diff options
Diffstat (limited to 'backend')
| -rw-r--r-- | backend/admin/handler.go | 12 | ||||
| -rw-r--r-- | backend/api/auth_middleware.go | 42 | ||||
| -rw-r--r-- | backend/api/handler.go | 76 | ||||
| -rw-r--r-- | backend/api/handler_wrapper.go | 20 | ||||
| -rw-r--r-- | backend/auth/jwt.go | 60 | ||||
| -rw-r--r-- | backend/auth/session.go | 21 | ||||
| -rw-r--r-- | backend/db/models.go | 7 | ||||
| -rw-r--r-- | backend/db/query.sql.go | 54 | ||||
| -rw-r--r-- | backend/gen/api/handler_wrapper_gen.go | 2 | ||||
| -rw-r--r-- | backend/go.mod | 1 | ||||
| -rw-r--r-- | backend/go.sum | 2 | ||||
| -rw-r--r-- | backend/main.go | 20 | ||||
| -rw-r--r-- | backend/query.sql | 14 | ||||
| -rw-r--r-- | backend/schema.sql | 10 |
14 files changed, 208 insertions, 133 deletions
diff --git a/backend/admin/handler.go b/backend/admin/handler.go index 28e7970..a18e32a 100644 --- a/backend/admin/handler.go +++ b/backend/admin/handler.go @@ -13,7 +13,7 @@ import ( "github.com/labstack/echo/v4" "albatross-2026-backend/account" - "albatross-2026-backend/auth" + "albatross-2026-backend/api" "albatross-2026-backend/config" "albatross-2026-backend/db" ) @@ -32,15 +32,11 @@ func NewHandler(q *db.Queries, conf *config.Config) *Handler { func (h *Handler) newAdminMiddleware() echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - jwt, err := c.Cookie("albatross_token") - if err != nil { - return c.Redirect(http.StatusSeeOther, h.conf.BasePath+"login") - } - claims, err := auth.ParseJWT(jwt.Value) - if err != nil { + user, ok := api.GetUserFromContext(c.Request().Context()) + if !ok { return c.Redirect(http.StatusSeeOther, h.conf.BasePath+"login") } - if !claims.IsAdmin { + if !user.IsAdmin { return echo.NewHTTPError(http.StatusForbidden) } return next(c) diff --git a/backend/api/auth_middleware.go b/backend/api/auth_middleware.go index 97f8946..d721f1d 100644 --- a/backend/api/auth_middleware.go +++ b/backend/api/auth_middleware.go @@ -6,27 +6,39 @@ import ( "github.com/labstack/echo/v4" "albatross-2026-backend/auth" + "albatross-2026-backend/db" ) -type contextKey struct{} +type sessionIDContextKey struct{} +type userContextKey struct{} -func JWTCookieMiddleware(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - cookie, err := c.Cookie("albatross_token") - if err != nil { +func SessionCookieMiddleware(q *db.Queries) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + cookie, err := c.Cookie("albatross_session") + if err != nil { + return next(c) + } + hashedID := auth.HashSessionID(cookie.Value) + user, err := q.GetUserBySession(c.Request().Context(), hashedID) + if err != nil { + return next(c) + } + ctx := c.Request().Context() + ctx = context.WithValue(ctx, sessionIDContextKey{}, hashedID) + ctx = context.WithValue(ctx, userContextKey{}, &user) + c.SetRequest(c.Request().WithContext(ctx)) return next(c) } - claims, err := auth.ParseJWT(cookie.Value) - if err != nil { - return next(c) - } - ctx := context.WithValue(c.Request().Context(), contextKey{}, claims) - c.SetRequest(c.Request().WithContext(ctx)) - return next(c) } } -func GetJWTClaimsFromContext(ctx context.Context) (*auth.JWTClaims, bool) { - claims, ok := ctx.Value(contextKey{}).(*auth.JWTClaims) - return claims, ok +func GetSessionIDFromContext(ctx context.Context) (string, bool) { + sessionID, ok := ctx.Value(sessionIDContextKey{}).(string) + return sessionID, ok +} + +func GetUserFromContext(ctx context.Context) (*db.User, bool) { + user, ok := ctx.Value(userContextKey{}).(*db.User) + return user, ok } diff --git a/backend/api/handler.go b/backend/api/handler.go index 25aea01..9f8849c 100644 --- a/backend/api/handler.go +++ b/backend/api/handler.go @@ -7,8 +7,10 @@ import ( "log" "net/http" "strconv" + "time" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" "github.com/labstack/echo/v4" "github.com/oapi-codegen/nullable" @@ -64,15 +66,25 @@ func (h *Handler) PostLogin(ctx context.Context, request PostLoginRequestObject) }, nil } - jwt, err := auth.NewJWT(&dbUser) + sessionID, err := auth.GenerateSessionID() if err != nil { return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } + hashedID := auth.HashSessionID(sessionID) + expiresAt := pgtype.Timestamp{Time: time.Now().Add(24 * time.Hour), Valid: true} + if err := h.q.CreateSession(ctx, db.CreateSessionParams{ + SessionID: hashedID, + UserID: dbUser.UserID, + ExpiresAt: expiresAt, + }); err != nil { + return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return postLoginCookieResponse{ cookie: http.Cookie{ - Name: "albatross_token", - Value: jwt, + Name: "albatross_session", + Value: sessionID, Path: h.conf.BasePath, MaxAge: 86400, HttpOnly: true, @@ -92,21 +104,15 @@ func (h *Handler) PostLogin(ctx context.Context, request PostLoginRequestObject) }, nil } -func (h *Handler) GetMe(ctx context.Context, _ GetMeRequestObject, claims *auth.JWTClaims) (GetMeResponseObject, error) { - dbUser, err := h.q.GetUserByID(ctx, int32(claims.UserID)) - if err != nil { - return GetMe401JSONResponse{ - Message: "Unauthorized", - }, nil - } +func (h *Handler) GetMe(_ context.Context, _ GetMeRequestObject, user *db.User) (GetMeResponseObject, error) { return GetMe200JSONResponse{ User: User{ - UserID: int(dbUser.UserID), - Username: dbUser.Username, - DisplayName: dbUser.DisplayName, - IconPath: dbUser.IconPath, - IsAdmin: dbUser.IsAdmin, - Label: toNullable(dbUser.Label), + UserID: int(user.UserID), + Username: user.Username, + DisplayName: user.DisplayName, + IconPath: user.IconPath, + IsAdmin: user.IsAdmin, + Label: toNullable(user.Label), }, }, nil } @@ -121,10 +127,13 @@ func (r postLogoutCookieResponse) VisitPostLogoutResponse(w http.ResponseWriter) return nil } -func (h *Handler) PostLogout(_ context.Context, _ PostLogoutRequestObject, _ *auth.JWTClaims) (PostLogoutResponseObject, error) { +func (h *Handler) PostLogout(ctx context.Context, _ PostLogoutRequestObject, _ *db.User) (PostLogoutResponseObject, error) { + if sessionID, ok := GetSessionIDFromContext(ctx); ok { + _ = h.q.DeleteSession(ctx, sessionID) + } return postLogoutCookieResponse{ cookie: http.Cookie{ - Name: "albatross_token", + Name: "albatross_session", Value: "", Path: h.conf.BasePath, MaxAge: -1, @@ -135,7 +144,7 @@ func (h *Handler) PostLogout(_ context.Context, _ PostLogoutRequestObject, _ *au }, nil } -func (h *Handler) GetGames(ctx context.Context, _ GetGamesRequestObject, _ *auth.JWTClaims) (GetGamesResponseObject, error) { +func (h *Handler) GetGames(ctx context.Context, _ GetGamesRequestObject, _ *db.User) (GetGamesResponseObject, error) { gameRows, err := h.q.ListPublicGames(ctx) if err != nil { return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) @@ -188,7 +197,7 @@ func (h *Handler) GetGames(ctx context.Context, _ GetGamesRequestObject, _ *auth }, nil } -func (h *Handler) GetGame(ctx context.Context, request GetGameRequestObject, user *auth.JWTClaims) (GetGameResponseObject, error) { +func (h *Handler) GetGame(ctx context.Context, request GetGameRequestObject, user *db.User) (GetGameResponseObject, error) { gameID := request.GameID row, err := h.q.GetGameByID(ctx, int32(gameID)) if err != nil { @@ -245,12 +254,11 @@ func (h *Handler) GetGame(ctx context.Context, request GetGameRequestObject, use }, nil } -func (h *Handler) GetGamePlayLatestState(ctx context.Context, request GetGamePlayLatestStateRequestObject, user *auth.JWTClaims) (GetGamePlayLatestStateResponseObject, error) { +func (h *Handler) GetGamePlayLatestState(ctx context.Context, request GetGamePlayLatestStateRequestObject, user *db.User) (GetGamePlayLatestStateResponseObject, error) { gameID := request.GameID - userID := user.UserID row, err := h.q.GetLatestState(ctx, db.GetLatestStateParams{ GameID: int32(gameID), - UserID: int32(userID), + UserID: user.UserID, }) if err != nil { if errors.Is(err, pgx.ErrNoRows) { @@ -275,7 +283,7 @@ func (h *Handler) GetGamePlayLatestState(ctx context.Context, request GetGamePla }, nil } -func (h *Handler) GetGameWatchLatestStates(ctx context.Context, request GetGameWatchLatestStatesRequestObject, user *auth.JWTClaims) (GetGameWatchLatestStatesResponseObject, error) { +func (h *Handler) GetGameWatchLatestStates(ctx context.Context, request GetGameWatchLatestStatesRequestObject, user *db.User) (GetGameWatchLatestStatesResponseObject, error) { gameID := request.GameID rows, err := h.q.GetLatestStatesOfMainPlayers(ctx, int32(gameID)) if err != nil { @@ -300,7 +308,7 @@ func (h *Handler) GetGameWatchLatestStates(ctx context.Context, request GetGameW Status: status, } - if int(row.UserID) == user.UserID && !user.IsAdmin { + if row.UserID == user.UserID && !user.IsAdmin { return GetGameWatchLatestStates403JSONResponse{ Message: "You are one of the main players of this game", }, nil @@ -311,7 +319,7 @@ func (h *Handler) GetGameWatchLatestStates(ctx context.Context, request GetGameW }, nil } -func (h *Handler) GetGameWatchRanking(ctx context.Context, request GetGameWatchRankingRequestObject, _ *auth.JWTClaims) (GetGameWatchRankingResponseObject, error) { +func (h *Handler) GetGameWatchRanking(ctx context.Context, request GetGameWatchRankingRequestObject, _ *db.User) (GetGameWatchRankingResponseObject, error) { gameID := request.GameID rows, err := h.q.GetRanking(ctx, int32(gameID)) if err != nil { @@ -342,13 +350,12 @@ func (h *Handler) GetGameWatchRanking(ctx context.Context, request GetGameWatchR }, nil } -func (h *Handler) PostGamePlayCode(ctx context.Context, request PostGamePlayCodeRequestObject, user *auth.JWTClaims) (PostGamePlayCodeResponseObject, error) { +func (h *Handler) PostGamePlayCode(ctx context.Context, request PostGamePlayCodeRequestObject, user *db.User) (PostGamePlayCodeResponseObject, error) { gameID := request.GameID - userID := user.UserID // TODO: check if the game is running err := h.q.UpdateCode(ctx, db.UpdateCodeParams{ GameID: int32(gameID), - UserID: int32(userID), + UserID: user.UserID, Code: request.Body.Code, Status: "none", }) @@ -358,9 +365,8 @@ func (h *Handler) PostGamePlayCode(ctx context.Context, request PostGamePlayCode return PostGamePlayCode200Response{}, nil } -func (h *Handler) PostGamePlaySubmit(ctx context.Context, request PostGamePlaySubmitRequestObject, user *auth.JWTClaims) (PostGamePlaySubmitResponseObject, error) { +func (h *Handler) PostGamePlaySubmit(ctx context.Context, request PostGamePlaySubmitRequestObject, user *db.User) (PostGamePlaySubmitResponseObject, error) { gameID := request.GameID - userID := user.UserID code := request.Body.Code gameRow, err := h.q.GetGameByID(ctx, int32(gameID)) @@ -377,7 +383,7 @@ func (h *Handler) PostGamePlaySubmit(ctx context.Context, request PostGamePlaySu // TODO: transaction err = h.q.UpdateCodeAndStatus(ctx, db.UpdateCodeAndStatusParams{ GameID: int32(gameID), - UserID: int32(userID), + UserID: user.UserID, Code: code, Status: "running", }) @@ -386,21 +392,21 @@ func (h *Handler) PostGamePlaySubmit(ctx context.Context, request PostGamePlaySu } submissionID, err := h.q.CreateSubmission(ctx, db.CreateSubmissionParams{ GameID: int32(gameID), - UserID: int32(userID), + UserID: user.UserID, Code: code, CodeSize: int32(codeSize), }) if err != nil { return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - err = h.hub.EnqueueTestTasks(ctx, int(submissionID), gameID, userID, language, code) + err = h.hub.EnqueueTestTasks(ctx, int(submissionID), gameID, int(user.UserID), language, code) if err != nil { return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return PostGamePlaySubmit200Response{}, nil } -func (h *Handler) GetTournament(ctx context.Context, request GetTournamentRequestObject, _ *auth.JWTClaims) (GetTournamentResponseObject, error) { +func (h *Handler) GetTournament(ctx context.Context, request GetTournamentRequestObject, _ *db.User) (GetTournamentResponseObject, error) { gameIDs := []int32{ int32(request.Params.Game1), int32(request.Params.Game2), diff --git a/backend/api/handler_wrapper.go b/backend/api/handler_wrapper.go index 5feaac7..8e3e8cd 100644 --- a/backend/api/handler_wrapper.go +++ b/backend/api/handler_wrapper.go @@ -26,7 +26,7 @@ func NewHandler(queries *db.Queries, hub GameHubInterface, conf *config.Config) } func (h *HandlerWrapper) GetGame(ctx context.Context, request GetGameRequestObject) (GetGameResponseObject, error) { - user, ok := GetJWTClaimsFromContext(ctx) + user, ok := GetUserFromContext(ctx) if !ok { return GetGame401JSONResponse{ Message: "Unauthorized", @@ -36,7 +36,7 @@ func (h *HandlerWrapper) GetGame(ctx context.Context, request GetGameRequestObje } func (h *HandlerWrapper) GetGamePlayLatestState(ctx context.Context, request GetGamePlayLatestStateRequestObject) (GetGamePlayLatestStateResponseObject, error) { - user, ok := GetJWTClaimsFromContext(ctx) + user, ok := GetUserFromContext(ctx) if !ok { return GetGamePlayLatestState401JSONResponse{ Message: "Unauthorized", @@ -46,7 +46,7 @@ func (h *HandlerWrapper) GetGamePlayLatestState(ctx context.Context, request Get } func (h *HandlerWrapper) GetGameWatchLatestStates(ctx context.Context, request GetGameWatchLatestStatesRequestObject) (GetGameWatchLatestStatesResponseObject, error) { - user, ok := GetJWTClaimsFromContext(ctx) + user, ok := GetUserFromContext(ctx) if !ok { return GetGameWatchLatestStates401JSONResponse{ Message: "Unauthorized", @@ -56,7 +56,7 @@ func (h *HandlerWrapper) GetGameWatchLatestStates(ctx context.Context, request G } func (h *HandlerWrapper) GetGameWatchRanking(ctx context.Context, request GetGameWatchRankingRequestObject) (GetGameWatchRankingResponseObject, error) { - user, ok := GetJWTClaimsFromContext(ctx) + user, ok := GetUserFromContext(ctx) if !ok { return GetGameWatchRanking401JSONResponse{ Message: "Unauthorized", @@ -66,7 +66,7 @@ func (h *HandlerWrapper) GetGameWatchRanking(ctx context.Context, request GetGam } func (h *HandlerWrapper) GetGames(ctx context.Context, request GetGamesRequestObject) (GetGamesResponseObject, error) { - user, ok := GetJWTClaimsFromContext(ctx) + user, ok := GetUserFromContext(ctx) if !ok { return GetGames401JSONResponse{ Message: "Unauthorized", @@ -76,7 +76,7 @@ func (h *HandlerWrapper) GetGames(ctx context.Context, request GetGamesRequestOb } func (h *HandlerWrapper) GetMe(ctx context.Context, request GetMeRequestObject) (GetMeResponseObject, error) { - user, ok := GetJWTClaimsFromContext(ctx) + user, ok := GetUserFromContext(ctx) if !ok { return GetMe401JSONResponse{ Message: "Unauthorized", @@ -86,7 +86,7 @@ func (h *HandlerWrapper) GetMe(ctx context.Context, request GetMeRequestObject) } func (h *HandlerWrapper) GetTournament(ctx context.Context, request GetTournamentRequestObject) (GetTournamentResponseObject, error) { - user, ok := GetJWTClaimsFromContext(ctx) + user, ok := GetUserFromContext(ctx) if !ok { return GetTournament401JSONResponse{ Message: "Unauthorized", @@ -96,7 +96,7 @@ func (h *HandlerWrapper) GetTournament(ctx context.Context, request GetTournamen } func (h *HandlerWrapper) PostGamePlayCode(ctx context.Context, request PostGamePlayCodeRequestObject) (PostGamePlayCodeResponseObject, error) { - user, ok := GetJWTClaimsFromContext(ctx) + user, ok := GetUserFromContext(ctx) if !ok { return PostGamePlayCode401JSONResponse{ Message: "Unauthorized", @@ -106,7 +106,7 @@ func (h *HandlerWrapper) PostGamePlayCode(ctx context.Context, request PostGameP } func (h *HandlerWrapper) PostGamePlaySubmit(ctx context.Context, request PostGamePlaySubmitRequestObject) (PostGamePlaySubmitResponseObject, error) { - user, ok := GetJWTClaimsFromContext(ctx) + user, ok := GetUserFromContext(ctx) if !ok { return PostGamePlaySubmit401JSONResponse{ Message: "Unauthorized", @@ -120,7 +120,7 @@ func (h *HandlerWrapper) PostLogin(ctx context.Context, request PostLoginRequest } func (h *HandlerWrapper) PostLogout(ctx context.Context, request PostLogoutRequestObject) (PostLogoutResponseObject, error) { - user, ok := GetJWTClaimsFromContext(ctx) + user, ok := GetUserFromContext(ctx) if !ok { return PostLogout401JSONResponse{ Message: "Unauthorized", diff --git a/backend/auth/jwt.go b/backend/auth/jwt.go deleted file mode 100644 index 217d384..0000000 --- a/backend/auth/jwt.go +++ /dev/null @@ -1,60 +0,0 @@ -package auth - -import ( - "errors" - "os" - "time" - - "github.com/golang-jwt/jwt/v5" - - "albatross-2026-backend/db" -) - -var ( - jwtSecret []byte -) - -func init() { - jwtSecret = []byte(os.Getenv("ALBATROSS_JWT_SECRET")) - if len(jwtSecret) == 0 { - panic("ALBATROSS_JWT_SECRET is not set") - } -} - -type JWTClaims struct { - UserID int `json:"user_id"` - Username string `json:"username"` - DisplayName string `json:"display_name"` - IconPath *string `json:"icon_path"` - IsAdmin bool `json:"is_admin"` - jwt.RegisteredClaims -} - -func NewJWT(user *db.User) (string, error) { - claims := &JWTClaims{ - UserID: int(user.UserID), - Username: user.Username, - DisplayName: user.DisplayName, - IconPath: user.IconPath, - IsAdmin: user.IsAdmin, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * 24)), - }, - } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - return token.SignedString(jwtSecret) -} - -func ParseJWT(token string) (*JWTClaims, error) { - claims := new(JWTClaims) - t, err := jwt.ParseWithClaims(token, claims, func(*jwt.Token) (any, error) { - return jwtSecret, nil - }) - if err != nil { - return nil, err - } - if !t.Valid { - return nil, errors.New("invalid token") - } - return claims, nil -} diff --git a/backend/auth/session.go b/backend/auth/session.go new file mode 100644 index 0000000..a0d5aa4 --- /dev/null +++ b/backend/auth/session.go @@ -0,0 +1,21 @@ +package auth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "fmt" +) + +func GenerateSessionID() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("generate session ID: %w", err) + } + return hex.EncodeToString(b), nil +} + +func HashSessionID(raw string) string { + h := sha256.Sum256([]byte(raw)) + return hex.EncodeToString(h[:]) +} diff --git a/backend/db/models.go b/backend/db/models.go index c6ef25f..c4a713d 100644 --- a/backend/db/models.go +++ b/backend/db/models.go @@ -40,6 +40,13 @@ type Problem struct { SampleCode string } +type Session struct { + SessionID string + UserID int32 + ExpiresAt pgtype.Timestamp + CreatedAt pgtype.Timestamp +} + type Submission struct { SubmissionID int32 GameID int32 diff --git a/backend/db/query.sql.go b/backend/db/query.sql.go index 6ec3aa4..1d6d11c 100644 --- a/backend/db/query.sql.go +++ b/backend/db/query.sql.go @@ -103,6 +103,21 @@ func (q *Queries) CreateProblem(ctx context.Context, arg CreateProblemParams) (i return problem_id, err } +const createSession = `-- name: CreateSession :exec +INSERT INTO sessions (session_id, user_id, expires_at) VALUES ($1, $2, $3) +` + +type CreateSessionParams struct { + SessionID string + UserID int32 + ExpiresAt pgtype.Timestamp +} + +func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) error { + _, err := q.db.Exec(ctx, createSession, arg.SessionID, arg.UserID, arg.ExpiresAt) + return err +} + const createSubmission = `-- name: CreateSubmission :one INSERT INTO submissions (game_id, user_id, code, code_size, status) VALUES ($1, $2, $3, $4, 'running') @@ -199,6 +214,24 @@ func (q *Queries) CreateUserAuth(ctx context.Context, arg CreateUserAuthParams) return err } +const deleteExpiredSessions = `-- name: DeleteExpiredSessions :exec +DELETE FROM sessions WHERE expires_at < NOW() +` + +func (q *Queries) DeleteExpiredSessions(ctx context.Context) error { + _, err := q.db.Exec(ctx, deleteExpiredSessions) + return err +} + +const deleteSession = `-- name: DeleteSession :exec +DELETE FROM sessions WHERE session_id = $1 +` + +func (q *Queries) DeleteSession(ctx context.Context, sessionID string) error { + _, err := q.db.Exec(ctx, deleteSession, sessionID) + return err +} + const deleteTestcase = `-- name: DeleteTestcase :exec DELETE FROM testcases WHERE testcase_id = $1 @@ -671,6 +704,27 @@ func (q *Queries) GetUserByID(ctx context.Context, userID int32) (User, error) { return i, err } +const getUserBySession = `-- name: GetUserBySession :one +SELECT users.user_id, users.username, users.display_name, users.icon_path, users.is_admin, users.label, users.created_at FROM sessions +JOIN users ON sessions.user_id = users.user_id +WHERE sessions.session_id = $1 AND sessions.expires_at > NOW() +` + +func (q *Queries) GetUserBySession(ctx context.Context, sessionID string) (User, error) { + row := q.db.QueryRow(ctx, getUserBySession, sessionID) + var i User + err := row.Scan( + &i.UserID, + &i.Username, + &i.DisplayName, + &i.IconPath, + &i.IsAdmin, + &i.Label, + &i.CreatedAt, + ) + return i, err +} + const getUserIDByUsername = `-- name: GetUserIDByUsername :one SELECT user_id FROM users WHERE users.username = $1 diff --git a/backend/gen/api/handler_wrapper_gen.go b/backend/gen/api/handler_wrapper_gen.go index c6e3e8a..3a9d31f 100644 --- a/backend/gen/api/handler_wrapper_gen.go +++ b/backend/gen/api/handler_wrapper_gen.go @@ -128,7 +128,7 @@ func NewHandler(queries *db.Queries, hub GameHubInterface, conf *config.Config) {{ range . }} func (h *HandlerWrapper) {{ .Name }}(ctx context.Context, request {{ .Name }}RequestObject) ({{ .Name }}ResponseObject, error) { {{ if .RequiresLogin -}} - user, ok := GetJWTClaimsFromContext(ctx) + user, ok := GetUserFromContext(ctx) if !ok { return {{ .Name }}401JSONResponse{ Message: "Unauthorized", diff --git a/backend/go.mod b/backend/go.mod index 388f706..3c73ff6 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -4,7 +4,6 @@ go 1.25.0 require ( github.com/getkin/kin-openapi v0.133.0 - github.com/golang-jwt/jwt/v5 v5.3.1 github.com/hibiken/asynq v0.26.0 github.com/jackc/pgx/v5 v5.8.0 github.com/labstack/echo/v4 v4.15.0 diff --git a/backend/go.sum b/backend/go.sum index c533670..c170256 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -247,8 +247,6 @@ github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJA github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E= github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= -github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= diff --git a/backend/main.go b/backend/main.go index 311c3dd..c1b670a 100644 --- a/backend/main.go +++ b/backend/main.go @@ -73,13 +73,14 @@ func main() { apiGroup := e.Group(conf.BasePath + "api") apiGroup.Use(ratelimit.LoginRateLimitMiddleware(loginRL)) - apiGroup.Use(api.JWTCookieMiddleware) + apiGroup.Use(api.SessionCookieMiddleware(queries)) apiGroup.Use(oapimiddleware.OapiRequestValidator(openAPISpec)) apiHandler := api.NewHandler(queries, gameHub, conf) api.RegisterHandlers(apiGroup, api.NewStrictHandler(apiHandler, nil)) adminHandler := admin.NewHandler(queries, conf) adminGroup := e.Group(conf.BasePath + "admin") + adminGroup.Use(api.SessionCookieMiddleware(queries)) adminHandler.RegisterHandlers(adminGroup) if conf.IsLocal { @@ -104,6 +105,23 @@ func main() { })) } + sessionCleanupCtx, cancelSessionCleanup := context.WithCancel(context.Background()) + defer cancelSessionCleanup() + go func() { + ticker := time.NewTicker(time.Hour) + defer ticker.Stop() + for { + select { + case <-sessionCleanupCtx.Done(): + return + case <-ticker.C: + if err := queries.DeleteExpiredSessions(sessionCleanupCtx); err != nil { + log.Printf("failed to delete expired sessions: %v", err) + } + } + } + }() + go gameHub.Run() if err := e.Start(":80"); err != http.ErrServerClosed { diff --git a/backend/query.sql b/backend/query.sql index 0d84652..4297e42 100644 --- a/backend/query.sql +++ b/backend/query.sql @@ -276,3 +276,17 @@ SELECT * FROM testcase_results WHERE submission_id = $1 ORDER BY created_at; + +-- name: CreateSession :exec +INSERT INTO sessions (session_id, user_id, expires_at) VALUES ($1, $2, $3); + +-- name: GetUserBySession :one +SELECT users.* FROM sessions +JOIN users ON sessions.user_id = users.user_id +WHERE sessions.session_id = $1 AND sessions.expires_at > NOW(); + +-- name: DeleteSession :exec +DELETE FROM sessions WHERE session_id = $1; + +-- name: DeleteExpiredSessions :exec +DELETE FROM sessions WHERE expires_at < NOW(); diff --git a/backend/schema.sql b/backend/schema.sql index 5e427ce..4a4b1ac 100644 --- a/backend/schema.sql +++ b/backend/schema.sql @@ -94,3 +94,13 @@ CREATE TABLE testcase_results ( CONSTRAINT uq_submission_id_testcase_id UNIQUE(submission_id, testcase_id) ); CREATE INDEX idx_testcase_results_submission_id ON testcase_results(submission_id); + +CREATE TABLE sessions ( + session_id VARCHAR(64) PRIMARY KEY, + user_id INT NOT NULL, + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + CONSTRAINT fk_sessions_user_id FOREIGN KEY(user_id) REFERENCES users(user_id) ON DELETE CASCADE +); +CREATE INDEX idx_sessions_user_id ON sessions(user_id); +CREATE INDEX idx_sessions_expires_at ON sessions(expires_at); |
