diff options
Diffstat (limited to 'backend/gen/api_handler_wrapper_gen.go')
| -rw-r--r-- | backend/gen/api_handler_wrapper_gen.go | 164 |
1 files changed, 164 insertions, 0 deletions
diff --git a/backend/gen/api_handler_wrapper_gen.go b/backend/gen/api_handler_wrapper_gen.go new file mode 100644 index 0000000..01d05bf --- /dev/null +++ b/backend/gen/api_handler_wrapper_gen.go @@ -0,0 +1,164 @@ +package main + +import ( + "bytes" + "flag" + "go/ast" + "go/format" + "go/parser" + "go/token" + "os" + "slices" + "strings" + "text/template" +) + +func main() { + inputFile := flag.String("i", "", "input file") + outputFile := flag.String("o", "", "output file") + flag.Parse() + + if inputFile == nil || *inputFile == "" || outputFile == nil || *outputFile == "" { + flag.PrintDefaults() + os.Exit(1) + } + + // Parse the input file + fileSet := token.NewFileSet() + parsedFile, err := parser.ParseFile(fileSet, *inputFile, nil, parser.SkipObjectResolution) + if err != nil { + panic(err) + } + + // Find methods in StrictServerInterface + var methods []string + for _, decl := range parsedFile.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok { + continue + } + for _, spec := range genDecl.Specs { + typeSpec, ok := spec.(*ast.TypeSpec) + if !ok { + continue + } + if typeSpec.Name.Name != "StrictServerInterface" { + continue + } + interfaceType, ok := typeSpec.Type.(*ast.InterfaceType) + if !ok { + continue + } + for _, method := range interfaceType.Methods.List { + if len(method.Names) != 0 { + methods = append(methods, method.Names[0].Name) + } + } + } + } + if len(methods) == 0 { + panic("StrictServerInterface not found") + } + slices.Sort(methods) + + type TemplateParameter struct { + Name string + RequiresLogin bool + RequiresAdminRole bool + } + templateParameters := make([]TemplateParameter, len(methods)) + for i, method := range methods { + templateParameters[i] = TemplateParameter{ + Name: method, + RequiresLogin: method != "PostLogin", + RequiresAdminRole: strings.Contains(method, "Admin"), + } + } + + // Generate code. + tmpl, err := template.New("code").Parse(templateText) + if err != nil { + panic(err) + } + + var buf bytes.Buffer + err = tmpl.Execute(&buf, templateParameters) + if err != nil { + panic(err) + } + + formatted, err := format.Source(buf.Bytes()) + if err != nil { + panic(err) + } + + err = os.WriteFile(*outputFile, formatted, 0644) + if err != nil { + panic(err) + } +} + +const templateText = `// Code generated by go generate; DO NOT EDIT. + +package api + +import ( + "context" + "errors" + "strings" + + "github.com/nsfisis/iosdc-japan-2024-albatross/backend/auth" + "github.com/nsfisis/iosdc-japan-2024-albatross/backend/db" +) + +var _ StrictServerInterface = (*ApiHandlerWrapper)(nil) + +type ApiHandlerWrapper struct { + innerHandler ApiHandler +} + +func NewHandler(queries *db.Queries, hubs GameHubsInterface) *ApiHandlerWrapper { + return &ApiHandlerWrapper{ + innerHandler: ApiHandler{ + q: queries, + hubs: hubs, + }, + } +} + +func parseJWTClaimsFromAuthorizationHeader(authorization string) (*auth.JWTClaims, error) { + const prefix = "Bearer " + if !strings.HasPrefix(authorization, prefix) { + return nil, errors.New("invalid authorization header") + } + token := authorization[len(prefix):] + claims, err := auth.ParseJWT(token) + if err != nil { + return nil, err + } + return claims, nil +} + +{{ range . }} + func (h *ApiHandlerWrapper) {{ .Name }}(ctx context.Context, request {{ .Name }}RequestObject) ({{ .Name }}ResponseObject, error) { + {{ if .RequiresLogin -}} + user, err := parseJWTClaimsFromAuthorizationHeader(request.Params.Authorization) + if err != nil { + return {{ .Name }}401JSONResponse{ + Message: "Unauthorized", + }, nil + } + {{ if .RequiresAdminRole -}} + if !user.IsAdmin { + return {{ .Name }}403JSONResponse{ + Message: "Forbidden", + }, nil + } + {{ end -}} + return h.innerHandler.{{ .Name }}(ctx, request, user) + {{ else -}} + return h.innerHandler.{{ .Name }}(ctx, request) + {{ end -}} + } +{{ end }} +` |
