package main import ( "bytes" "flag" "go/ast" "go/format" "go/parser" "go/token" "os" "slices" "strings" "text/template" ) func main() { inputFile := flag.String("i", "", "input file") outputFile := flag.String("o", "", "output file") flag.Parse() if inputFile == nil || *inputFile == "" || outputFile == nil || *outputFile == "" { flag.PrintDefaults() os.Exit(1) } // Parse the input file fileSet := token.NewFileSet() parsedFile, err := parser.ParseFile(fileSet, *inputFile, nil, parser.SkipObjectResolution) if err != nil { panic(err) } // Find TaskType* constants. var consts []string for _, decl := range parsedFile.Decls { genDecl, ok := decl.(*ast.GenDecl) if !ok { continue } for _, spec := range genDecl.Specs { valueSpec, ok := spec.(*ast.ValueSpec) if !ok { continue } for _, name := range valueSpec.Names { if !strings.HasPrefix(name.Name, "TaskType") { continue } if name.Name == "TaskType" { continue } consts = append(consts, strings.TrimPrefix(name.Name, "TaskType")) } } } if len(consts) == 0 { panic("TaskType* constants not found") } slices.Sort(consts) // Generate code. tmpl, err := template.New("code").Parse(templateText) if err != nil { panic(err) } var buf bytes.Buffer err = tmpl.Execute(&buf, consts) if err != nil { panic(err) } formatted, err := format.Source(buf.Bytes()) if err != nil { panic(err) } err = os.WriteFile(*outputFile, formatted, 0644) if err != nil { panic(err) } } const templateText = `// Code generated by go generate; DO NOT EDIT. package taskqueue import ( "context" "encoding/json" "errors" "fmt" "github.com/hibiken/asynq" ) type processorWrapper struct { impl processor results chan TaskResult } func newProcessorWrapper(impl processor) *processorWrapper { return &processorWrapper{ impl: impl, results: make(chan TaskResult), } } {{ range . }} func (p *processorWrapper) processTask{{ . }}(ctx context.Context, t *asynq.Task) error { var payload TaskPayload{{ . }} if err := json.Unmarshal(t.Payload(), &payload); err != nil { err := fmt.Errorf("json.Unmarshal failed: %v: %w", err, asynq.SkipRetry) p.results <- &TaskResult{{ . }}{Err: err} return err } result, err := p.impl.doProcessTask{{ . }}(ctx, &payload) if err != nil { retryCount, _ := asynq.GetRetryCount(ctx) maxRetry, _ := asynq.GetMaxRetry(ctx) isRecoverable := !errors.Is(err, asynq.SkipRetry) && retryCount < maxRetry if !isRecoverable { p.results <- &TaskResult{{ . }}{Err: err} } return err } p.results <- result return nil } {{ end }} `