diff options
| author | nsfisis <nsfisis@gmail.com> | 2024-08-01 21:08:31 +0900 |
|---|---|---|
| committer | nsfisis <nsfisis@gmail.com> | 2024-08-01 21:08:31 +0900 |
| commit | 6767acd3d9cc2cf5b778048ec6339b8c9123fbb5 (patch) | |
| tree | dc436435b043158275501eab4b79e5a64cf013d9 /backend | |
| parent | 5e6775c9c1efbbd3b08363ffda421a5996dc7143 (diff) | |
| download | iosdc-japan-2024-albatross-6767acd3d9cc2cf5b778048ec6339b8c9123fbb5.tar.gz iosdc-japan-2024-albatross-6767acd3d9cc2cf5b778048ec6339b8c9123fbb5.tar.zst iosdc-japan-2024-albatross-6767acd3d9cc2cf5b778048ec6339b8c9123fbb5.zip | |
refactor(backend): wrap ApiHandler with user authentication
Diffstat (limited to 'backend')
| -rw-r--r-- | backend/api/handler_wrapper.go | 134 | ||||
| -rw-r--r-- | backend/api/handlers.go | 93 | ||||
| -rw-r--r-- | backend/gen/api_handler_wrapper_gen.go | 164 | ||||
| -rw-r--r-- | backend/gen/gen.go | 1 | ||||
| -rw-r--r-- | backend/main.go | 4 |
5 files changed, 307 insertions, 89 deletions
diff --git a/backend/api/handler_wrapper.go b/backend/api/handler_wrapper.go new file mode 100644 index 0000000..37a199b --- /dev/null +++ b/backend/api/handler_wrapper.go @@ -0,0 +1,134 @@ +// Code generated by go generate; DO NOT EDIT. + +package api + +import ( + "context" + "errors" + "strings" + + "github.com/nsfisis/iosdc-japan-2024-albatross/backend/auth" + "github.com/nsfisis/iosdc-japan-2024-albatross/backend/db" +) + +var _ StrictServerInterface = (*ApiHandlerWrapper)(nil) + +type ApiHandlerWrapper struct { + innerHandler ApiHandler +} + +func NewHandler(queries *db.Queries, hubs GameHubsInterface) *ApiHandlerWrapper { + return &ApiHandlerWrapper{ + innerHandler: ApiHandler{ + q: queries, + hubs: hubs, + }, + } +} + +func parseJWTClaimsFromAuthorizationHeader(authorization string) (*auth.JWTClaims, error) { + const prefix = "Bearer " + if !strings.HasPrefix(authorization, prefix) { + return nil, errors.New("invalid authorization header") + } + token := authorization[len(prefix):] + claims, err := auth.ParseJWT(token) + if err != nil { + return nil, err + } + return claims, nil +} + +func (h *ApiHandlerWrapper) AdminGetGame(ctx context.Context, request AdminGetGameRequestObject) (AdminGetGameResponseObject, error) { + user, err := parseJWTClaimsFromAuthorizationHeader(request.Params.Authorization) + if err != nil { + return AdminGetGame401JSONResponse{ + Message: "Unauthorized", + }, nil + } + if !user.IsAdmin { + return AdminGetGame403JSONResponse{ + Message: "Forbidden", + }, nil + } + return h.innerHandler.AdminGetGame(ctx, request, user) +} + +func (h *ApiHandlerWrapper) AdminGetGames(ctx context.Context, request AdminGetGamesRequestObject) (AdminGetGamesResponseObject, error) { + user, err := parseJWTClaimsFromAuthorizationHeader(request.Params.Authorization) + if err != nil { + return AdminGetGames401JSONResponse{ + Message: "Unauthorized", + }, nil + } + if !user.IsAdmin { + return AdminGetGames403JSONResponse{ + Message: "Forbidden", + }, nil + } + return h.innerHandler.AdminGetGames(ctx, request, user) +} + +func (h *ApiHandlerWrapper) AdminGetUsers(ctx context.Context, request AdminGetUsersRequestObject) (AdminGetUsersResponseObject, error) { + user, err := parseJWTClaimsFromAuthorizationHeader(request.Params.Authorization) + if err != nil { + return AdminGetUsers401JSONResponse{ + Message: "Unauthorized", + }, nil + } + if !user.IsAdmin { + return AdminGetUsers403JSONResponse{ + Message: "Forbidden", + }, nil + } + return h.innerHandler.AdminGetUsers(ctx, request, user) +} + +func (h *ApiHandlerWrapper) AdminPutGame(ctx context.Context, request AdminPutGameRequestObject) (AdminPutGameResponseObject, error) { + user, err := parseJWTClaimsFromAuthorizationHeader(request.Params.Authorization) + if err != nil { + return AdminPutGame401JSONResponse{ + Message: "Unauthorized", + }, nil + } + if !user.IsAdmin { + return AdminPutGame403JSONResponse{ + Message: "Forbidden", + }, nil + } + return h.innerHandler.AdminPutGame(ctx, request, user) +} + +func (h *ApiHandlerWrapper) GetGame(ctx context.Context, request GetGameRequestObject) (GetGameResponseObject, error) { + user, err := parseJWTClaimsFromAuthorizationHeader(request.Params.Authorization) + if err != nil { + return GetGame401JSONResponse{ + Message: "Unauthorized", + }, nil + } + return h.innerHandler.GetGame(ctx, request, user) +} + +func (h *ApiHandlerWrapper) GetGames(ctx context.Context, request GetGamesRequestObject) (GetGamesResponseObject, error) { + user, err := parseJWTClaimsFromAuthorizationHeader(request.Params.Authorization) + if err != nil { + return GetGames401JSONResponse{ + Message: "Unauthorized", + }, nil + } + return h.innerHandler.GetGames(ctx, request, user) +} + +func (h *ApiHandlerWrapper) GetToken(ctx context.Context, request GetTokenRequestObject) (GetTokenResponseObject, error) { + user, err := parseJWTClaimsFromAuthorizationHeader(request.Params.Authorization) + if err != nil { + return GetToken401JSONResponse{ + Message: "Unauthorized", + }, nil + } + return h.innerHandler.GetToken(ctx, request, user) +} + +func (h *ApiHandlerWrapper) PostLogin(ctx context.Context, request PostLoginRequestObject) (PostLoginResponseObject, error) { + return h.innerHandler.PostLogin(ctx, request) +} diff --git a/backend/api/handlers.go b/backend/api/handlers.go index a250629..ea9ddea 100644 --- a/backend/api/handlers.go +++ b/backend/api/handlers.go @@ -4,7 +4,6 @@ import ( "context" "errors" "net/http" - "strings" "time" "github.com/jackc/pgx/v5" @@ -15,8 +14,6 @@ import ( "github.com/nsfisis/iosdc-japan-2024-albatross/backend/db" ) -var _ StrictServerInterface = (*ApiHandler)(nil) - type ApiHandler struct { q *db.Queries hubs GameHubsInterface @@ -26,20 +23,7 @@ type GameHubsInterface interface { StartGame(gameID int) error } -func NewHandler(queries *db.Queries, hubs GameHubsInterface) *ApiHandler { - return &ApiHandler{ - q: queries, - hubs: hubs, - } -} - -func (h *ApiHandler) AdminGetGames(ctx context.Context, request AdminGetGamesRequestObject) (AdminGetGamesResponseObject, error) { - user := ctx.Value("user").(*auth.JWTClaims) - if !user.IsAdmin { - return AdminGetGames403JSONResponse{ - Message: "Forbidden", - }, nil - } +func (h *ApiHandler) AdminGetGames(ctx context.Context, request AdminGetGamesRequestObject, user *auth.JWTClaims) (AdminGetGamesResponseObject, error) { gameRows, err := h.q.ListGames(ctx) if err != nil { return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) @@ -76,13 +60,7 @@ func (h *ApiHandler) AdminGetGames(ctx context.Context, request AdminGetGamesReq }, nil } -func (h *ApiHandler) AdminGetGame(ctx context.Context, request AdminGetGameRequestObject) (AdminGetGameResponseObject, error) { - user := ctx.Value("user").(*auth.JWTClaims) - if !user.IsAdmin { - return AdminGetGame403JSONResponse{ - Message: "Forbidden", - }, nil - } +func (h *ApiHandler) AdminGetGame(ctx context.Context, request AdminGetGameRequestObject, user *auth.JWTClaims) (AdminGetGameResponseObject, error) { gameId := request.GameId row, err := h.q.GetGameById(ctx, int32(gameId)) if err != nil { @@ -123,13 +101,7 @@ func (h *ApiHandler) AdminGetGame(ctx context.Context, request AdminGetGameReque }, nil } -func (h *ApiHandler) AdminPutGame(ctx context.Context, request AdminPutGameRequestObject) (AdminPutGameResponseObject, error) { - user := ctx.Value("user").(*auth.JWTClaims) - if !user.IsAdmin { - return AdminPutGame403JSONResponse{ - Message: "Forbidden", - }, nil - } +func (h *ApiHandler) AdminPutGame(ctx context.Context, request AdminPutGameRequestObject, user *auth.JWTClaims) (AdminPutGameResponseObject, error) { gameID := request.GameId displayName := request.Body.DisplayName durationSeconds := request.Body.DurationSeconds @@ -210,13 +182,7 @@ func (h *ApiHandler) AdminPutGame(ctx context.Context, request AdminPutGameReque return AdminPutGame204Response{}, nil } -func (h *ApiHandler) AdminGetUsers(ctx context.Context, request AdminGetUsersRequestObject) (AdminGetUsersResponseObject, error) { - user := ctx.Value("user").(*auth.JWTClaims) - if !user.IsAdmin { - return AdminGetUsers403JSONResponse{ - Message: "Forbidden", - }, nil - } +func (h *ApiHandler) AdminGetUsers(ctx context.Context, request AdminGetUsersRequestObject, user *auth.JWTClaims) (AdminGetUsersResponseObject, error) { users, err := h.q.ListUsers(ctx) if err != nil { return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) @@ -263,8 +229,7 @@ func (h *ApiHandler) PostLogin(ctx context.Context, request PostLoginRequestObje }, nil } -func (h *ApiHandler) GetToken(ctx context.Context, request GetTokenRequestObject) (GetTokenResponseObject, error) { - user := ctx.Value("user").(*auth.JWTClaims) +func (h *ApiHandler) GetToken(ctx context.Context, request GetTokenRequestObject, user *auth.JWTClaims) (GetTokenResponseObject, error) { newToken, err := auth.NewShortLivedJWT(user) if err != nil { return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) @@ -274,8 +239,7 @@ func (h *ApiHandler) GetToken(ctx context.Context, request GetTokenRequestObject }, nil } -func (h *ApiHandler) GetGames(ctx context.Context, request GetGamesRequestObject) (GetGamesResponseObject, error) { - user := ctx.Value("user").(*auth.JWTClaims) +func (h *ApiHandler) GetGames(ctx context.Context, request GetGamesRequestObject, user *auth.JWTClaims) (GetGamesResponseObject, error) { playerId := request.Params.PlayerId if !user.IsAdmin { if playerId == nil || *playerId != user.UserID { @@ -357,8 +321,7 @@ func (h *ApiHandler) GetGames(ctx context.Context, request GetGamesRequestObject } } -func (h *ApiHandler) GetGame(ctx context.Context, request GetGameRequestObject) (GetGameResponseObject, error) { - user := ctx.Value("user").(*auth.JWTClaims) +func (h *ApiHandler) GetGame(ctx context.Context, request GetGameRequestObject, user *auth.JWTClaims) (GetGameResponseObject, error) { // TODO: check user permission gameId := request.GameId row, err := h.q.GetGameById(ctx, int32(gameId)) @@ -401,45 +364,3 @@ func (h *ApiHandler) GetGame(ctx context.Context, request GetGameRequestObject) Game: game, }, nil } - -func _assertUserResponseIsCompatibleWithJWTClaims() { - var c auth.JWTClaims - var u User - u.UserId = c.UserID - u.Username = c.Username - u.DisplayName = c.DisplayName - u.IconPath = c.IconPath - u.IsAdmin = c.IsAdmin - _ = u -} - -func setupJWTFromAuthorizationHeader(c echo.Context) error { - authorization := c.Request().Header.Get("Authorization") - const prefix = "Bearer " - if !strings.HasPrefix(authorization, prefix) { - return echo.NewHTTPError(http.StatusUnauthorized) - } - token := authorization[len(prefix):] - claims, err := auth.ParseJWT(token) - if err != nil { - return echo.NewHTTPError(http.StatusUnauthorized, err.Error()) - } - c.SetRequest(c.Request().WithContext(context.WithValue(c.Request().Context(), "user", claims))) - return nil -} - -func NewJWTMiddleware() StrictMiddlewareFunc { - return func(handler StrictHandlerFunc, operationID string) StrictHandlerFunc { - if operationID == "PostLogin" { - return handler - } - - return func(c echo.Context, request interface{}) (interface{}, error) { - err := setupJWTFromAuthorizationHeader(c) - if err != nil { - return nil, echo.NewHTTPError(http.StatusUnauthorized, err.Error()) - } - return handler(c, request) - } - } -} diff --git a/backend/gen/api_handler_wrapper_gen.go b/backend/gen/api_handler_wrapper_gen.go new file mode 100644 index 0000000..01d05bf --- /dev/null +++ b/backend/gen/api_handler_wrapper_gen.go @@ -0,0 +1,164 @@ +package main + +import ( + "bytes" + "flag" + "go/ast" + "go/format" + "go/parser" + "go/token" + "os" + "slices" + "strings" + "text/template" +) + +func main() { + inputFile := flag.String("i", "", "input file") + outputFile := flag.String("o", "", "output file") + flag.Parse() + + if inputFile == nil || *inputFile == "" || outputFile == nil || *outputFile == "" { + flag.PrintDefaults() + os.Exit(1) + } + + // Parse the input file + fileSet := token.NewFileSet() + parsedFile, err := parser.ParseFile(fileSet, *inputFile, nil, parser.SkipObjectResolution) + if err != nil { + panic(err) + } + + // Find methods in StrictServerInterface + var methods []string + for _, decl := range parsedFile.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok { + continue + } + for _, spec := range genDecl.Specs { + typeSpec, ok := spec.(*ast.TypeSpec) + if !ok { + continue + } + if typeSpec.Name.Name != "StrictServerInterface" { + continue + } + interfaceType, ok := typeSpec.Type.(*ast.InterfaceType) + if !ok { + continue + } + for _, method := range interfaceType.Methods.List { + if len(method.Names) != 0 { + methods = append(methods, method.Names[0].Name) + } + } + } + } + if len(methods) == 0 { + panic("StrictServerInterface not found") + } + slices.Sort(methods) + + type TemplateParameter struct { + Name string + RequiresLogin bool + RequiresAdminRole bool + } + templateParameters := make([]TemplateParameter, len(methods)) + for i, method := range methods { + templateParameters[i] = TemplateParameter{ + Name: method, + RequiresLogin: method != "PostLogin", + RequiresAdminRole: strings.Contains(method, "Admin"), + } + } + + // Generate code. + tmpl, err := template.New("code").Parse(templateText) + if err != nil { + panic(err) + } + + var buf bytes.Buffer + err = tmpl.Execute(&buf, templateParameters) + if err != nil { + panic(err) + } + + formatted, err := format.Source(buf.Bytes()) + if err != nil { + panic(err) + } + + err = os.WriteFile(*outputFile, formatted, 0644) + if err != nil { + panic(err) + } +} + +const templateText = `// Code generated by go generate; DO NOT EDIT. + +package api + +import ( + "context" + "errors" + "strings" + + "github.com/nsfisis/iosdc-japan-2024-albatross/backend/auth" + "github.com/nsfisis/iosdc-japan-2024-albatross/backend/db" +) + +var _ StrictServerInterface = (*ApiHandlerWrapper)(nil) + +type ApiHandlerWrapper struct { + innerHandler ApiHandler +} + +func NewHandler(queries *db.Queries, hubs GameHubsInterface) *ApiHandlerWrapper { + return &ApiHandlerWrapper{ + innerHandler: ApiHandler{ + q: queries, + hubs: hubs, + }, + } +} + +func parseJWTClaimsFromAuthorizationHeader(authorization string) (*auth.JWTClaims, error) { + const prefix = "Bearer " + if !strings.HasPrefix(authorization, prefix) { + return nil, errors.New("invalid authorization header") + } + token := authorization[len(prefix):] + claims, err := auth.ParseJWT(token) + if err != nil { + return nil, err + } + return claims, nil +} + +{{ range . }} + func (h *ApiHandlerWrapper) {{ .Name }}(ctx context.Context, request {{ .Name }}RequestObject) ({{ .Name }}ResponseObject, error) { + {{ if .RequiresLogin -}} + user, err := parseJWTClaimsFromAuthorizationHeader(request.Params.Authorization) + if err != nil { + return {{ .Name }}401JSONResponse{ + Message: "Unauthorized", + }, nil + } + {{ if .RequiresAdminRole -}} + if !user.IsAdmin { + return {{ .Name }}403JSONResponse{ + Message: "Forbidden", + }, nil + } + {{ end -}} + return h.innerHandler.{{ .Name }}(ctx, request, user) + {{ else -}} + return h.innerHandler.{{ .Name }}(ctx, request) + {{ end -}} + } +{{ end }} +` diff --git a/backend/gen/gen.go b/backend/gen/gen.go index 6af912d..6fb430f 100644 --- a/backend/gen/gen.go +++ b/backend/gen/gen.go @@ -2,3 +2,4 @@ package main //go:generate go run github.com/sqlc-dev/sqlc/cmd/sqlc generate //go:generate go run github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen -config oapi-codegen.yaml ../../openapi.yaml +//go:generate go run ./api_handler_wrapper_gen.go -i ../api/generated.go -o ../api/handler_wrapper.go diff --git a/backend/main.go b/backend/main.go index d636af7..0257113 100644 --- a/backend/main.go +++ b/backend/main.go @@ -75,9 +75,7 @@ func main() { apiGroup := e.Group("/api") apiGroup.Use(oapimiddleware.OapiRequestValidator(openApiSpec)) apiHandler := api.NewHandler(queries, gameHubs) - api.RegisterHandlers(apiGroup, api.NewStrictHandler(apiHandler, []api.StrictMiddlewareFunc{ - api.NewJWTMiddleware(), - })) + api.RegisterHandlers(apiGroup, api.NewStrictHandler(apiHandler, nil)) gameHubs.Run() |
