diff options
Diffstat (limited to 'backend')
| -rw-r--r-- | backend/api/handler_test.go | 10 | ||||
| -rw-r--r-- | backend/config/config_test.go | 19 | ||||
| -rw-r--r-- | backend/game/service.go | 11 | ||||
| -rw-r--r-- | backend/game/service_test.go | 112 |
4 files changed, 148 insertions, 4 deletions
diff --git a/backend/api/handler_test.go b/backend/api/handler_test.go index 2dcde29..41c3403 100644 --- a/backend/api/handler_test.go +++ b/backend/api/handler_test.go @@ -672,7 +672,15 @@ func TestGetGameWatchRanking_EmptyRanking(t *testing.T) { } func TestGetGameWatchLatestStates_Empty(t *testing.T) { - h := newTestHandler(&mockQuerier{}) + h := newTestHandler(&mockQuerier{ + getGameByIDFunc: func(_ context.Context, _ int32) (db.GetGameByIDRow, error) { + return db.GetGameByIDRow{ + GameID: 1, + DurationSeconds: 300, + StartedAt: pgtype.Timestamp{Time: time.Now().Add(-10 * time.Minute), Valid: true}, + }, nil + }, + }) user := &db.User{UserID: 1} resp, err := h.GetGameWatchLatestStates(context.Background(), GetGameWatchLatestStatesRequestObject{GameID: 1}, user) if err != nil { diff --git a/backend/config/config_test.go b/backend/config/config_test.go index 5110e0c..210e89b 100644 --- a/backend/config/config_test.go +++ b/backend/config/config_test.go @@ -1,6 +1,7 @@ package config import ( + "os" "testing" ) @@ -125,10 +126,24 @@ func TestNewConfigFromEnv_MissingRequired(t *testing.T) { }, } + allKeys := []string{ + "ALBATROSS_DB_HOST", + "ALBATROSS_DB_PORT", + "ALBATROSS_DB_USER", + "ALBATROSS_DB_PASSWORD", + "ALBATROSS_DB_NAME", + "ALBATROSS_BASE_PATH", + } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - for k, v := range tt.envVars { - t.Setenv(k, v) + for _, k := range allKeys { + if v, ok := tt.envVars[k]; ok { + t.Setenv(k, v) + } else { + t.Setenv(k, "") + os.Unsetenv(k) + } } _, err := NewConfigFromEnv() if err == nil { diff --git a/backend/game/service.go b/backend/game/service.go index a054388..dc23061 100644 --- a/backend/game/service.go +++ b/backend/game/service.go @@ -293,6 +293,15 @@ func (s *Service) GetLatestState(ctx context.Context, gameID int, userID int32) } func (s *Service) GetWatchLatestStates(ctx context.Context, gameID int, userID *int32, isAdmin bool) (map[int]LatestState, error) { + gameRow, err := s.q.GetGameByID(ctx, int32(gameID)) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrNotFound + } + return nil, err + } + finished := IsGameFinished(gameRow.StartedAt, gameRow.DurationSeconds) + rows, err := s.q.GetLatestStatesOfMainPlayers(ctx, int32(gameID)) if err != nil { return nil, err @@ -320,7 +329,7 @@ func (s *Service) GetWatchLatestStates(ctx context.Context, gameID int, userID * submittedAt = &ts } - if userID != nil && row.UserID == *userID && !isAdmin { + if userID != nil && row.UserID == *userID && !isAdmin && !finished { return nil, ErrForbidden } diff --git a/backend/game/service_test.go b/backend/game/service_test.go index 95ceef6..93e62f7 100644 --- a/backend/game/service_test.go +++ b/backend/game/service_test.go @@ -1,12 +1,31 @@ package game import ( + "context" "testing" "time" "github.com/jackc/pgx/v5/pgtype" + + "albatross-2026-backend/db" ) +// stubQuerier implements db.Querier with only the methods needed for tests. +// All unimplemented methods panic so missing stubs are caught immediately. +type stubQuerier struct { + db.Querier + getGameByID func(ctx context.Context, gameID int32) (db.GetGameByIDRow, error) + getLatestStatesOfMainPlayers func(ctx context.Context, gameID int32) ([]db.GetLatestStatesOfMainPlayersRow, error) +} + +func (s *stubQuerier) GetGameByID(ctx context.Context, gameID int32) (db.GetGameByIDRow, error) { + return s.getGameByID(ctx, gameID) +} + +func (s *stubQuerier) GetLatestStatesOfMainPlayers(ctx context.Context, gameID int32) ([]db.GetLatestStatesOfMainPlayersRow, error) { + return s.getLatestStatesOfMainPlayers(ctx, gameID) +} + func TestIsGameRunning(t *testing.T) { now := time.Now() tests := []struct { @@ -80,3 +99,96 @@ func TestIsGameFinished(t *testing.T) { }) } } + +func TestGetWatchLatestStates_ParticipantRestriction(t *testing.T) { + now := time.Now() + var playerID int32 = 1 + var otherID int32 = 2 + + code := "<?php echo 1;" + status := "pass" + var codeSize int32 = 14 + + mainPlayerRows := []db.GetLatestStatesOfMainPlayersRow{ + { + GameID: 1, + UserID: playerID, + Code: &code, + Status: &status, + CodeSize: &codeSize, + }, + { + GameID: 1, + UserID: otherID, + Code: &code, + Status: &status, + CodeSize: &codeSize, + }, + } + + tests := []struct { + name string + startedAt pgtype.Timestamp + userID *int32 + isAdmin bool + wantErr error + }{ + { + name: "participant blocked while game is running", + startedAt: pgtype.Timestamp{Time: now.Add(-1 * time.Minute), Valid: true}, + userID: &playerID, + isAdmin: false, + wantErr: ErrForbidden, + }, + { + name: "participant allowed after game finished", + startedAt: pgtype.Timestamp{Time: now.Add(-10 * time.Minute), Valid: true}, + userID: &playerID, + isAdmin: false, + wantErr: nil, + }, + { + name: "participant blocked before game starts", + startedAt: pgtype.Timestamp{Valid: false}, + userID: &playerID, + isAdmin: false, + wantErr: ErrForbidden, + }, + { + name: "admin always allowed even while running", + startedAt: pgtype.Timestamp{Time: now.Add(-1 * time.Minute), Valid: true}, + userID: &playerID, + isAdmin: true, + wantErr: nil, + }, + { + name: "non-participant allowed while running", + startedAt: pgtype.Timestamp{Time: now.Add(-1 * time.Minute), Valid: true}, + userID: nil, + isAdmin: false, + wantErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := &stubQuerier{ + getGameByID: func(_ context.Context, _ int32) (db.GetGameByIDRow, error) { + return db.GetGameByIDRow{ + StartedAt: tt.startedAt, + DurationSeconds: 300, + }, nil + }, + getLatestStatesOfMainPlayers: func(_ context.Context, _ int32) ([]db.GetLatestStatesOfMainPlayersRow, error) { + return mainPlayerRows, nil + }, + } + svc := NewService(q, nil, nil) + + _, err := svc.GetWatchLatestStates(context.Background(), 1, tt.userID, tt.isAdmin) + if err != tt.wantErr { + t.Errorf("got err=%v, want %v", err, tt.wantErr) + } + }) + } +} |
