aboutsummaryrefslogtreecommitdiffhomepage
path: root/backend
diff options
context:
space:
mode:
authornsfisis <nsfisis@gmail.com>2024-08-01 21:08:31 +0900
committernsfisis <nsfisis@gmail.com>2024-08-01 21:08:31 +0900
commit6767acd3d9cc2cf5b778048ec6339b8c9123fbb5 (patch)
treedc436435b043158275501eab4b79e5a64cf013d9 /backend
parent5e6775c9c1efbbd3b08363ffda421a5996dc7143 (diff)
downloadiosdc-japan-2024-albatross-6767acd3d9cc2cf5b778048ec6339b8c9123fbb5.tar.gz
iosdc-japan-2024-albatross-6767acd3d9cc2cf5b778048ec6339b8c9123fbb5.tar.zst
iosdc-japan-2024-albatross-6767acd3d9cc2cf5b778048ec6339b8c9123fbb5.zip
refactor(backend): wrap ApiHandler with user authentication
Diffstat (limited to 'backend')
-rw-r--r--backend/api/handler_wrapper.go134
-rw-r--r--backend/api/handlers.go93
-rw-r--r--backend/gen/api_handler_wrapper_gen.go164
-rw-r--r--backend/gen/gen.go1
-rw-r--r--backend/main.go4
5 files changed, 307 insertions, 89 deletions
diff --git a/backend/api/handler_wrapper.go b/backend/api/handler_wrapper.go
new file mode 100644
index 0000000..37a199b
--- /dev/null
+++ b/backend/api/handler_wrapper.go
@@ -0,0 +1,134 @@
+// Code generated by go generate; DO NOT EDIT.
+
+package api
+
+import (
+ "context"
+ "errors"
+ "strings"
+
+ "github.com/nsfisis/iosdc-japan-2024-albatross/backend/auth"
+ "github.com/nsfisis/iosdc-japan-2024-albatross/backend/db"
+)
+
+var _ StrictServerInterface = (*ApiHandlerWrapper)(nil)
+
+type ApiHandlerWrapper struct {
+ innerHandler ApiHandler
+}
+
+func NewHandler(queries *db.Queries, hubs GameHubsInterface) *ApiHandlerWrapper {
+ return &ApiHandlerWrapper{
+ innerHandler: ApiHandler{
+ q: queries,
+ hubs: hubs,
+ },
+ }
+}
+
+func parseJWTClaimsFromAuthorizationHeader(authorization string) (*auth.JWTClaims, error) {
+ const prefix = "Bearer "
+ if !strings.HasPrefix(authorization, prefix) {
+ return nil, errors.New("invalid authorization header")
+ }
+ token := authorization[len(prefix):]
+ claims, err := auth.ParseJWT(token)
+ if err != nil {
+ return nil, err
+ }
+ return claims, nil
+}
+
+func (h *ApiHandlerWrapper) AdminGetGame(ctx context.Context, request AdminGetGameRequestObject) (AdminGetGameResponseObject, error) {
+ user, err := parseJWTClaimsFromAuthorizationHeader(request.Params.Authorization)
+ if err != nil {
+ return AdminGetGame401JSONResponse{
+ Message: "Unauthorized",
+ }, nil
+ }
+ if !user.IsAdmin {
+ return AdminGetGame403JSONResponse{
+ Message: "Forbidden",
+ }, nil
+ }
+ return h.innerHandler.AdminGetGame(ctx, request, user)
+}
+
+func (h *ApiHandlerWrapper) AdminGetGames(ctx context.Context, request AdminGetGamesRequestObject) (AdminGetGamesResponseObject, error) {
+ user, err := parseJWTClaimsFromAuthorizationHeader(request.Params.Authorization)
+ if err != nil {
+ return AdminGetGames401JSONResponse{
+ Message: "Unauthorized",
+ }, nil
+ }
+ if !user.IsAdmin {
+ return AdminGetGames403JSONResponse{
+ Message: "Forbidden",
+ }, nil
+ }
+ return h.innerHandler.AdminGetGames(ctx, request, user)
+}
+
+func (h *ApiHandlerWrapper) AdminGetUsers(ctx context.Context, request AdminGetUsersRequestObject) (AdminGetUsersResponseObject, error) {
+ user, err := parseJWTClaimsFromAuthorizationHeader(request.Params.Authorization)
+ if err != nil {
+ return AdminGetUsers401JSONResponse{
+ Message: "Unauthorized",
+ }, nil
+ }
+ if !user.IsAdmin {
+ return AdminGetUsers403JSONResponse{
+ Message: "Forbidden",
+ }, nil
+ }
+ return h.innerHandler.AdminGetUsers(ctx, request, user)
+}
+
+func (h *ApiHandlerWrapper) AdminPutGame(ctx context.Context, request AdminPutGameRequestObject) (AdminPutGameResponseObject, error) {
+ user, err := parseJWTClaimsFromAuthorizationHeader(request.Params.Authorization)
+ if err != nil {
+ return AdminPutGame401JSONResponse{
+ Message: "Unauthorized",
+ }, nil
+ }
+ if !user.IsAdmin {
+ return AdminPutGame403JSONResponse{
+ Message: "Forbidden",
+ }, nil
+ }
+ return h.innerHandler.AdminPutGame(ctx, request, user)
+}
+
+func (h *ApiHandlerWrapper) GetGame(ctx context.Context, request GetGameRequestObject) (GetGameResponseObject, error) {
+ user, err := parseJWTClaimsFromAuthorizationHeader(request.Params.Authorization)
+ if err != nil {
+ return GetGame401JSONResponse{
+ Message: "Unauthorized",
+ }, nil
+ }
+ return h.innerHandler.GetGame(ctx, request, user)
+}
+
+func (h *ApiHandlerWrapper) GetGames(ctx context.Context, request GetGamesRequestObject) (GetGamesResponseObject, error) {
+ user, err := parseJWTClaimsFromAuthorizationHeader(request.Params.Authorization)
+ if err != nil {
+ return GetGames401JSONResponse{
+ Message: "Unauthorized",
+ }, nil
+ }
+ return h.innerHandler.GetGames(ctx, request, user)
+}
+
+func (h *ApiHandlerWrapper) GetToken(ctx context.Context, request GetTokenRequestObject) (GetTokenResponseObject, error) {
+ user, err := parseJWTClaimsFromAuthorizationHeader(request.Params.Authorization)
+ if err != nil {
+ return GetToken401JSONResponse{
+ Message: "Unauthorized",
+ }, nil
+ }
+ return h.innerHandler.GetToken(ctx, request, user)
+}
+
+func (h *ApiHandlerWrapper) PostLogin(ctx context.Context, request PostLoginRequestObject) (PostLoginResponseObject, error) {
+ return h.innerHandler.PostLogin(ctx, request)
+}
diff --git a/backend/api/handlers.go b/backend/api/handlers.go
index a250629..ea9ddea 100644
--- a/backend/api/handlers.go
+++ b/backend/api/handlers.go
@@ -4,7 +4,6 @@ import (
"context"
"errors"
"net/http"
- "strings"
"time"
"github.com/jackc/pgx/v5"
@@ -15,8 +14,6 @@ import (
"github.com/nsfisis/iosdc-japan-2024-albatross/backend/db"
)
-var _ StrictServerInterface = (*ApiHandler)(nil)
-
type ApiHandler struct {
q *db.Queries
hubs GameHubsInterface
@@ -26,20 +23,7 @@ type GameHubsInterface interface {
StartGame(gameID int) error
}
-func NewHandler(queries *db.Queries, hubs GameHubsInterface) *ApiHandler {
- return &ApiHandler{
- q: queries,
- hubs: hubs,
- }
-}
-
-func (h *ApiHandler) AdminGetGames(ctx context.Context, request AdminGetGamesRequestObject) (AdminGetGamesResponseObject, error) {
- user := ctx.Value("user").(*auth.JWTClaims)
- if !user.IsAdmin {
- return AdminGetGames403JSONResponse{
- Message: "Forbidden",
- }, nil
- }
+func (h *ApiHandler) AdminGetGames(ctx context.Context, request AdminGetGamesRequestObject, user *auth.JWTClaims) (AdminGetGamesResponseObject, error) {
gameRows, err := h.q.ListGames(ctx)
if err != nil {
return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
@@ -76,13 +60,7 @@ func (h *ApiHandler) AdminGetGames(ctx context.Context, request AdminGetGamesReq
}, nil
}
-func (h *ApiHandler) AdminGetGame(ctx context.Context, request AdminGetGameRequestObject) (AdminGetGameResponseObject, error) {
- user := ctx.Value("user").(*auth.JWTClaims)
- if !user.IsAdmin {
- return AdminGetGame403JSONResponse{
- Message: "Forbidden",
- }, nil
- }
+func (h *ApiHandler) AdminGetGame(ctx context.Context, request AdminGetGameRequestObject, user *auth.JWTClaims) (AdminGetGameResponseObject, error) {
gameId := request.GameId
row, err := h.q.GetGameById(ctx, int32(gameId))
if err != nil {
@@ -123,13 +101,7 @@ func (h *ApiHandler) AdminGetGame(ctx context.Context, request AdminGetGameReque
}, nil
}
-func (h *ApiHandler) AdminPutGame(ctx context.Context, request AdminPutGameRequestObject) (AdminPutGameResponseObject, error) {
- user := ctx.Value("user").(*auth.JWTClaims)
- if !user.IsAdmin {
- return AdminPutGame403JSONResponse{
- Message: "Forbidden",
- }, nil
- }
+func (h *ApiHandler) AdminPutGame(ctx context.Context, request AdminPutGameRequestObject, user *auth.JWTClaims) (AdminPutGameResponseObject, error) {
gameID := request.GameId
displayName := request.Body.DisplayName
durationSeconds := request.Body.DurationSeconds
@@ -210,13 +182,7 @@ func (h *ApiHandler) AdminPutGame(ctx context.Context, request AdminPutGameReque
return AdminPutGame204Response{}, nil
}
-func (h *ApiHandler) AdminGetUsers(ctx context.Context, request AdminGetUsersRequestObject) (AdminGetUsersResponseObject, error) {
- user := ctx.Value("user").(*auth.JWTClaims)
- if !user.IsAdmin {
- return AdminGetUsers403JSONResponse{
- Message: "Forbidden",
- }, nil
- }
+func (h *ApiHandler) AdminGetUsers(ctx context.Context, request AdminGetUsersRequestObject, user *auth.JWTClaims) (AdminGetUsersResponseObject, error) {
users, err := h.q.ListUsers(ctx)
if err != nil {
return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
@@ -263,8 +229,7 @@ func (h *ApiHandler) PostLogin(ctx context.Context, request PostLoginRequestObje
}, nil
}
-func (h *ApiHandler) GetToken(ctx context.Context, request GetTokenRequestObject) (GetTokenResponseObject, error) {
- user := ctx.Value("user").(*auth.JWTClaims)
+func (h *ApiHandler) GetToken(ctx context.Context, request GetTokenRequestObject, user *auth.JWTClaims) (GetTokenResponseObject, error) {
newToken, err := auth.NewShortLivedJWT(user)
if err != nil {
return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
@@ -274,8 +239,7 @@ func (h *ApiHandler) GetToken(ctx context.Context, request GetTokenRequestObject
}, nil
}
-func (h *ApiHandler) GetGames(ctx context.Context, request GetGamesRequestObject) (GetGamesResponseObject, error) {
- user := ctx.Value("user").(*auth.JWTClaims)
+func (h *ApiHandler) GetGames(ctx context.Context, request GetGamesRequestObject, user *auth.JWTClaims) (GetGamesResponseObject, error) {
playerId := request.Params.PlayerId
if !user.IsAdmin {
if playerId == nil || *playerId != user.UserID {
@@ -357,8 +321,7 @@ func (h *ApiHandler) GetGames(ctx context.Context, request GetGamesRequestObject
}
}
-func (h *ApiHandler) GetGame(ctx context.Context, request GetGameRequestObject) (GetGameResponseObject, error) {
- user := ctx.Value("user").(*auth.JWTClaims)
+func (h *ApiHandler) GetGame(ctx context.Context, request GetGameRequestObject, user *auth.JWTClaims) (GetGameResponseObject, error) {
// TODO: check user permission
gameId := request.GameId
row, err := h.q.GetGameById(ctx, int32(gameId))
@@ -401,45 +364,3 @@ func (h *ApiHandler) GetGame(ctx context.Context, request GetGameRequestObject)
Game: game,
}, nil
}
-
-func _assertUserResponseIsCompatibleWithJWTClaims() {
- var c auth.JWTClaims
- var u User
- u.UserId = c.UserID
- u.Username = c.Username
- u.DisplayName = c.DisplayName
- u.IconPath = c.IconPath
- u.IsAdmin = c.IsAdmin
- _ = u
-}
-
-func setupJWTFromAuthorizationHeader(c echo.Context) error {
- authorization := c.Request().Header.Get("Authorization")
- const prefix = "Bearer "
- if !strings.HasPrefix(authorization, prefix) {
- return echo.NewHTTPError(http.StatusUnauthorized)
- }
- token := authorization[len(prefix):]
- claims, err := auth.ParseJWT(token)
- if err != nil {
- return echo.NewHTTPError(http.StatusUnauthorized, err.Error())
- }
- c.SetRequest(c.Request().WithContext(context.WithValue(c.Request().Context(), "user", claims)))
- return nil
-}
-
-func NewJWTMiddleware() StrictMiddlewareFunc {
- return func(handler StrictHandlerFunc, operationID string) StrictHandlerFunc {
- if operationID == "PostLogin" {
- return handler
- }
-
- return func(c echo.Context, request interface{}) (interface{}, error) {
- err := setupJWTFromAuthorizationHeader(c)
- if err != nil {
- return nil, echo.NewHTTPError(http.StatusUnauthorized, err.Error())
- }
- return handler(c, request)
- }
- }
-}
diff --git a/backend/gen/api_handler_wrapper_gen.go b/backend/gen/api_handler_wrapper_gen.go
new file mode 100644
index 0000000..01d05bf
--- /dev/null
+++ b/backend/gen/api_handler_wrapper_gen.go
@@ -0,0 +1,164 @@
+package main
+
+import (
+ "bytes"
+ "flag"
+ "go/ast"
+ "go/format"
+ "go/parser"
+ "go/token"
+ "os"
+ "slices"
+ "strings"
+ "text/template"
+)
+
+func main() {
+ inputFile := flag.String("i", "", "input file")
+ outputFile := flag.String("o", "", "output file")
+ flag.Parse()
+
+ if inputFile == nil || *inputFile == "" || outputFile == nil || *outputFile == "" {
+ flag.PrintDefaults()
+ os.Exit(1)
+ }
+
+ // Parse the input file
+ fileSet := token.NewFileSet()
+ parsedFile, err := parser.ParseFile(fileSet, *inputFile, nil, parser.SkipObjectResolution)
+ if err != nil {
+ panic(err)
+ }
+
+ // Find methods in StrictServerInterface
+ var methods []string
+ for _, decl := range parsedFile.Decls {
+ genDecl, ok := decl.(*ast.GenDecl)
+ if !ok {
+ continue
+ }
+ for _, spec := range genDecl.Specs {
+ typeSpec, ok := spec.(*ast.TypeSpec)
+ if !ok {
+ continue
+ }
+ if typeSpec.Name.Name != "StrictServerInterface" {
+ continue
+ }
+ interfaceType, ok := typeSpec.Type.(*ast.InterfaceType)
+ if !ok {
+ continue
+ }
+ for _, method := range interfaceType.Methods.List {
+ if len(method.Names) != 0 {
+ methods = append(methods, method.Names[0].Name)
+ }
+ }
+ }
+ }
+ if len(methods) == 0 {
+ panic("StrictServerInterface not found")
+ }
+ slices.Sort(methods)
+
+ type TemplateParameter struct {
+ Name string
+ RequiresLogin bool
+ RequiresAdminRole bool
+ }
+ templateParameters := make([]TemplateParameter, len(methods))
+ for i, method := range methods {
+ templateParameters[i] = TemplateParameter{
+ Name: method,
+ RequiresLogin: method != "PostLogin",
+ RequiresAdminRole: strings.Contains(method, "Admin"),
+ }
+ }
+
+ // Generate code.
+ tmpl, err := template.New("code").Parse(templateText)
+ if err != nil {
+ panic(err)
+ }
+
+ var buf bytes.Buffer
+ err = tmpl.Execute(&buf, templateParameters)
+ if err != nil {
+ panic(err)
+ }
+
+ formatted, err := format.Source(buf.Bytes())
+ if err != nil {
+ panic(err)
+ }
+
+ err = os.WriteFile(*outputFile, formatted, 0644)
+ if err != nil {
+ panic(err)
+ }
+}
+
+const templateText = `// Code generated by go generate; DO NOT EDIT.
+
+package api
+
+import (
+ "context"
+ "errors"
+ "strings"
+
+ "github.com/nsfisis/iosdc-japan-2024-albatross/backend/auth"
+ "github.com/nsfisis/iosdc-japan-2024-albatross/backend/db"
+)
+
+var _ StrictServerInterface = (*ApiHandlerWrapper)(nil)
+
+type ApiHandlerWrapper struct {
+ innerHandler ApiHandler
+}
+
+func NewHandler(queries *db.Queries, hubs GameHubsInterface) *ApiHandlerWrapper {
+ return &ApiHandlerWrapper{
+ innerHandler: ApiHandler{
+ q: queries,
+ hubs: hubs,
+ },
+ }
+}
+
+func parseJWTClaimsFromAuthorizationHeader(authorization string) (*auth.JWTClaims, error) {
+ const prefix = "Bearer "
+ if !strings.HasPrefix(authorization, prefix) {
+ return nil, errors.New("invalid authorization header")
+ }
+ token := authorization[len(prefix):]
+ claims, err := auth.ParseJWT(token)
+ if err != nil {
+ return nil, err
+ }
+ return claims, nil
+}
+
+{{ range . }}
+ func (h *ApiHandlerWrapper) {{ .Name }}(ctx context.Context, request {{ .Name }}RequestObject) ({{ .Name }}ResponseObject, error) {
+ {{ if .RequiresLogin -}}
+ user, err := parseJWTClaimsFromAuthorizationHeader(request.Params.Authorization)
+ if err != nil {
+ return {{ .Name }}401JSONResponse{
+ Message: "Unauthorized",
+ }, nil
+ }
+ {{ if .RequiresAdminRole -}}
+ if !user.IsAdmin {
+ return {{ .Name }}403JSONResponse{
+ Message: "Forbidden",
+ }, nil
+ }
+ {{ end -}}
+ return h.innerHandler.{{ .Name }}(ctx, request, user)
+ {{ else -}}
+ return h.innerHandler.{{ .Name }}(ctx, request)
+ {{ end -}}
+ }
+{{ end }}
+`
diff --git a/backend/gen/gen.go b/backend/gen/gen.go
index 6af912d..6fb430f 100644
--- a/backend/gen/gen.go
+++ b/backend/gen/gen.go
@@ -2,3 +2,4 @@ package main
//go:generate go run github.com/sqlc-dev/sqlc/cmd/sqlc generate
//go:generate go run github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen -config oapi-codegen.yaml ../../openapi.yaml
+//go:generate go run ./api_handler_wrapper_gen.go -i ../api/generated.go -o ../api/handler_wrapper.go
diff --git a/backend/main.go b/backend/main.go
index d636af7..0257113 100644
--- a/backend/main.go
+++ b/backend/main.go
@@ -75,9 +75,7 @@ func main() {
apiGroup := e.Group("/api")
apiGroup.Use(oapimiddleware.OapiRequestValidator(openApiSpec))
apiHandler := api.NewHandler(queries, gameHubs)
- api.RegisterHandlers(apiGroup, api.NewStrictHandler(apiHandler, []api.StrictMiddlewareFunc{
- api.NewJWTMiddleware(),
- }))
+ api.RegisterHandlers(apiGroup, api.NewStrictHandler(apiHandler, nil))
gameHubs.Run()