diff options
| author | nsfisis <nsfisis@gmail.com> | 2026-02-15 11:12:50 +0900 |
|---|---|---|
| committer | nsfisis <nsfisis@gmail.com> | 2026-02-15 11:14:28 +0900 |
| commit | 96fad1a4e78c7209e5a0f3496e8b59d591fbe500 (patch) | |
| tree | 8e43fb3918cd7401fe68cac933fe943c794b7634 /backend/api | |
| parent | 2f1a8a1c599300d0964d7fbbfd824e2d74f0bf4a (diff) | |
| download | phperkaigi-2026-albatross-96fad1a4e78c7209e5a0f3496e8b59d591fbe500.tar.gz phperkaigi-2026-albatross-96fad1a4e78c7209e5a0f3496e8b59d591fbe500.tar.zst phperkaigi-2026-albatross-96fad1a4e78c7209e5a0f3496e8b59d591fbe500.zip | |
refactor(auth): replace JWT authentication with server-side sessions
Migrate from stateless JWT tokens to server-side session management
backed by PostgreSQL. Sessions are hashed with SHA-256 before storage,
cleaned up periodically, and invalidated on logout. This removes the
need for JWT_SECRET/COOKIE_SECRET environment variables and the
golang-jwt dependency.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'backend/api')
| -rw-r--r-- | backend/api/auth_middleware.go | 42 | ||||
| -rw-r--r-- | backend/api/handler.go | 76 | ||||
| -rw-r--r-- | backend/api/handler_wrapper.go | 20 |
3 files changed, 78 insertions, 60 deletions
diff --git a/backend/api/auth_middleware.go b/backend/api/auth_middleware.go index 97f8946..d721f1d 100644 --- a/backend/api/auth_middleware.go +++ b/backend/api/auth_middleware.go @@ -6,27 +6,39 @@ import ( "github.com/labstack/echo/v4" "albatross-2026-backend/auth" + "albatross-2026-backend/db" ) -type contextKey struct{} +type sessionIDContextKey struct{} +type userContextKey struct{} -func JWTCookieMiddleware(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - cookie, err := c.Cookie("albatross_token") - if err != nil { +func SessionCookieMiddleware(q *db.Queries) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + cookie, err := c.Cookie("albatross_session") + if err != nil { + return next(c) + } + hashedID := auth.HashSessionID(cookie.Value) + user, err := q.GetUserBySession(c.Request().Context(), hashedID) + if err != nil { + return next(c) + } + ctx := c.Request().Context() + ctx = context.WithValue(ctx, sessionIDContextKey{}, hashedID) + ctx = context.WithValue(ctx, userContextKey{}, &user) + c.SetRequest(c.Request().WithContext(ctx)) return next(c) } - claims, err := auth.ParseJWT(cookie.Value) - if err != nil { - return next(c) - } - ctx := context.WithValue(c.Request().Context(), contextKey{}, claims) - c.SetRequest(c.Request().WithContext(ctx)) - return next(c) } } -func GetJWTClaimsFromContext(ctx context.Context) (*auth.JWTClaims, bool) { - claims, ok := ctx.Value(contextKey{}).(*auth.JWTClaims) - return claims, ok +func GetSessionIDFromContext(ctx context.Context) (string, bool) { + sessionID, ok := ctx.Value(sessionIDContextKey{}).(string) + return sessionID, ok +} + +func GetUserFromContext(ctx context.Context) (*db.User, bool) { + user, ok := ctx.Value(userContextKey{}).(*db.User) + return user, ok } diff --git a/backend/api/handler.go b/backend/api/handler.go index 25aea01..9f8849c 100644 --- a/backend/api/handler.go +++ b/backend/api/handler.go @@ -7,8 +7,10 @@ import ( "log" "net/http" "strconv" + "time" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" "github.com/labstack/echo/v4" "github.com/oapi-codegen/nullable" @@ -64,15 +66,25 @@ func (h *Handler) PostLogin(ctx context.Context, request PostLoginRequestObject) }, nil } - jwt, err := auth.NewJWT(&dbUser) + sessionID, err := auth.GenerateSessionID() if err != nil { return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } + hashedID := auth.HashSessionID(sessionID) + expiresAt := pgtype.Timestamp{Time: time.Now().Add(24 * time.Hour), Valid: true} + if err := h.q.CreateSession(ctx, db.CreateSessionParams{ + SessionID: hashedID, + UserID: dbUser.UserID, + ExpiresAt: expiresAt, + }); err != nil { + return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return postLoginCookieResponse{ cookie: http.Cookie{ - Name: "albatross_token", - Value: jwt, + Name: "albatross_session", + Value: sessionID, Path: h.conf.BasePath, MaxAge: 86400, HttpOnly: true, @@ -92,21 +104,15 @@ func (h *Handler) PostLogin(ctx context.Context, request PostLoginRequestObject) }, nil } -func (h *Handler) GetMe(ctx context.Context, _ GetMeRequestObject, claims *auth.JWTClaims) (GetMeResponseObject, error) { - dbUser, err := h.q.GetUserByID(ctx, int32(claims.UserID)) - if err != nil { - return GetMe401JSONResponse{ - Message: "Unauthorized", - }, nil - } +func (h *Handler) GetMe(_ context.Context, _ GetMeRequestObject, user *db.User) (GetMeResponseObject, error) { return GetMe200JSONResponse{ User: User{ - UserID: int(dbUser.UserID), - Username: dbUser.Username, - DisplayName: dbUser.DisplayName, - IconPath: dbUser.IconPath, - IsAdmin: dbUser.IsAdmin, - Label: toNullable(dbUser.Label), + UserID: int(user.UserID), + Username: user.Username, + DisplayName: user.DisplayName, + IconPath: user.IconPath, + IsAdmin: user.IsAdmin, + Label: toNullable(user.Label), }, }, nil } @@ -121,10 +127,13 @@ func (r postLogoutCookieResponse) VisitPostLogoutResponse(w http.ResponseWriter) return nil } -func (h *Handler) PostLogout(_ context.Context, _ PostLogoutRequestObject, _ *auth.JWTClaims) (PostLogoutResponseObject, error) { +func (h *Handler) PostLogout(ctx context.Context, _ PostLogoutRequestObject, _ *db.User) (PostLogoutResponseObject, error) { + if sessionID, ok := GetSessionIDFromContext(ctx); ok { + _ = h.q.DeleteSession(ctx, sessionID) + } return postLogoutCookieResponse{ cookie: http.Cookie{ - Name: "albatross_token", + Name: "albatross_session", Value: "", Path: h.conf.BasePath, MaxAge: -1, @@ -135,7 +144,7 @@ func (h *Handler) PostLogout(_ context.Context, _ PostLogoutRequestObject, _ *au }, nil } -func (h *Handler) GetGames(ctx context.Context, _ GetGamesRequestObject, _ *auth.JWTClaims) (GetGamesResponseObject, error) { +func (h *Handler) GetGames(ctx context.Context, _ GetGamesRequestObject, _ *db.User) (GetGamesResponseObject, error) { gameRows, err := h.q.ListPublicGames(ctx) if err != nil { return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) @@ -188,7 +197,7 @@ func (h *Handler) GetGames(ctx context.Context, _ GetGamesRequestObject, _ *auth }, nil } -func (h *Handler) GetGame(ctx context.Context, request GetGameRequestObject, user *auth.JWTClaims) (GetGameResponseObject, error) { +func (h *Handler) GetGame(ctx context.Context, request GetGameRequestObject, user *db.User) (GetGameResponseObject, error) { gameID := request.GameID row, err := h.q.GetGameByID(ctx, int32(gameID)) if err != nil { @@ -245,12 +254,11 @@ func (h *Handler) GetGame(ctx context.Context, request GetGameRequestObject, use }, nil } -func (h *Handler) GetGamePlayLatestState(ctx context.Context, request GetGamePlayLatestStateRequestObject, user *auth.JWTClaims) (GetGamePlayLatestStateResponseObject, error) { +func (h *Handler) GetGamePlayLatestState(ctx context.Context, request GetGamePlayLatestStateRequestObject, user *db.User) (GetGamePlayLatestStateResponseObject, error) { gameID := request.GameID - userID := user.UserID row, err := h.q.GetLatestState(ctx, db.GetLatestStateParams{ GameID: int32(gameID), - UserID: int32(userID), + UserID: user.UserID, }) if err != nil { if errors.Is(err, pgx.ErrNoRows) { @@ -275,7 +283,7 @@ func (h *Handler) GetGamePlayLatestState(ctx context.Context, request GetGamePla }, nil } -func (h *Handler) GetGameWatchLatestStates(ctx context.Context, request GetGameWatchLatestStatesRequestObject, user *auth.JWTClaims) (GetGameWatchLatestStatesResponseObject, error) { +func (h *Handler) GetGameWatchLatestStates(ctx context.Context, request GetGameWatchLatestStatesRequestObject, user *db.User) (GetGameWatchLatestStatesResponseObject, error) { gameID := request.GameID rows, err := h.q.GetLatestStatesOfMainPlayers(ctx, int32(gameID)) if err != nil { @@ -300,7 +308,7 @@ func (h *Handler) GetGameWatchLatestStates(ctx context.Context, request GetGameW Status: status, } - if int(row.UserID) == user.UserID && !user.IsAdmin { + if row.UserID == user.UserID && !user.IsAdmin { return GetGameWatchLatestStates403JSONResponse{ Message: "You are one of the main players of this game", }, nil @@ -311,7 +319,7 @@ func (h *Handler) GetGameWatchLatestStates(ctx context.Context, request GetGameW }, nil } -func (h *Handler) GetGameWatchRanking(ctx context.Context, request GetGameWatchRankingRequestObject, _ *auth.JWTClaims) (GetGameWatchRankingResponseObject, error) { +func (h *Handler) GetGameWatchRanking(ctx context.Context, request GetGameWatchRankingRequestObject, _ *db.User) (GetGameWatchRankingResponseObject, error) { gameID := request.GameID rows, err := h.q.GetRanking(ctx, int32(gameID)) if err != nil { @@ -342,13 +350,12 @@ func (h *Handler) GetGameWatchRanking(ctx context.Context, request GetGameWatchR }, nil } -func (h *Handler) PostGamePlayCode(ctx context.Context, request PostGamePlayCodeRequestObject, user *auth.JWTClaims) (PostGamePlayCodeResponseObject, error) { +func (h *Handler) PostGamePlayCode(ctx context.Context, request PostGamePlayCodeRequestObject, user *db.User) (PostGamePlayCodeResponseObject, error) { gameID := request.GameID - userID := user.UserID // TODO: check if the game is running err := h.q.UpdateCode(ctx, db.UpdateCodeParams{ GameID: int32(gameID), - UserID: int32(userID), + UserID: user.UserID, Code: request.Body.Code, Status: "none", }) @@ -358,9 +365,8 @@ func (h *Handler) PostGamePlayCode(ctx context.Context, request PostGamePlayCode return PostGamePlayCode200Response{}, nil } -func (h *Handler) PostGamePlaySubmit(ctx context.Context, request PostGamePlaySubmitRequestObject, user *auth.JWTClaims) (PostGamePlaySubmitResponseObject, error) { +func (h *Handler) PostGamePlaySubmit(ctx context.Context, request PostGamePlaySubmitRequestObject, user *db.User) (PostGamePlaySubmitResponseObject, error) { gameID := request.GameID - userID := user.UserID code := request.Body.Code gameRow, err := h.q.GetGameByID(ctx, int32(gameID)) @@ -377,7 +383,7 @@ func (h *Handler) PostGamePlaySubmit(ctx context.Context, request PostGamePlaySu // TODO: transaction err = h.q.UpdateCodeAndStatus(ctx, db.UpdateCodeAndStatusParams{ GameID: int32(gameID), - UserID: int32(userID), + UserID: user.UserID, Code: code, Status: "running", }) @@ -386,21 +392,21 @@ func (h *Handler) PostGamePlaySubmit(ctx context.Context, request PostGamePlaySu } submissionID, err := h.q.CreateSubmission(ctx, db.CreateSubmissionParams{ GameID: int32(gameID), - UserID: int32(userID), + UserID: user.UserID, Code: code, CodeSize: int32(codeSize), }) if err != nil { return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - err = h.hub.EnqueueTestTasks(ctx, int(submissionID), gameID, userID, language, code) + err = h.hub.EnqueueTestTasks(ctx, int(submissionID), gameID, int(user.UserID), language, code) if err != nil { return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return PostGamePlaySubmit200Response{}, nil } -func (h *Handler) GetTournament(ctx context.Context, request GetTournamentRequestObject, _ *auth.JWTClaims) (GetTournamentResponseObject, error) { +func (h *Handler) GetTournament(ctx context.Context, request GetTournamentRequestObject, _ *db.User) (GetTournamentResponseObject, error) { gameIDs := []int32{ int32(request.Params.Game1), int32(request.Params.Game2), diff --git a/backend/api/handler_wrapper.go b/backend/api/handler_wrapper.go index 5feaac7..8e3e8cd 100644 --- a/backend/api/handler_wrapper.go +++ b/backend/api/handler_wrapper.go @@ -26,7 +26,7 @@ func NewHandler(queries *db.Queries, hub GameHubInterface, conf *config.Config) } func (h *HandlerWrapper) GetGame(ctx context.Context, request GetGameRequestObject) (GetGameResponseObject, error) { - user, ok := GetJWTClaimsFromContext(ctx) + user, ok := GetUserFromContext(ctx) if !ok { return GetGame401JSONResponse{ Message: "Unauthorized", @@ -36,7 +36,7 @@ func (h *HandlerWrapper) GetGame(ctx context.Context, request GetGameRequestObje } func (h *HandlerWrapper) GetGamePlayLatestState(ctx context.Context, request GetGamePlayLatestStateRequestObject) (GetGamePlayLatestStateResponseObject, error) { - user, ok := GetJWTClaimsFromContext(ctx) + user, ok := GetUserFromContext(ctx) if !ok { return GetGamePlayLatestState401JSONResponse{ Message: "Unauthorized", @@ -46,7 +46,7 @@ func (h *HandlerWrapper) GetGamePlayLatestState(ctx context.Context, request Get } func (h *HandlerWrapper) GetGameWatchLatestStates(ctx context.Context, request GetGameWatchLatestStatesRequestObject) (GetGameWatchLatestStatesResponseObject, error) { - user, ok := GetJWTClaimsFromContext(ctx) + user, ok := GetUserFromContext(ctx) if !ok { return GetGameWatchLatestStates401JSONResponse{ Message: "Unauthorized", @@ -56,7 +56,7 @@ func (h *HandlerWrapper) GetGameWatchLatestStates(ctx context.Context, request G } func (h *HandlerWrapper) GetGameWatchRanking(ctx context.Context, request GetGameWatchRankingRequestObject) (GetGameWatchRankingResponseObject, error) { - user, ok := GetJWTClaimsFromContext(ctx) + user, ok := GetUserFromContext(ctx) if !ok { return GetGameWatchRanking401JSONResponse{ Message: "Unauthorized", @@ -66,7 +66,7 @@ func (h *HandlerWrapper) GetGameWatchRanking(ctx context.Context, request GetGam } func (h *HandlerWrapper) GetGames(ctx context.Context, request GetGamesRequestObject) (GetGamesResponseObject, error) { - user, ok := GetJWTClaimsFromContext(ctx) + user, ok := GetUserFromContext(ctx) if !ok { return GetGames401JSONResponse{ Message: "Unauthorized", @@ -76,7 +76,7 @@ func (h *HandlerWrapper) GetGames(ctx context.Context, request GetGamesRequestOb } func (h *HandlerWrapper) GetMe(ctx context.Context, request GetMeRequestObject) (GetMeResponseObject, error) { - user, ok := GetJWTClaimsFromContext(ctx) + user, ok := GetUserFromContext(ctx) if !ok { return GetMe401JSONResponse{ Message: "Unauthorized", @@ -86,7 +86,7 @@ func (h *HandlerWrapper) GetMe(ctx context.Context, request GetMeRequestObject) } func (h *HandlerWrapper) GetTournament(ctx context.Context, request GetTournamentRequestObject) (GetTournamentResponseObject, error) { - user, ok := GetJWTClaimsFromContext(ctx) + user, ok := GetUserFromContext(ctx) if !ok { return GetTournament401JSONResponse{ Message: "Unauthorized", @@ -96,7 +96,7 @@ func (h *HandlerWrapper) GetTournament(ctx context.Context, request GetTournamen } func (h *HandlerWrapper) PostGamePlayCode(ctx context.Context, request PostGamePlayCodeRequestObject) (PostGamePlayCodeResponseObject, error) { - user, ok := GetJWTClaimsFromContext(ctx) + user, ok := GetUserFromContext(ctx) if !ok { return PostGamePlayCode401JSONResponse{ Message: "Unauthorized", @@ -106,7 +106,7 @@ func (h *HandlerWrapper) PostGamePlayCode(ctx context.Context, request PostGameP } func (h *HandlerWrapper) PostGamePlaySubmit(ctx context.Context, request PostGamePlaySubmitRequestObject) (PostGamePlaySubmitResponseObject, error) { - user, ok := GetJWTClaimsFromContext(ctx) + user, ok := GetUserFromContext(ctx) if !ok { return PostGamePlaySubmit401JSONResponse{ Message: "Unauthorized", @@ -120,7 +120,7 @@ func (h *HandlerWrapper) PostLogin(ctx context.Context, request PostLoginRequest } func (h *HandlerWrapper) PostLogout(ctx context.Context, request PostLogoutRequestObject) (PostLogoutResponseObject, error) { - user, ok := GetJWTClaimsFromContext(ctx) + user, ok := GetUserFromContext(ctx) if !ok { return PostLogout401JSONResponse{ Message: "Unauthorized", |
