aboutsummaryrefslogtreecommitdiffhomepage
path: root/backend
diff options
context:
space:
mode:
Diffstat (limited to 'backend')
-rw-r--r--backend/admin/handler.go12
-rw-r--r--backend/api/auth_middleware.go42
-rw-r--r--backend/api/handler.go76
-rw-r--r--backend/api/handler_wrapper.go20
-rw-r--r--backend/auth/jwt.go60
-rw-r--r--backend/auth/session.go21
-rw-r--r--backend/db/models.go7
-rw-r--r--backend/db/query.sql.go54
-rw-r--r--backend/gen/api/handler_wrapper_gen.go2
-rw-r--r--backend/go.mod1
-rw-r--r--backend/go.sum2
-rw-r--r--backend/main.go20
-rw-r--r--backend/query.sql14
-rw-r--r--backend/schema.sql10
14 files changed, 208 insertions, 133 deletions
diff --git a/backend/admin/handler.go b/backend/admin/handler.go
index 28e7970..a18e32a 100644
--- a/backend/admin/handler.go
+++ b/backend/admin/handler.go
@@ -13,7 +13,7 @@ import (
"github.com/labstack/echo/v4"
"albatross-2026-backend/account"
- "albatross-2026-backend/auth"
+ "albatross-2026-backend/api"
"albatross-2026-backend/config"
"albatross-2026-backend/db"
)
@@ -32,15 +32,11 @@ func NewHandler(q *db.Queries, conf *config.Config) *Handler {
func (h *Handler) newAdminMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
- jwt, err := c.Cookie("albatross_token")
- if err != nil {
- return c.Redirect(http.StatusSeeOther, h.conf.BasePath+"login")
- }
- claims, err := auth.ParseJWT(jwt.Value)
- if err != nil {
+ user, ok := api.GetUserFromContext(c.Request().Context())
+ if !ok {
return c.Redirect(http.StatusSeeOther, h.conf.BasePath+"login")
}
- if !claims.IsAdmin {
+ if !user.IsAdmin {
return echo.NewHTTPError(http.StatusForbidden)
}
return next(c)
diff --git a/backend/api/auth_middleware.go b/backend/api/auth_middleware.go
index 97f8946..d721f1d 100644
--- a/backend/api/auth_middleware.go
+++ b/backend/api/auth_middleware.go
@@ -6,27 +6,39 @@ import (
"github.com/labstack/echo/v4"
"albatross-2026-backend/auth"
+ "albatross-2026-backend/db"
)
-type contextKey struct{}
+type sessionIDContextKey struct{}
+type userContextKey struct{}
-func JWTCookieMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c echo.Context) error {
- cookie, err := c.Cookie("albatross_token")
- if err != nil {
+func SessionCookieMiddleware(q *db.Queries) echo.MiddlewareFunc {
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ cookie, err := c.Cookie("albatross_session")
+ if err != nil {
+ return next(c)
+ }
+ hashedID := auth.HashSessionID(cookie.Value)
+ user, err := q.GetUserBySession(c.Request().Context(), hashedID)
+ if err != nil {
+ return next(c)
+ }
+ ctx := c.Request().Context()
+ ctx = context.WithValue(ctx, sessionIDContextKey{}, hashedID)
+ ctx = context.WithValue(ctx, userContextKey{}, &user)
+ c.SetRequest(c.Request().WithContext(ctx))
return next(c)
}
- claims, err := auth.ParseJWT(cookie.Value)
- if err != nil {
- return next(c)
- }
- ctx := context.WithValue(c.Request().Context(), contextKey{}, claims)
- c.SetRequest(c.Request().WithContext(ctx))
- return next(c)
}
}
-func GetJWTClaimsFromContext(ctx context.Context) (*auth.JWTClaims, bool) {
- claims, ok := ctx.Value(contextKey{}).(*auth.JWTClaims)
- return claims, ok
+func GetSessionIDFromContext(ctx context.Context) (string, bool) {
+ sessionID, ok := ctx.Value(sessionIDContextKey{}).(string)
+ return sessionID, ok
+}
+
+func GetUserFromContext(ctx context.Context) (*db.User, bool) {
+ user, ok := ctx.Value(userContextKey{}).(*db.User)
+ return user, ok
}
diff --git a/backend/api/handler.go b/backend/api/handler.go
index 25aea01..9f8849c 100644
--- a/backend/api/handler.go
+++ b/backend/api/handler.go
@@ -7,8 +7,10 @@ import (
"log"
"net/http"
"strconv"
+ "time"
"github.com/jackc/pgx/v5"
+ "github.com/jackc/pgx/v5/pgtype"
"github.com/labstack/echo/v4"
"github.com/oapi-codegen/nullable"
@@ -64,15 +66,25 @@ func (h *Handler) PostLogin(ctx context.Context, request PostLoginRequestObject)
}, nil
}
- jwt, err := auth.NewJWT(&dbUser)
+ sessionID, err := auth.GenerateSessionID()
if err != nil {
return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
+ hashedID := auth.HashSessionID(sessionID)
+ expiresAt := pgtype.Timestamp{Time: time.Now().Add(24 * time.Hour), Valid: true}
+ if err := h.q.CreateSession(ctx, db.CreateSessionParams{
+ SessionID: hashedID,
+ UserID: dbUser.UserID,
+ ExpiresAt: expiresAt,
+ }); err != nil {
+ return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
+ }
+
return postLoginCookieResponse{
cookie: http.Cookie{
- Name: "albatross_token",
- Value: jwt,
+ Name: "albatross_session",
+ Value: sessionID,
Path: h.conf.BasePath,
MaxAge: 86400,
HttpOnly: true,
@@ -92,21 +104,15 @@ func (h *Handler) PostLogin(ctx context.Context, request PostLoginRequestObject)
}, nil
}
-func (h *Handler) GetMe(ctx context.Context, _ GetMeRequestObject, claims *auth.JWTClaims) (GetMeResponseObject, error) {
- dbUser, err := h.q.GetUserByID(ctx, int32(claims.UserID))
- if err != nil {
- return GetMe401JSONResponse{
- Message: "Unauthorized",
- }, nil
- }
+func (h *Handler) GetMe(_ context.Context, _ GetMeRequestObject, user *db.User) (GetMeResponseObject, error) {
return GetMe200JSONResponse{
User: User{
- UserID: int(dbUser.UserID),
- Username: dbUser.Username,
- DisplayName: dbUser.DisplayName,
- IconPath: dbUser.IconPath,
- IsAdmin: dbUser.IsAdmin,
- Label: toNullable(dbUser.Label),
+ UserID: int(user.UserID),
+ Username: user.Username,
+ DisplayName: user.DisplayName,
+ IconPath: user.IconPath,
+ IsAdmin: user.IsAdmin,
+ Label: toNullable(user.Label),
},
}, nil
}
@@ -121,10 +127,13 @@ func (r postLogoutCookieResponse) VisitPostLogoutResponse(w http.ResponseWriter)
return nil
}
-func (h *Handler) PostLogout(_ context.Context, _ PostLogoutRequestObject, _ *auth.JWTClaims) (PostLogoutResponseObject, error) {
+func (h *Handler) PostLogout(ctx context.Context, _ PostLogoutRequestObject, _ *db.User) (PostLogoutResponseObject, error) {
+ if sessionID, ok := GetSessionIDFromContext(ctx); ok {
+ _ = h.q.DeleteSession(ctx, sessionID)
+ }
return postLogoutCookieResponse{
cookie: http.Cookie{
- Name: "albatross_token",
+ Name: "albatross_session",
Value: "",
Path: h.conf.BasePath,
MaxAge: -1,
@@ -135,7 +144,7 @@ func (h *Handler) PostLogout(_ context.Context, _ PostLogoutRequestObject, _ *au
}, nil
}
-func (h *Handler) GetGames(ctx context.Context, _ GetGamesRequestObject, _ *auth.JWTClaims) (GetGamesResponseObject, error) {
+func (h *Handler) GetGames(ctx context.Context, _ GetGamesRequestObject, _ *db.User) (GetGamesResponseObject, error) {
gameRows, err := h.q.ListPublicGames(ctx)
if err != nil {
return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
@@ -188,7 +197,7 @@ func (h *Handler) GetGames(ctx context.Context, _ GetGamesRequestObject, _ *auth
}, nil
}
-func (h *Handler) GetGame(ctx context.Context, request GetGameRequestObject, user *auth.JWTClaims) (GetGameResponseObject, error) {
+func (h *Handler) GetGame(ctx context.Context, request GetGameRequestObject, user *db.User) (GetGameResponseObject, error) {
gameID := request.GameID
row, err := h.q.GetGameByID(ctx, int32(gameID))
if err != nil {
@@ -245,12 +254,11 @@ func (h *Handler) GetGame(ctx context.Context, request GetGameRequestObject, use
}, nil
}
-func (h *Handler) GetGamePlayLatestState(ctx context.Context, request GetGamePlayLatestStateRequestObject, user *auth.JWTClaims) (GetGamePlayLatestStateResponseObject, error) {
+func (h *Handler) GetGamePlayLatestState(ctx context.Context, request GetGamePlayLatestStateRequestObject, user *db.User) (GetGamePlayLatestStateResponseObject, error) {
gameID := request.GameID
- userID := user.UserID
row, err := h.q.GetLatestState(ctx, db.GetLatestStateParams{
GameID: int32(gameID),
- UserID: int32(userID),
+ UserID: user.UserID,
})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
@@ -275,7 +283,7 @@ func (h *Handler) GetGamePlayLatestState(ctx context.Context, request GetGamePla
}, nil
}
-func (h *Handler) GetGameWatchLatestStates(ctx context.Context, request GetGameWatchLatestStatesRequestObject, user *auth.JWTClaims) (GetGameWatchLatestStatesResponseObject, error) {
+func (h *Handler) GetGameWatchLatestStates(ctx context.Context, request GetGameWatchLatestStatesRequestObject, user *db.User) (GetGameWatchLatestStatesResponseObject, error) {
gameID := request.GameID
rows, err := h.q.GetLatestStatesOfMainPlayers(ctx, int32(gameID))
if err != nil {
@@ -300,7 +308,7 @@ func (h *Handler) GetGameWatchLatestStates(ctx context.Context, request GetGameW
Status: status,
}
- if int(row.UserID) == user.UserID && !user.IsAdmin {
+ if row.UserID == user.UserID && !user.IsAdmin {
return GetGameWatchLatestStates403JSONResponse{
Message: "You are one of the main players of this game",
}, nil
@@ -311,7 +319,7 @@ func (h *Handler) GetGameWatchLatestStates(ctx context.Context, request GetGameW
}, nil
}
-func (h *Handler) GetGameWatchRanking(ctx context.Context, request GetGameWatchRankingRequestObject, _ *auth.JWTClaims) (GetGameWatchRankingResponseObject, error) {
+func (h *Handler) GetGameWatchRanking(ctx context.Context, request GetGameWatchRankingRequestObject, _ *db.User) (GetGameWatchRankingResponseObject, error) {
gameID := request.GameID
rows, err := h.q.GetRanking(ctx, int32(gameID))
if err != nil {
@@ -342,13 +350,12 @@ func (h *Handler) GetGameWatchRanking(ctx context.Context, request GetGameWatchR
}, nil
}
-func (h *Handler) PostGamePlayCode(ctx context.Context, request PostGamePlayCodeRequestObject, user *auth.JWTClaims) (PostGamePlayCodeResponseObject, error) {
+func (h *Handler) PostGamePlayCode(ctx context.Context, request PostGamePlayCodeRequestObject, user *db.User) (PostGamePlayCodeResponseObject, error) {
gameID := request.GameID
- userID := user.UserID
// TODO: check if the game is running
err := h.q.UpdateCode(ctx, db.UpdateCodeParams{
GameID: int32(gameID),
- UserID: int32(userID),
+ UserID: user.UserID,
Code: request.Body.Code,
Status: "none",
})
@@ -358,9 +365,8 @@ func (h *Handler) PostGamePlayCode(ctx context.Context, request PostGamePlayCode
return PostGamePlayCode200Response{}, nil
}
-func (h *Handler) PostGamePlaySubmit(ctx context.Context, request PostGamePlaySubmitRequestObject, user *auth.JWTClaims) (PostGamePlaySubmitResponseObject, error) {
+func (h *Handler) PostGamePlaySubmit(ctx context.Context, request PostGamePlaySubmitRequestObject, user *db.User) (PostGamePlaySubmitResponseObject, error) {
gameID := request.GameID
- userID := user.UserID
code := request.Body.Code
gameRow, err := h.q.GetGameByID(ctx, int32(gameID))
@@ -377,7 +383,7 @@ func (h *Handler) PostGamePlaySubmit(ctx context.Context, request PostGamePlaySu
// TODO: transaction
err = h.q.UpdateCodeAndStatus(ctx, db.UpdateCodeAndStatusParams{
GameID: int32(gameID),
- UserID: int32(userID),
+ UserID: user.UserID,
Code: code,
Status: "running",
})
@@ -386,21 +392,21 @@ func (h *Handler) PostGamePlaySubmit(ctx context.Context, request PostGamePlaySu
}
submissionID, err := h.q.CreateSubmission(ctx, db.CreateSubmissionParams{
GameID: int32(gameID),
- UserID: int32(userID),
+ UserID: user.UserID,
Code: code,
CodeSize: int32(codeSize),
})
if err != nil {
return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
- err = h.hub.EnqueueTestTasks(ctx, int(submissionID), gameID, userID, language, code)
+ err = h.hub.EnqueueTestTasks(ctx, int(submissionID), gameID, int(user.UserID), language, code)
if err != nil {
return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return PostGamePlaySubmit200Response{}, nil
}
-func (h *Handler) GetTournament(ctx context.Context, request GetTournamentRequestObject, _ *auth.JWTClaims) (GetTournamentResponseObject, error) {
+func (h *Handler) GetTournament(ctx context.Context, request GetTournamentRequestObject, _ *db.User) (GetTournamentResponseObject, error) {
gameIDs := []int32{
int32(request.Params.Game1),
int32(request.Params.Game2),
diff --git a/backend/api/handler_wrapper.go b/backend/api/handler_wrapper.go
index 5feaac7..8e3e8cd 100644
--- a/backend/api/handler_wrapper.go
+++ b/backend/api/handler_wrapper.go
@@ -26,7 +26,7 @@ func NewHandler(queries *db.Queries, hub GameHubInterface, conf *config.Config)
}
func (h *HandlerWrapper) GetGame(ctx context.Context, request GetGameRequestObject) (GetGameResponseObject, error) {
- user, ok := GetJWTClaimsFromContext(ctx)
+ user, ok := GetUserFromContext(ctx)
if !ok {
return GetGame401JSONResponse{
Message: "Unauthorized",
@@ -36,7 +36,7 @@ func (h *HandlerWrapper) GetGame(ctx context.Context, request GetGameRequestObje
}
func (h *HandlerWrapper) GetGamePlayLatestState(ctx context.Context, request GetGamePlayLatestStateRequestObject) (GetGamePlayLatestStateResponseObject, error) {
- user, ok := GetJWTClaimsFromContext(ctx)
+ user, ok := GetUserFromContext(ctx)
if !ok {
return GetGamePlayLatestState401JSONResponse{
Message: "Unauthorized",
@@ -46,7 +46,7 @@ func (h *HandlerWrapper) GetGamePlayLatestState(ctx context.Context, request Get
}
func (h *HandlerWrapper) GetGameWatchLatestStates(ctx context.Context, request GetGameWatchLatestStatesRequestObject) (GetGameWatchLatestStatesResponseObject, error) {
- user, ok := GetJWTClaimsFromContext(ctx)
+ user, ok := GetUserFromContext(ctx)
if !ok {
return GetGameWatchLatestStates401JSONResponse{
Message: "Unauthorized",
@@ -56,7 +56,7 @@ func (h *HandlerWrapper) GetGameWatchLatestStates(ctx context.Context, request G
}
func (h *HandlerWrapper) GetGameWatchRanking(ctx context.Context, request GetGameWatchRankingRequestObject) (GetGameWatchRankingResponseObject, error) {
- user, ok := GetJWTClaimsFromContext(ctx)
+ user, ok := GetUserFromContext(ctx)
if !ok {
return GetGameWatchRanking401JSONResponse{
Message: "Unauthorized",
@@ -66,7 +66,7 @@ func (h *HandlerWrapper) GetGameWatchRanking(ctx context.Context, request GetGam
}
func (h *HandlerWrapper) GetGames(ctx context.Context, request GetGamesRequestObject) (GetGamesResponseObject, error) {
- user, ok := GetJWTClaimsFromContext(ctx)
+ user, ok := GetUserFromContext(ctx)
if !ok {
return GetGames401JSONResponse{
Message: "Unauthorized",
@@ -76,7 +76,7 @@ func (h *HandlerWrapper) GetGames(ctx context.Context, request GetGamesRequestOb
}
func (h *HandlerWrapper) GetMe(ctx context.Context, request GetMeRequestObject) (GetMeResponseObject, error) {
- user, ok := GetJWTClaimsFromContext(ctx)
+ user, ok := GetUserFromContext(ctx)
if !ok {
return GetMe401JSONResponse{
Message: "Unauthorized",
@@ -86,7 +86,7 @@ func (h *HandlerWrapper) GetMe(ctx context.Context, request GetMeRequestObject)
}
func (h *HandlerWrapper) GetTournament(ctx context.Context, request GetTournamentRequestObject) (GetTournamentResponseObject, error) {
- user, ok := GetJWTClaimsFromContext(ctx)
+ user, ok := GetUserFromContext(ctx)
if !ok {
return GetTournament401JSONResponse{
Message: "Unauthorized",
@@ -96,7 +96,7 @@ func (h *HandlerWrapper) GetTournament(ctx context.Context, request GetTournamen
}
func (h *HandlerWrapper) PostGamePlayCode(ctx context.Context, request PostGamePlayCodeRequestObject) (PostGamePlayCodeResponseObject, error) {
- user, ok := GetJWTClaimsFromContext(ctx)
+ user, ok := GetUserFromContext(ctx)
if !ok {
return PostGamePlayCode401JSONResponse{
Message: "Unauthorized",
@@ -106,7 +106,7 @@ func (h *HandlerWrapper) PostGamePlayCode(ctx context.Context, request PostGameP
}
func (h *HandlerWrapper) PostGamePlaySubmit(ctx context.Context, request PostGamePlaySubmitRequestObject) (PostGamePlaySubmitResponseObject, error) {
- user, ok := GetJWTClaimsFromContext(ctx)
+ user, ok := GetUserFromContext(ctx)
if !ok {
return PostGamePlaySubmit401JSONResponse{
Message: "Unauthorized",
@@ -120,7 +120,7 @@ func (h *HandlerWrapper) PostLogin(ctx context.Context, request PostLoginRequest
}
func (h *HandlerWrapper) PostLogout(ctx context.Context, request PostLogoutRequestObject) (PostLogoutResponseObject, error) {
- user, ok := GetJWTClaimsFromContext(ctx)
+ user, ok := GetUserFromContext(ctx)
if !ok {
return PostLogout401JSONResponse{
Message: "Unauthorized",
diff --git a/backend/auth/jwt.go b/backend/auth/jwt.go
deleted file mode 100644
index 217d384..0000000
--- a/backend/auth/jwt.go
+++ /dev/null
@@ -1,60 +0,0 @@
-package auth
-
-import (
- "errors"
- "os"
- "time"
-
- "github.com/golang-jwt/jwt/v5"
-
- "albatross-2026-backend/db"
-)
-
-var (
- jwtSecret []byte
-)
-
-func init() {
- jwtSecret = []byte(os.Getenv("ALBATROSS_JWT_SECRET"))
- if len(jwtSecret) == 0 {
- panic("ALBATROSS_JWT_SECRET is not set")
- }
-}
-
-type JWTClaims struct {
- UserID int `json:"user_id"`
- Username string `json:"username"`
- DisplayName string `json:"display_name"`
- IconPath *string `json:"icon_path"`
- IsAdmin bool `json:"is_admin"`
- jwt.RegisteredClaims
-}
-
-func NewJWT(user *db.User) (string, error) {
- claims := &JWTClaims{
- UserID: int(user.UserID),
- Username: user.Username,
- DisplayName: user.DisplayName,
- IconPath: user.IconPath,
- IsAdmin: user.IsAdmin,
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * 24)),
- },
- }
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- return token.SignedString(jwtSecret)
-}
-
-func ParseJWT(token string) (*JWTClaims, error) {
- claims := new(JWTClaims)
- t, err := jwt.ParseWithClaims(token, claims, func(*jwt.Token) (any, error) {
- return jwtSecret, nil
- })
- if err != nil {
- return nil, err
- }
- if !t.Valid {
- return nil, errors.New("invalid token")
- }
- return claims, nil
-}
diff --git a/backend/auth/session.go b/backend/auth/session.go
new file mode 100644
index 0000000..a0d5aa4
--- /dev/null
+++ b/backend/auth/session.go
@@ -0,0 +1,21 @@
+package auth
+
+import (
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/hex"
+ "fmt"
+)
+
+func GenerateSessionID() (string, error) {
+ b := make([]byte, 32)
+ if _, err := rand.Read(b); err != nil {
+ return "", fmt.Errorf("generate session ID: %w", err)
+ }
+ return hex.EncodeToString(b), nil
+}
+
+func HashSessionID(raw string) string {
+ h := sha256.Sum256([]byte(raw))
+ return hex.EncodeToString(h[:])
+}
diff --git a/backend/db/models.go b/backend/db/models.go
index c6ef25f..c4a713d 100644
--- a/backend/db/models.go
+++ b/backend/db/models.go
@@ -40,6 +40,13 @@ type Problem struct {
SampleCode string
}
+type Session struct {
+ SessionID string
+ UserID int32
+ ExpiresAt pgtype.Timestamp
+ CreatedAt pgtype.Timestamp
+}
+
type Submission struct {
SubmissionID int32
GameID int32
diff --git a/backend/db/query.sql.go b/backend/db/query.sql.go
index 6ec3aa4..1d6d11c 100644
--- a/backend/db/query.sql.go
+++ b/backend/db/query.sql.go
@@ -103,6 +103,21 @@ func (q *Queries) CreateProblem(ctx context.Context, arg CreateProblemParams) (i
return problem_id, err
}
+const createSession = `-- name: CreateSession :exec
+INSERT INTO sessions (session_id, user_id, expires_at) VALUES ($1, $2, $3)
+`
+
+type CreateSessionParams struct {
+ SessionID string
+ UserID int32
+ ExpiresAt pgtype.Timestamp
+}
+
+func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) error {
+ _, err := q.db.Exec(ctx, createSession, arg.SessionID, arg.UserID, arg.ExpiresAt)
+ return err
+}
+
const createSubmission = `-- name: CreateSubmission :one
INSERT INTO submissions (game_id, user_id, code, code_size, status)
VALUES ($1, $2, $3, $4, 'running')
@@ -199,6 +214,24 @@ func (q *Queries) CreateUserAuth(ctx context.Context, arg CreateUserAuthParams)
return err
}
+const deleteExpiredSessions = `-- name: DeleteExpiredSessions :exec
+DELETE FROM sessions WHERE expires_at < NOW()
+`
+
+func (q *Queries) DeleteExpiredSessions(ctx context.Context) error {
+ _, err := q.db.Exec(ctx, deleteExpiredSessions)
+ return err
+}
+
+const deleteSession = `-- name: DeleteSession :exec
+DELETE FROM sessions WHERE session_id = $1
+`
+
+func (q *Queries) DeleteSession(ctx context.Context, sessionID string) error {
+ _, err := q.db.Exec(ctx, deleteSession, sessionID)
+ return err
+}
+
const deleteTestcase = `-- name: DeleteTestcase :exec
DELETE FROM testcases
WHERE testcase_id = $1
@@ -671,6 +704,27 @@ func (q *Queries) GetUserByID(ctx context.Context, userID int32) (User, error) {
return i, err
}
+const getUserBySession = `-- name: GetUserBySession :one
+SELECT users.user_id, users.username, users.display_name, users.icon_path, users.is_admin, users.label, users.created_at FROM sessions
+JOIN users ON sessions.user_id = users.user_id
+WHERE sessions.session_id = $1 AND sessions.expires_at > NOW()
+`
+
+func (q *Queries) GetUserBySession(ctx context.Context, sessionID string) (User, error) {
+ row := q.db.QueryRow(ctx, getUserBySession, sessionID)
+ var i User
+ err := row.Scan(
+ &i.UserID,
+ &i.Username,
+ &i.DisplayName,
+ &i.IconPath,
+ &i.IsAdmin,
+ &i.Label,
+ &i.CreatedAt,
+ )
+ return i, err
+}
+
const getUserIDByUsername = `-- name: GetUserIDByUsername :one
SELECT user_id FROM users
WHERE users.username = $1
diff --git a/backend/gen/api/handler_wrapper_gen.go b/backend/gen/api/handler_wrapper_gen.go
index c6e3e8a..3a9d31f 100644
--- a/backend/gen/api/handler_wrapper_gen.go
+++ b/backend/gen/api/handler_wrapper_gen.go
@@ -128,7 +128,7 @@ func NewHandler(queries *db.Queries, hub GameHubInterface, conf *config.Config)
{{ range . }}
func (h *HandlerWrapper) {{ .Name }}(ctx context.Context, request {{ .Name }}RequestObject) ({{ .Name }}ResponseObject, error) {
{{ if .RequiresLogin -}}
- user, ok := GetJWTClaimsFromContext(ctx)
+ user, ok := GetUserFromContext(ctx)
if !ok {
return {{ .Name }}401JSONResponse{
Message: "Unauthorized",
diff --git a/backend/go.mod b/backend/go.mod
index 388f706..3c73ff6 100644
--- a/backend/go.mod
+++ b/backend/go.mod
@@ -4,7 +4,6 @@ go 1.25.0
require (
github.com/getkin/kin-openapi v0.133.0
- github.com/golang-jwt/jwt/v5 v5.3.1
github.com/hibiken/asynq v0.26.0
github.com/jackc/pgx/v5 v5.8.0
github.com/labstack/echo/v4 v4.15.0
diff --git a/backend/go.sum b/backend/go.sum
index c533670..c170256 100644
--- a/backend/go.sum
+++ b/backend/go.sum
@@ -247,8 +247,6 @@ github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJA
github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E=
github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
-github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
-github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
diff --git a/backend/main.go b/backend/main.go
index 311c3dd..c1b670a 100644
--- a/backend/main.go
+++ b/backend/main.go
@@ -73,13 +73,14 @@ func main() {
apiGroup := e.Group(conf.BasePath + "api")
apiGroup.Use(ratelimit.LoginRateLimitMiddleware(loginRL))
- apiGroup.Use(api.JWTCookieMiddleware)
+ apiGroup.Use(api.SessionCookieMiddleware(queries))
apiGroup.Use(oapimiddleware.OapiRequestValidator(openAPISpec))
apiHandler := api.NewHandler(queries, gameHub, conf)
api.RegisterHandlers(apiGroup, api.NewStrictHandler(apiHandler, nil))
adminHandler := admin.NewHandler(queries, conf)
adminGroup := e.Group(conf.BasePath + "admin")
+ adminGroup.Use(api.SessionCookieMiddleware(queries))
adminHandler.RegisterHandlers(adminGroup)
if conf.IsLocal {
@@ -104,6 +105,23 @@ func main() {
}))
}
+ sessionCleanupCtx, cancelSessionCleanup := context.WithCancel(context.Background())
+ defer cancelSessionCleanup()
+ go func() {
+ ticker := time.NewTicker(time.Hour)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-sessionCleanupCtx.Done():
+ return
+ case <-ticker.C:
+ if err := queries.DeleteExpiredSessions(sessionCleanupCtx); err != nil {
+ log.Printf("failed to delete expired sessions: %v", err)
+ }
+ }
+ }
+ }()
+
go gameHub.Run()
if err := e.Start(":80"); err != http.ErrServerClosed {
diff --git a/backend/query.sql b/backend/query.sql
index 0d84652..4297e42 100644
--- a/backend/query.sql
+++ b/backend/query.sql
@@ -276,3 +276,17 @@ SELECT *
FROM testcase_results
WHERE submission_id = $1
ORDER BY created_at;
+
+-- name: CreateSession :exec
+INSERT INTO sessions (session_id, user_id, expires_at) VALUES ($1, $2, $3);
+
+-- name: GetUserBySession :one
+SELECT users.* FROM sessions
+JOIN users ON sessions.user_id = users.user_id
+WHERE sessions.session_id = $1 AND sessions.expires_at > NOW();
+
+-- name: DeleteSession :exec
+DELETE FROM sessions WHERE session_id = $1;
+
+-- name: DeleteExpiredSessions :exec
+DELETE FROM sessions WHERE expires_at < NOW();
diff --git a/backend/schema.sql b/backend/schema.sql
index 5e427ce..4a4b1ac 100644
--- a/backend/schema.sql
+++ b/backend/schema.sql
@@ -94,3 +94,13 @@ CREATE TABLE testcase_results (
CONSTRAINT uq_submission_id_testcase_id UNIQUE(submission_id, testcase_id)
);
CREATE INDEX idx_testcase_results_submission_id ON testcase_results(submission_id);
+
+CREATE TABLE sessions (
+ session_id VARCHAR(64) PRIMARY KEY,
+ user_id INT NOT NULL,
+ expires_at TIMESTAMP NOT NULL,
+ created_at TIMESTAMP NOT NULL DEFAULT NOW(),
+ CONSTRAINT fk_sessions_user_id FOREIGN KEY(user_id) REFERENCES users(user_id) ON DELETE CASCADE
+);
+CREATE INDEX idx_sessions_user_id ON sessions(user_id);
+CREATE INDEX idx_sessions_expires_at ON sessions(expires_at);