aboutsummaryrefslogtreecommitdiffhomepage
path: root/backend/gen/api_handler_wrapper_gen.go
diff options
context:
space:
mode:
Diffstat (limited to 'backend/gen/api_handler_wrapper_gen.go')
-rw-r--r--backend/gen/api_handler_wrapper_gen.go164
1 files changed, 164 insertions, 0 deletions
diff --git a/backend/gen/api_handler_wrapper_gen.go b/backend/gen/api_handler_wrapper_gen.go
new file mode 100644
index 0000000..01d05bf
--- /dev/null
+++ b/backend/gen/api_handler_wrapper_gen.go
@@ -0,0 +1,164 @@
+package main
+
+import (
+ "bytes"
+ "flag"
+ "go/ast"
+ "go/format"
+ "go/parser"
+ "go/token"
+ "os"
+ "slices"
+ "strings"
+ "text/template"
+)
+
+func main() {
+ inputFile := flag.String("i", "", "input file")
+ outputFile := flag.String("o", "", "output file")
+ flag.Parse()
+
+ if inputFile == nil || *inputFile == "" || outputFile == nil || *outputFile == "" {
+ flag.PrintDefaults()
+ os.Exit(1)
+ }
+
+ // Parse the input file
+ fileSet := token.NewFileSet()
+ parsedFile, err := parser.ParseFile(fileSet, *inputFile, nil, parser.SkipObjectResolution)
+ if err != nil {
+ panic(err)
+ }
+
+ // Find methods in StrictServerInterface
+ var methods []string
+ for _, decl := range parsedFile.Decls {
+ genDecl, ok := decl.(*ast.GenDecl)
+ if !ok {
+ continue
+ }
+ for _, spec := range genDecl.Specs {
+ typeSpec, ok := spec.(*ast.TypeSpec)
+ if !ok {
+ continue
+ }
+ if typeSpec.Name.Name != "StrictServerInterface" {
+ continue
+ }
+ interfaceType, ok := typeSpec.Type.(*ast.InterfaceType)
+ if !ok {
+ continue
+ }
+ for _, method := range interfaceType.Methods.List {
+ if len(method.Names) != 0 {
+ methods = append(methods, method.Names[0].Name)
+ }
+ }
+ }
+ }
+ if len(methods) == 0 {
+ panic("StrictServerInterface not found")
+ }
+ slices.Sort(methods)
+
+ type TemplateParameter struct {
+ Name string
+ RequiresLogin bool
+ RequiresAdminRole bool
+ }
+ templateParameters := make([]TemplateParameter, len(methods))
+ for i, method := range methods {
+ templateParameters[i] = TemplateParameter{
+ Name: method,
+ RequiresLogin: method != "PostLogin",
+ RequiresAdminRole: strings.Contains(method, "Admin"),
+ }
+ }
+
+ // Generate code.
+ tmpl, err := template.New("code").Parse(templateText)
+ if err != nil {
+ panic(err)
+ }
+
+ var buf bytes.Buffer
+ err = tmpl.Execute(&buf, templateParameters)
+ if err != nil {
+ panic(err)
+ }
+
+ formatted, err := format.Source(buf.Bytes())
+ if err != nil {
+ panic(err)
+ }
+
+ err = os.WriteFile(*outputFile, formatted, 0644)
+ if err != nil {
+ panic(err)
+ }
+}
+
+const templateText = `// Code generated by go generate; DO NOT EDIT.
+
+package api
+
+import (
+ "context"
+ "errors"
+ "strings"
+
+ "github.com/nsfisis/iosdc-japan-2024-albatross/backend/auth"
+ "github.com/nsfisis/iosdc-japan-2024-albatross/backend/db"
+)
+
+var _ StrictServerInterface = (*ApiHandlerWrapper)(nil)
+
+type ApiHandlerWrapper struct {
+ innerHandler ApiHandler
+}
+
+func NewHandler(queries *db.Queries, hubs GameHubsInterface) *ApiHandlerWrapper {
+ return &ApiHandlerWrapper{
+ innerHandler: ApiHandler{
+ q: queries,
+ hubs: hubs,
+ },
+ }
+}
+
+func parseJWTClaimsFromAuthorizationHeader(authorization string) (*auth.JWTClaims, error) {
+ const prefix = "Bearer "
+ if !strings.HasPrefix(authorization, prefix) {
+ return nil, errors.New("invalid authorization header")
+ }
+ token := authorization[len(prefix):]
+ claims, err := auth.ParseJWT(token)
+ if err != nil {
+ return nil, err
+ }
+ return claims, nil
+}
+
+{{ range . }}
+ func (h *ApiHandlerWrapper) {{ .Name }}(ctx context.Context, request {{ .Name }}RequestObject) ({{ .Name }}ResponseObject, error) {
+ {{ if .RequiresLogin -}}
+ user, err := parseJWTClaimsFromAuthorizationHeader(request.Params.Authorization)
+ if err != nil {
+ return {{ .Name }}401JSONResponse{
+ Message: "Unauthorized",
+ }, nil
+ }
+ {{ if .RequiresAdminRole -}}
+ if !user.IsAdmin {
+ return {{ .Name }}403JSONResponse{
+ Message: "Forbidden",
+ }, nil
+ }
+ {{ end -}}
+ return h.innerHandler.{{ .Name }}(ctx, request, user)
+ {{ else -}}
+ return h.innerHandler.{{ .Name }}(ctx, request)
+ {{ end -}}
+ }
+{{ end }}
+`