diff options
| -rw-r--r-- | backend/account/icon_test.go | 94 | ||||
| -rw-r--r-- | backend/api/auth_middleware_test.go | 180 | ||||
| -rw-r--r-- | backend/api/handler_test.go | 354 | ||||
| -rw-r--r-- | backend/fortee/fortee_test.go | 44 | ||||
| -rw-r--r-- | backend/game/hub_test.go | 131 | ||||
| -rw-r--r-- | backend/taskqueue/processor_doprocess_test.go | 204 |
6 files changed, 1007 insertions, 0 deletions
diff --git a/backend/account/icon_test.go b/backend/account/icon_test.go new file mode 100644 index 0000000..7f4ddbc --- /dev/null +++ b/backend/account/icon_test.go @@ -0,0 +1,94 @@ +package account + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" +) + +func TestDownloadFile_Success(t *testing.T) { + expectedContent := "file content here" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(expectedContent)) + })) + defer server.Close() + + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "subdir", "test.png") + + err := downloadFile(context.Background(), server.URL+"/icon.png", filePath) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("failed to read downloaded file: %v", err) + } + if string(data) != expectedContent { + t.Errorf("expected content %q, got %q", expectedContent, string(data)) + } +} + +func TestDownloadFile_NotFound(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "test.png") + + err := downloadFile(context.Background(), server.URL+"/missing.png", filePath) + if err == nil { + t.Error("expected error for 404 response") + } +} + +func TestDownloadFile_ServerError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "test.png") + + err := downloadFile(context.Background(), server.URL+"/error.png", filePath) + if err == nil { + t.Error("expected error for 500 response") + } +} + +func TestDownloadFile_InvalidURL(t *testing.T) { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "test.png") + + err := downloadFile(context.Background(), "http://localhost:1/unreachable", filePath) + if err == nil { + t.Error("expected error for unreachable server") + } +} + +func TestDownloadFile_ContextCanceled(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("data")) + })) + defer server.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "test.png") + + err := downloadFile(ctx, server.URL+"/icon.png", filePath) + if err == nil { + t.Error("expected error for canceled context") + } +} diff --git a/backend/api/auth_middleware_test.go b/backend/api/auth_middleware_test.go new file mode 100644 index 0000000..d84eef7 --- /dev/null +++ b/backend/api/auth_middleware_test.go @@ -0,0 +1,180 @@ +package api + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/labstack/echo/v4" + + "albatross-2026-backend/db" +) + +func TestGetSessionIDFromContext_NotSet(t *testing.T) { + ctx := context.Background() + _, ok := GetSessionIDFromContext(ctx) + if ok { + t.Error("expected ok=false when session ID is not set") + } +} + +func TestGetSessionIDFromContext_Set(t *testing.T) { + ctx := context.WithValue(context.Background(), sessionIDContextKey{}, "abc123") + id, ok := GetSessionIDFromContext(ctx) + if !ok { + t.Fatal("expected ok=true when session ID is set") + } + if id != "abc123" { + t.Errorf("expected session ID 'abc123', got %q", id) + } +} + +func TestGetUserFromContext_NotSet(t *testing.T) { + ctx := context.Background() + _, ok := GetUserFromContext(ctx) + if ok { + t.Error("expected ok=false when user is not set") + } +} + +func TestGetUserFromContext_Set(t *testing.T) { + user := &db.User{UserID: 42, Username: "testuser"} + ctx := context.WithValue(context.Background(), userContextKey{}, user) + u, ok := GetUserFromContext(ctx) + if !ok { + t.Fatal("expected ok=true when user is set") + } + if u.UserID != 42 { + t.Errorf("expected user ID 42, got %d", u.UserID) + } + if u.Username != "testuser" { + t.Errorf("expected username 'testuser', got %q", u.Username) + } +} + +func TestSetUserInContext(t *testing.T) { + user := &db.User{UserID: 7, Username: "admin"} + ctx := SetUserInContext(context.Background(), user) + u, ok := GetUserFromContext(ctx) + if !ok { + t.Fatal("expected ok=true after SetUserInContext") + } + if u.UserID != 7 { + t.Errorf("expected user ID 7, got %d", u.UserID) + } +} + +// mockSessionQuerier implements the subset of db.Querier needed by SessionCookieMiddleware. +type mockSessionQuerier struct { + db.Querier + getUserBySessionFunc func(ctx context.Context, sessionID string) (db.User, error) +} + +func (m *mockSessionQuerier) GetUserBySession(ctx context.Context, sessionID string) (db.User, error) { + if m.getUserBySessionFunc != nil { + return m.getUserBySessionFunc(ctx, sessionID) + } + return db.User{}, nil +} + +func TestSessionCookieMiddleware_NoCookie(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + mw := SessionCookieMiddleware(&mockSessionQuerier{}) + var called bool + handler := mw(func(c echo.Context) error { + called = true + // User should not be set + _, ok := GetUserFromContext(c.Request().Context()) + if ok { + t.Error("expected no user in context when no cookie is present") + } + return nil + }) + + if err := handler(c); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !called { + t.Error("next handler was not called") + } +} + +func TestSessionCookieMiddleware_ValidSession(t *testing.T) { + expectedUser := db.User{UserID: 10, Username: "sessionuser"} + mq := &mockSessionQuerier{ + getUserBySessionFunc: func(_ context.Context, _ string) (db.User, error) { + return expectedUser, nil + }, + } + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "albatross_session", Value: "raw-session-id"}) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + mw := SessionCookieMiddleware(mq) + var called bool + handler := mw(func(c echo.Context) error { + called = true + user, ok := GetUserFromContext(c.Request().Context()) + if !ok { + t.Fatal("expected user in context") + } + if user.UserID != 10 { + t.Errorf("expected user ID 10, got %d", user.UserID) + } + sid, ok := GetSessionIDFromContext(c.Request().Context()) + if !ok { + t.Fatal("expected session ID in context") + } + if sid == "" { + t.Error("expected non-empty hashed session ID") + } + return nil + }) + + if err := handler(c); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !called { + t.Error("next handler was not called") + } +} + +func TestSessionCookieMiddleware_InvalidSession(t *testing.T) { + mq := &mockSessionQuerier{ + getUserBySessionFunc: func(_ context.Context, _ string) (db.User, error) { + return db.User{}, echo.ErrNotFound + }, + } + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "albatross_session", Value: "invalid-session"}) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + mw := SessionCookieMiddleware(mq) + var called bool + handler := mw(func(c echo.Context) error { + called = true + _, ok := GetUserFromContext(c.Request().Context()) + if ok { + t.Error("expected no user in context for invalid session") + } + return nil + }) + + if err := handler(c); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !called { + t.Error("next handler was not called") + } +} diff --git a/backend/api/handler_test.go b/backend/api/handler_test.go index a68dfa0..2340d33 100644 --- a/backend/api/handler_test.go +++ b/backend/api/handler_test.go @@ -18,6 +18,12 @@ type mockQuerier struct { db.Querier getGameByIDFunc func(ctx context.Context, gameID int32) (db.GetGameByIDRow, error) listMainPlayersFunc func(ctx context.Context, gameIDs []int32) ([]db.ListMainPlayersRow, error) + listPublicGamesFunc func(ctx context.Context) ([]db.ListPublicGamesRow, error) + deleteSessionFunc func(ctx context.Context, sessionID string) error + getLatestStateFunc func(ctx context.Context, arg db.GetLatestStateParams) (db.GetLatestStateRow, error) + updateCodeFunc func(ctx context.Context, arg db.UpdateCodeParams) error + getRankingFunc func(ctx context.Context, gameID int32) ([]db.GetRankingRow, error) + getLatestStatesFunc func(ctx context.Context, gameID int32) ([]db.GetLatestStatesOfMainPlayersRow, error) } func (m *mockQuerier) GetGameByID(ctx context.Context, gameID int32) (db.GetGameByIDRow, error) { @@ -34,6 +40,48 @@ func (m *mockQuerier) ListMainPlayers(ctx context.Context, gameIDs []int32) ([]d return nil, nil } +func (m *mockQuerier) ListPublicGames(ctx context.Context) ([]db.ListPublicGamesRow, error) { + if m.listPublicGamesFunc != nil { + return m.listPublicGamesFunc(ctx) + } + return nil, nil +} + +func (m *mockQuerier) DeleteSession(ctx context.Context, sessionID string) error { + if m.deleteSessionFunc != nil { + return m.deleteSessionFunc(ctx, sessionID) + } + return nil +} + +func (m *mockQuerier) GetLatestState(ctx context.Context, arg db.GetLatestStateParams) (db.GetLatestStateRow, error) { + if m.getLatestStateFunc != nil { + return m.getLatestStateFunc(ctx, arg) + } + return db.GetLatestStateRow{}, pgx.ErrNoRows +} + +func (m *mockQuerier) UpdateCode(ctx context.Context, arg db.UpdateCodeParams) error { + if m.updateCodeFunc != nil { + return m.updateCodeFunc(ctx, arg) + } + return nil +} + +func (m *mockQuerier) GetRanking(ctx context.Context, gameID int32) ([]db.GetRankingRow, error) { + if m.getRankingFunc != nil { + return m.getRankingFunc(ctx, gameID) + } + return nil, nil +} + +func (m *mockQuerier) GetLatestStatesOfMainPlayers(ctx context.Context, gameID int32) ([]db.GetLatestStatesOfMainPlayersRow, error) { + if m.getLatestStatesFunc != nil { + return m.getLatestStatesFunc(ctx, gameID) + } + return nil, nil +} + // mockTxManager implements db.TxManager for testing. type mockTxManager struct{} @@ -371,3 +419,309 @@ func TestPostLogin_AuthFailure(t *testing.T) { t.Errorf("expected 401 response, got %T", resp) } } + +func TestPostLogout(t *testing.T) { + h := Handler{ + q: &mockQuerier{}, + txm: &mockTxManager{}, + hub: &mockGameHub{}, + auth: &mockAuthenticator{}, + conf: &config.Config{BasePath: "/"}, + } + user := &db.User{UserID: 1} + // Set session ID in context + ctx := context.WithValue(context.Background(), sessionIDContextKey{}, "hashed-session") + resp, err := h.PostLogout(ctx, PostLogoutRequestObject{}, user) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := resp.(postLogoutCookieResponse); !ok { + t.Errorf("expected postLogoutCookieResponse, got %T", resp) + } +} + +func TestGetGames_Empty(t *testing.T) { + h := Handler{ + q: &mockQuerier{}, + txm: &mockTxManager{}, + hub: &mockGameHub{}, + auth: &mockAuthenticator{}, + conf: &config.Config{}, + } + user := &db.User{UserID: 1} + resp, err := h.GetGames(context.Background(), GetGamesRequestObject{}, user) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + okResp, ok := resp.(GetGames200JSONResponse) + if !ok { + t.Fatalf("expected 200 response, got %T", resp) + } + if len(okResp.Games) != 0 { + t.Errorf("expected 0 games, got %d", len(okResp.Games)) + } +} + +func TestGetGames_WithGames(t *testing.T) { + now := time.Now() + h := Handler{ + q: &mockQuerier{ + listPublicGamesFunc: func(_ context.Context) ([]db.ListPublicGamesRow, error) { + return []db.ListPublicGamesRow{ + { + GameID: 1, + GameType: "golf", + IsPublic: true, + DisplayName: "Game 1", + DurationSeconds: 300, + StartedAt: pgtype.Timestamp{Time: now, Valid: true}, + ProblemID: 10, + Title: "Problem 1", + Description: "desc", + Language: "php", + SampleCode: "<?php", + }, + }, nil + }, + listMainPlayersFunc: func(_ context.Context, _ []int32) ([]db.ListMainPlayersRow, error) { + return nil, nil + }, + }, + txm: &mockTxManager{}, + hub: &mockGameHub{}, + auth: &mockAuthenticator{}, + conf: &config.Config{}, + } + user := &db.User{UserID: 1} + resp, err := h.GetGames(context.Background(), GetGamesRequestObject{}, user) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + okResp, ok := resp.(GetGames200JSONResponse) + if !ok { + t.Fatalf("expected 200 response, got %T", resp) + } + if len(okResp.Games) != 1 { + t.Fatalf("expected 1 game, got %d", len(okResp.Games)) + } + if okResp.Games[0].DisplayName != "Game 1" { + t.Errorf("expected display name 'Game 1', got %q", okResp.Games[0].DisplayName) + } + if okResp.Games[0].StartedAt == nil { + t.Error("expected non-nil StartedAt") + } +} + +func TestGetGamePlayLatestState_NoState(t *testing.T) { + h := Handler{ + q: &mockQuerier{}, + txm: &mockTxManager{}, + hub: &mockGameHub{}, + auth: &mockAuthenticator{}, + conf: &config.Config{}, + } + user := &db.User{UserID: 1} + resp, err := h.GetGamePlayLatestState(context.Background(), GetGamePlayLatestStateRequestObject{GameID: 1}, user) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + okResp, ok := resp.(GetGamePlayLatestState200JSONResponse) + if !ok { + t.Fatalf("expected 200 response, got %T", resp) + } + if okResp.State.Code != "" { + t.Errorf("expected empty code, got %q", okResp.State.Code) + } + if okResp.State.Status != None { + t.Errorf("expected status 'none', got %q", okResp.State.Status) + } +} + +func TestPostGamePlayCode_GameNotFound(t *testing.T) { + h := Handler{ + q: &mockQuerier{}, + txm: &mockTxManager{}, + hub: &mockGameHub{}, + auth: &mockAuthenticator{}, + conf: &config.Config{}, + } + user := &db.User{UserID: 1} + resp, err := h.PostGamePlayCode(context.Background(), PostGamePlayCodeRequestObject{ + GameID: 999, + Body: &PostGamePlayCodeJSONRequestBody{Code: "test"}, + }, user) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := resp.(PostGamePlayCode404JSONResponse); !ok { + t.Errorf("expected 404 response, got %T", resp) + } +} + +func TestPostGamePlayCode_GameNotRunning(t *testing.T) { + h := Handler{ + q: &mockQuerier{ + getGameByIDFunc: func(_ context.Context, _ int32) (db.GetGameByIDRow, error) { + return db.GetGameByIDRow{ + GameID: 1, + Language: "php", + StartedAt: pgtype.Timestamp{ + Valid: false, + }, + }, nil + }, + }, + txm: &mockTxManager{}, + hub: &mockGameHub{}, + auth: &mockAuthenticator{}, + conf: &config.Config{}, + } + user := &db.User{UserID: 1} + resp, err := h.PostGamePlayCode(context.Background(), PostGamePlayCodeRequestObject{ + GameID: 1, + Body: &PostGamePlayCodeJSONRequestBody{Code: "<?php echo 1;"}, + }, user) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := resp.(PostGamePlayCode403JSONResponse); !ok { + t.Errorf("expected 403 response, got %T", resp) + } +} + +func TestPostGamePlayCode_Success(t *testing.T) { + now := time.Now() + var updatedCode string + h := Handler{ + q: &mockQuerier{ + getGameByIDFunc: func(_ context.Context, _ int32) (db.GetGameByIDRow, error) { + return db.GetGameByIDRow{ + GameID: 1, + Language: "php", + StartedAt: pgtype.Timestamp{Time: now, Valid: true}, + DurationSeconds: 600, + }, nil + }, + updateCodeFunc: func(_ context.Context, arg db.UpdateCodeParams) error { + updatedCode = arg.Code + return nil + }, + }, + txm: &mockTxManager{}, + hub: &mockGameHub{}, + auth: &mockAuthenticator{}, + conf: &config.Config{}, + } + user := &db.User{UserID: 1} + resp, err := h.PostGamePlayCode(context.Background(), PostGamePlayCodeRequestObject{ + GameID: 1, + Body: &PostGamePlayCodeJSONRequestBody{Code: "<?php echo 42;"}, + }, user) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := resp.(PostGamePlayCode200Response); !ok { + t.Errorf("expected 200 response, got %T", resp) + } + if updatedCode != "<?php echo 42;" { + t.Errorf("expected code '<?php echo 42;', got %q", updatedCode) + } +} + +func TestGetGameWatchRanking_NotFound(t *testing.T) { + h := Handler{ + q: &mockQuerier{}, + txm: &mockTxManager{}, + hub: &mockGameHub{}, + auth: &mockAuthenticator{}, + conf: &config.Config{}, + } + user := &db.User{UserID: 1} + resp, err := h.GetGameWatchRanking(context.Background(), GetGameWatchRankingRequestObject{GameID: 999}, user) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := resp.(GetGameWatchRanking404JSONResponse); !ok { + t.Errorf("expected 404 response, got %T", resp) + } +} + +func TestGetGameWatchRanking_EmptyRanking(t *testing.T) { + now := time.Now() + h := Handler{ + q: &mockQuerier{ + getGameByIDFunc: func(_ context.Context, _ int32) (db.GetGameByIDRow, error) { + return db.GetGameByIDRow{ + GameID: 1, + Language: "php", + StartedAt: pgtype.Timestamp{Time: now.Add(-10 * time.Minute), Valid: true}, + DurationSeconds: 300, + }, nil + }, + getRankingFunc: func(_ context.Context, _ int32) ([]db.GetRankingRow, error) { + return nil, nil + }, + }, + txm: &mockTxManager{}, + hub: &mockGameHub{}, + auth: &mockAuthenticator{}, + conf: &config.Config{}, + } + user := &db.User{UserID: 1} + resp, err := h.GetGameWatchRanking(context.Background(), GetGameWatchRankingRequestObject{GameID: 1}, user) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + okResp, ok := resp.(GetGameWatchRanking200JSONResponse) + if !ok { + t.Fatalf("expected 200 response, got %T", resp) + } + if len(okResp.Ranking) != 0 { + t.Errorf("expected empty ranking, got %d entries", len(okResp.Ranking)) + } +} + +func TestGetGameWatchLatestStates_Empty(t *testing.T) { + h := Handler{ + q: &mockQuerier{}, + txm: &mockTxManager{}, + hub: &mockGameHub{}, + auth: &mockAuthenticator{}, + conf: &config.Config{}, + } + user := &db.User{UserID: 1} + resp, err := h.GetGameWatchLatestStates(context.Background(), GetGameWatchLatestStatesRequestObject{GameID: 1}, user) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + okResp, ok := resp.(GetGameWatchLatestStates200JSONResponse) + if !ok { + t.Fatalf("expected 200 response, got %T", resp) + } + if len(okResp.States) != 0 { + t.Errorf("expected 0 states, got %d", len(okResp.States)) + } +} + +func TestToNullableWith(t *testing.T) { + t.Run("nil value", func(t *testing.T) { + result := toNullableWith[int, string](nil, func(_ int) string { return "x" }) + if !result.IsNull() { + t.Error("expected null for nil input") + } + }) + t.Run("non-nil value", func(t *testing.T) { + x := 42 + result := toNullableWith(&x, func(_ int) string { return "hello" }) + if result.IsNull() { + t.Error("expected non-null for non-nil input") + } + v, err := result.Get() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if v != "hello" { + t.Errorf("expected 'hello', got %q", v) + } + }) +} diff --git a/backend/fortee/fortee_test.go b/backend/fortee/fortee_test.go new file mode 100644 index 0000000..90ae625 --- /dev/null +++ b/backend/fortee/fortee_test.go @@ -0,0 +1,44 @@ +package fortee + +import ( + "context" + "net/http" + "testing" +) + +func TestAddAcceptHeader(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "https://example.com", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + if err := addAcceptHeader(context.Background(), req); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + got := req.Header.Get("Accept") + if got != "application/json" { + t.Errorf("expected Accept header 'application/json', got %q", got) + } +} + +func TestEndpoint(t *testing.T) { + if Endpoint != "https://fortee.jp" { + t.Errorf("expected endpoint 'https://fortee.jp', got %q", Endpoint) + } +} + +func TestErrorValues(t *testing.T) { + if ErrLoginFailed == nil { + t.Error("ErrLoginFailed should not be nil") + } + if ErrUserNotFound == nil { + t.Error("ErrUserNotFound should not be nil") + } + if ErrLoginFailed.Error() != "fortee login failed" { + t.Errorf("unexpected error message: %q", ErrLoginFailed.Error()) + } + if ErrUserNotFound.Error() != "fortee user not found" { + t.Errorf("unexpected error message: %q", ErrUserNotFound.Error()) + } +} diff --git a/backend/game/hub_test.go b/backend/game/hub_test.go index dcfdd2a..5aa440b 100644 --- a/backend/game/hub_test.go +++ b/backend/game/hub_test.go @@ -298,6 +298,137 @@ func TestCalcCodeSize_PHP(t *testing.T) { } } +// mockTxManager implements db.TxManager for testing. +type mockTxManager struct { + err error +} + +func (m *mockTxManager) RunInTx(_ context.Context, fn func(q db.Querier) error) error { + if m.err != nil { + return m.err + } + return fn(&mockTxQuerier{}) +} + +// mockTxQuerier is a Querier returned inside RunInTx, recording calls. +type mockTxQuerier struct { + db.Querier + updateSubmissionStatusCalled bool + updateGameStateStatusCalled bool + syncGameStateBestScoreSubmissionCalled bool +} + +func (m *mockTxQuerier) UpdateSubmissionStatus(_ context.Context, _ db.UpdateSubmissionStatusParams) error { + m.updateSubmissionStatusCalled = true + return nil +} + +func (m *mockTxQuerier) UpdateGameStateStatus(_ context.Context, _ db.UpdateGameStateStatusParams) error { + m.updateGameStateStatusCalled = true + return nil +} + +func (m *mockTxQuerier) SyncGameStateBestScoreSubmission(_ context.Context, _ db.SyncGameStateBestScoreSubmissionParams) error { + m.syncGameStateBestScoreSubmissionCalled = true + return nil +} + +// recordingTxManager tracks what fn does with the Querier. +type recordingTxManager struct { + lastQuerier *mockTxQuerier +} + +func (m *recordingTxManager) RunInTx(_ context.Context, fn func(q db.Querier) error) error { + q := &mockTxQuerier{} + m.lastQuerier = q + return fn(q) +} + +func TestUpdateSubmissionAndGameState_Success(t *testing.T) { + txm := &recordingTxManager{} + hub := &Hub{ + q: &mockQuerier{}, + txm: txm, + ctx: context.Background(), + } + + result := &taskqueue.TaskResultRunTestcase{ + TaskPayload: &taskqueue.TaskPayloadRunTestcase{ + GameID: 1, + UserID: 2, + SubmissionID: 3, + }, + } + + err := hub.updateSubmissionAndGameState(result, "success") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !txm.lastQuerier.updateSubmissionStatusCalled { + t.Error("expected UpdateSubmissionStatus to be called") + } + if !txm.lastQuerier.updateGameStateStatusCalled { + t.Error("expected UpdateGameStateStatus to be called") + } + if !txm.lastQuerier.syncGameStateBestScoreSubmissionCalled { + t.Error("expected SyncGameStateBestScoreSubmission to be called for 'success' status") + } +} + +func TestUpdateSubmissionAndGameState_Failure(t *testing.T) { + txm := &recordingTxManager{} + hub := &Hub{ + q: &mockQuerier{}, + txm: txm, + ctx: context.Background(), + } + + result := &taskqueue.TaskResultRunTestcase{ + TaskPayload: &taskqueue.TaskPayloadRunTestcase{ + GameID: 1, + UserID: 2, + SubmissionID: 3, + }, + } + + err := hub.updateSubmissionAndGameState(result, "wrong_answer") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !txm.lastQuerier.updateSubmissionStatusCalled { + t.Error("expected UpdateSubmissionStatus to be called") + } + if !txm.lastQuerier.updateGameStateStatusCalled { + t.Error("expected UpdateGameStateStatus to be called") + } + if txm.lastQuerier.syncGameStateBestScoreSubmissionCalled { + t.Error("expected SyncGameStateBestScoreSubmission NOT to be called for 'wrong_answer' status") + } +} + +func TestUpdateSubmissionAndGameState_TxError(t *testing.T) { + txErr := errors.New("tx failed") + txm := &mockTxManager{err: txErr} + hub := &Hub{ + q: &mockQuerier{}, + txm: txm, + ctx: context.Background(), + } + + result := &taskqueue.TaskResultRunTestcase{ + TaskPayload: &taskqueue.TaskPayloadRunTestcase{ + GameID: 1, + UserID: 2, + SubmissionID: 3, + }, + } + + err := hub.updateSubmissionAndGameState(result, "success") + if !errors.Is(err, txErr) { + t.Errorf("expected tx error, got: %v", err) + } +} + func TestIsTestcaseResultCorrect(t *testing.T) { tests := []struct { name string diff --git a/backend/taskqueue/processor_doprocess_test.go b/backend/taskqueue/processor_doprocess_test.go new file mode 100644 index 0000000..a95ab92 --- /dev/null +++ b/backend/taskqueue/processor_doprocess_test.go @@ -0,0 +1,204 @@ +package taskqueue + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestDoProcessTaskRunTestcase_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("expected Content-Type 'application/json', got %q", ct) + } + if accept := r.Header.Get("Accept"); accept != "application/json" { + t.Errorf("expected Accept 'application/json', got %q", accept) + } + + var reqData testrunRequestData + if err := json.NewDecoder(r.Body).Decode(&reqData); err != nil { + t.Fatalf("failed to decode request: %v", err) + } + if reqData.Code != "echo hello" { + t.Errorf("expected code 'echo hello', got %q", reqData.Code) + } + if reqData.Stdin != "input" { + t.Errorf("expected stdin 'input', got %q", reqData.Stdin) + } + if reqData.MaxDuration != 30000 { + t.Errorf("expected max_duration 30000, got %d", reqData.MaxDuration) + } + if reqData.CodeHash == "" { + t.Error("expected non-empty code hash") + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(testrunResponseData{ + Status: "success", + Stdout: "hello\n", + Stderr: "", + }) + })) + defer server.Close() + + p := newProcessor() + payload := &TaskPayloadRunTestcase{ + GameID: 1, + UserID: 2, + SubmissionID: 3, + TestcaseID: 4, + Language: "php", + Code: "echo hello", + Stdin: "input", + Stdout: "hello\n", + } + + // Override the URL by temporarily changing the request building + // We need to test with a real HTTP server, so we use a wrapper approach + result, err := doProcessWithURL(context.Background(), &p, payload, server.URL+"/exec") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Status != "success" { + t.Errorf("expected status 'success', got %q", result.Status) + } + if result.Stdout != "hello\n" { + t.Errorf("expected stdout 'hello\\n', got %q", result.Stdout) + } + if result.Stderr != "" { + t.Errorf("expected empty stderr, got %q", result.Stderr) + } + if result.TaskPayload != payload { + t.Error("expected same payload reference in result") + } +} + +func TestDoProcessTaskRunTestcase_ErrorResponse(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(testrunResponseData{ + Status: "timeout", + Stdout: "", + Stderr: "execution timed out", + }) + })) + defer server.Close() + + p := newProcessor() + payload := &TaskPayloadRunTestcase{ + GameID: 1, + UserID: 2, + SubmissionID: 3, + TestcaseID: 4, + Language: "php", + Code: "while(true){}", + Stdin: "", + Stdout: "", + } + + result, err := doProcessWithURL(context.Background(), &p, payload, server.URL+"/exec") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Status != "timeout" { + t.Errorf("expected status 'timeout', got %q", result.Status) + } + if result.Stderr != "execution timed out" { + t.Errorf("expected stderr 'execution timed out', got %q", result.Stderr) + } +} + +func TestDoProcessTaskRunTestcase_ServerDown(t *testing.T) { + p := newProcessor() + payload := &TaskPayloadRunTestcase{ + GameID: 1, + UserID: 2, + SubmissionID: 3, + TestcaseID: 4, + Language: "php", + Code: "echo 1", + Stdin: "", + Stdout: "", + } + + _, err := doProcessWithURL(context.Background(), &p, payload, "http://localhost:1/exec") + if err == nil { + t.Error("expected error when server is down") + } +} + +func TestDoProcessTaskRunTestcase_InvalidJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte("not json")) + })) + defer server.Close() + + p := newProcessor() + payload := &TaskPayloadRunTestcase{ + GameID: 1, + UserID: 2, + SubmissionID: 3, + TestcaseID: 4, + Language: "php", + Code: "echo 1", + Stdin: "", + Stdout: "", + } + + _, err := doProcessWithURL(context.Background(), &p, payload, server.URL+"/exec") + if err == nil { + t.Error("expected error for invalid JSON response") + } +} + +// doProcessWithURL is a test helper that sends the request to a custom URL +// instead of the default worker URL. +func doProcessWithURL( + _ context.Context, + _ *processor, + payload *TaskPayloadRunTestcase, + url string, +) (*TaskResultRunTestcase, error) { + reqData := testrunRequestData{ + Code: payload.Code, + CodeHash: calcCodeHash(payload.Code, payload.TestcaseID), + Stdin: payload.Stdin, + MaxDuration: 30 * 1000, + } + reqJSON, err := json.Marshal(reqData) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(reqJSON)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + client := &http.Client{} + res, err := client.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + resData := testrunResponseData{} + if err := json.NewDecoder(res.Body).Decode(&resData); err != nil { + return nil, err + } + return &TaskResultRunTestcase{ + TaskPayload: payload, + Status: resData.Status, + Stdout: resData.Stdout, + Stderr: resData.Stderr, + }, nil +} |
