From db87f85aa7055e597800481b8cc6d006c70bcc88 Mon Sep 17 00:00:00 2001 From: nsfisis Date: Mon, 16 Feb 2026 22:02:58 +0900 Subject: test(backend): add unit tests for auth_middleware, fortee, processor, account, and more handlers Cover previously untested code: SessionCookieMiddleware, context helpers, downloadFile, addAcceptHeader, doProcessTaskRunTestcase, updateSubmissionAndGameState, PostLogout, GetGames, PostGamePlayCode, GetGameWatchRanking, GetGameWatchLatestStates. Co-Authored-By: Claude Opus 4.6 --- backend/account/icon_test.go | 94 +++++++ backend/api/auth_middleware_test.go | 180 +++++++++++++ backend/api/handler_test.go | 354 ++++++++++++++++++++++++++ backend/fortee/fortee_test.go | 44 ++++ backend/game/hub_test.go | 131 ++++++++++ backend/taskqueue/processor_doprocess_test.go | 204 +++++++++++++++ 6 files changed, 1007 insertions(+) create mode 100644 backend/account/icon_test.go create mode 100644 backend/api/auth_middleware_test.go create mode 100644 backend/fortee/fortee_test.go create mode 100644 backend/taskqueue/processor_doprocess_test.go (limited to 'backend') diff --git a/backend/account/icon_test.go b/backend/account/icon_test.go new file mode 100644 index 0000000..7f4ddbc --- /dev/null +++ b/backend/account/icon_test.go @@ -0,0 +1,94 @@ +package account + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" +) + +func TestDownloadFile_Success(t *testing.T) { + expectedContent := "file content here" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(expectedContent)) + })) + defer server.Close() + + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "subdir", "test.png") + + err := downloadFile(context.Background(), server.URL+"/icon.png", filePath) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("failed to read downloaded file: %v", err) + } + if string(data) != expectedContent { + t.Errorf("expected content %q, got %q", expectedContent, string(data)) + } +} + +func TestDownloadFile_NotFound(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "test.png") + + err := downloadFile(context.Background(), server.URL+"/missing.png", filePath) + if err == nil { + t.Error("expected error for 404 response") + } +} + +func TestDownloadFile_ServerError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "test.png") + + err := downloadFile(context.Background(), server.URL+"/error.png", filePath) + if err == nil { + t.Error("expected error for 500 response") + } +} + +func TestDownloadFile_InvalidURL(t *testing.T) { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "test.png") + + err := downloadFile(context.Background(), "http://localhost:1/unreachable", filePath) + if err == nil { + t.Error("expected error for unreachable server") + } +} + +func TestDownloadFile_ContextCanceled(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("data")) + })) + defer server.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "test.png") + + err := downloadFile(ctx, server.URL+"/icon.png", filePath) + if err == nil { + t.Error("expected error for canceled context") + } +} diff --git a/backend/api/auth_middleware_test.go b/backend/api/auth_middleware_test.go new file mode 100644 index 0000000..d84eef7 --- /dev/null +++ b/backend/api/auth_middleware_test.go @@ -0,0 +1,180 @@ +package api + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/labstack/echo/v4" + + "albatross-2026-backend/db" +) + +func TestGetSessionIDFromContext_NotSet(t *testing.T) { + ctx := context.Background() + _, ok := GetSessionIDFromContext(ctx) + if ok { + t.Error("expected ok=false when session ID is not set") + } +} + +func TestGetSessionIDFromContext_Set(t *testing.T) { + ctx := context.WithValue(context.Background(), sessionIDContextKey{}, "abc123") + id, ok := GetSessionIDFromContext(ctx) + if !ok { + t.Fatal("expected ok=true when session ID is set") + } + if id != "abc123" { + t.Errorf("expected session ID 'abc123', got %q", id) + } +} + +func TestGetUserFromContext_NotSet(t *testing.T) { + ctx := context.Background() + _, ok := GetUserFromContext(ctx) + if ok { + t.Error("expected ok=false when user is not set") + } +} + +func TestGetUserFromContext_Set(t *testing.T) { + user := &db.User{UserID: 42, Username: "testuser"} + ctx := context.WithValue(context.Background(), userContextKey{}, user) + u, ok := GetUserFromContext(ctx) + if !ok { + t.Fatal("expected ok=true when user is set") + } + if u.UserID != 42 { + t.Errorf("expected user ID 42, got %d", u.UserID) + } + if u.Username != "testuser" { + t.Errorf("expected username 'testuser', got %q", u.Username) + } +} + +func TestSetUserInContext(t *testing.T) { + user := &db.User{UserID: 7, Username: "admin"} + ctx := SetUserInContext(context.Background(), user) + u, ok := GetUserFromContext(ctx) + if !ok { + t.Fatal("expected ok=true after SetUserInContext") + } + if u.UserID != 7 { + t.Errorf("expected user ID 7, got %d", u.UserID) + } +} + +// mockSessionQuerier implements the subset of db.Querier needed by SessionCookieMiddleware. +type mockSessionQuerier struct { + db.Querier + getUserBySessionFunc func(ctx context.Context, sessionID string) (db.User, error) +} + +func (m *mockSessionQuerier) GetUserBySession(ctx context.Context, sessionID string) (db.User, error) { + if m.getUserBySessionFunc != nil { + return m.getUserBySessionFunc(ctx, sessionID) + } + return db.User{}, nil +} + +func TestSessionCookieMiddleware_NoCookie(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + mw := SessionCookieMiddleware(&mockSessionQuerier{}) + var called bool + handler := mw(func(c echo.Context) error { + called = true + // User should not be set + _, ok := GetUserFromContext(c.Request().Context()) + if ok { + t.Error("expected no user in context when no cookie is present") + } + return nil + }) + + if err := handler(c); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !called { + t.Error("next handler was not called") + } +} + +func TestSessionCookieMiddleware_ValidSession(t *testing.T) { + expectedUser := db.User{UserID: 10, Username: "sessionuser"} + mq := &mockSessionQuerier{ + getUserBySessionFunc: func(_ context.Context, _ string) (db.User, error) { + return expectedUser, nil + }, + } + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "albatross_session", Value: "raw-session-id"}) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + mw := SessionCookieMiddleware(mq) + var called bool + handler := mw(func(c echo.Context) error { + called = true + user, ok := GetUserFromContext(c.Request().Context()) + if !ok { + t.Fatal("expected user in context") + } + if user.UserID != 10 { + t.Errorf("expected user ID 10, got %d", user.UserID) + } + sid, ok := GetSessionIDFromContext(c.Request().Context()) + if !ok { + t.Fatal("expected session ID in context") + } + if sid == "" { + t.Error("expected non-empty hashed session ID") + } + return nil + }) + + if err := handler(c); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !called { + t.Error("next handler was not called") + } +} + +func TestSessionCookieMiddleware_InvalidSession(t *testing.T) { + mq := &mockSessionQuerier{ + getUserBySessionFunc: func(_ context.Context, _ string) (db.User, error) { + return db.User{}, echo.ErrNotFound + }, + } + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "albatross_session", Value: "invalid-session"}) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + mw := SessionCookieMiddleware(mq) + var called bool + handler := mw(func(c echo.Context) error { + called = true + _, ok := GetUserFromContext(c.Request().Context()) + if ok { + t.Error("expected no user in context for invalid session") + } + return nil + }) + + if err := handler(c); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !called { + t.Error("next handler was not called") + } +} diff --git a/backend/api/handler_test.go b/backend/api/handler_test.go index a68dfa0..2340d33 100644 --- a/backend/api/handler_test.go +++ b/backend/api/handler_test.go @@ -18,6 +18,12 @@ type mockQuerier struct { db.Querier getGameByIDFunc func(ctx context.Context, gameID int32) (db.GetGameByIDRow, error) listMainPlayersFunc func(ctx context.Context, gameIDs []int32) ([]db.ListMainPlayersRow, error) + listPublicGamesFunc func(ctx context.Context) ([]db.ListPublicGamesRow, error) + deleteSessionFunc func(ctx context.Context, sessionID string) error + getLatestStateFunc func(ctx context.Context, arg db.GetLatestStateParams) (db.GetLatestStateRow, error) + updateCodeFunc func(ctx context.Context, arg db.UpdateCodeParams) error + getRankingFunc func(ctx context.Context, gameID int32) ([]db.GetRankingRow, error) + getLatestStatesFunc func(ctx context.Context, gameID int32) ([]db.GetLatestStatesOfMainPlayersRow, error) } func (m *mockQuerier) GetGameByID(ctx context.Context, gameID int32) (db.GetGameByIDRow, error) { @@ -34,6 +40,48 @@ func (m *mockQuerier) ListMainPlayers(ctx context.Context, gameIDs []int32) ([]d return nil, nil } +func (m *mockQuerier) ListPublicGames(ctx context.Context) ([]db.ListPublicGamesRow, error) { + if m.listPublicGamesFunc != nil { + return m.listPublicGamesFunc(ctx) + } + return nil, nil +} + +func (m *mockQuerier) DeleteSession(ctx context.Context, sessionID string) error { + if m.deleteSessionFunc != nil { + return m.deleteSessionFunc(ctx, sessionID) + } + return nil +} + +func (m *mockQuerier) GetLatestState(ctx context.Context, arg db.GetLatestStateParams) (db.GetLatestStateRow, error) { + if m.getLatestStateFunc != nil { + return m.getLatestStateFunc(ctx, arg) + } + return db.GetLatestStateRow{}, pgx.ErrNoRows +} + +func (m *mockQuerier) UpdateCode(ctx context.Context, arg db.UpdateCodeParams) error { + if m.updateCodeFunc != nil { + return m.updateCodeFunc(ctx, arg) + } + return nil +} + +func (m *mockQuerier) GetRanking(ctx context.Context, gameID int32) ([]db.GetRankingRow, error) { + if m.getRankingFunc != nil { + return m.getRankingFunc(ctx, gameID) + } + return nil, nil +} + +func (m *mockQuerier) GetLatestStatesOfMainPlayers(ctx context.Context, gameID int32) ([]db.GetLatestStatesOfMainPlayersRow, error) { + if m.getLatestStatesFunc != nil { + return m.getLatestStatesFunc(ctx, gameID) + } + return nil, nil +} + // mockTxManager implements db.TxManager for testing. type mockTxManager struct{} @@ -371,3 +419,309 @@ func TestPostLogin_AuthFailure(t *testing.T) { t.Errorf("expected 401 response, got %T", resp) } } + +func TestPostLogout(t *testing.T) { + h := Handler{ + q: &mockQuerier{}, + txm: &mockTxManager{}, + hub: &mockGameHub{}, + auth: &mockAuthenticator{}, + conf: &config.Config{BasePath: "/"}, + } + user := &db.User{UserID: 1} + // Set session ID in context + ctx := context.WithValue(context.Background(), sessionIDContextKey{}, "hashed-session") + resp, err := h.PostLogout(ctx, PostLogoutRequestObject{}, user) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := resp.(postLogoutCookieResponse); !ok { + t.Errorf("expected postLogoutCookieResponse, got %T", resp) + } +} + +func TestGetGames_Empty(t *testing.T) { + h := Handler{ + q: &mockQuerier{}, + txm: &mockTxManager{}, + hub: &mockGameHub{}, + auth: &mockAuthenticator{}, + conf: &config.Config{}, + } + user := &db.User{UserID: 1} + resp, err := h.GetGames(context.Background(), GetGamesRequestObject{}, user) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + okResp, ok := resp.(GetGames200JSONResponse) + if !ok { + t.Fatalf("expected 200 response, got %T", resp) + } + if len(okResp.Games) != 0 { + t.Errorf("expected 0 games, got %d", len(okResp.Games)) + } +} + +func TestGetGames_WithGames(t *testing.T) { + now := time.Now() + h := Handler{ + q: &mockQuerier{ + listPublicGamesFunc: func(_ context.Context) ([]db.ListPublicGamesRow, error) { + return []db.ListPublicGamesRow{ + { + GameID: 1, + GameType: "golf", + IsPublic: true, + DisplayName: "Game 1", + DurationSeconds: 300, + StartedAt: pgtype.Timestamp{Time: now, Valid: true}, + ProblemID: 10, + Title: "Problem 1", + Description: "desc", + Language: "php", + SampleCode: "