aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--backend/api/handler_test.go10
-rw-r--r--backend/config/config_test.go19
-rw-r--r--backend/game/service.go11
-rw-r--r--backend/game/service_test.go112
4 files changed, 148 insertions, 4 deletions
diff --git a/backend/api/handler_test.go b/backend/api/handler_test.go
index 2dcde29..41c3403 100644
--- a/backend/api/handler_test.go
+++ b/backend/api/handler_test.go
@@ -672,7 +672,15 @@ func TestGetGameWatchRanking_EmptyRanking(t *testing.T) {
}
func TestGetGameWatchLatestStates_Empty(t *testing.T) {
- h := newTestHandler(&mockQuerier{})
+ h := newTestHandler(&mockQuerier{
+ getGameByIDFunc: func(_ context.Context, _ int32) (db.GetGameByIDRow, error) {
+ return db.GetGameByIDRow{
+ GameID: 1,
+ DurationSeconds: 300,
+ StartedAt: pgtype.Timestamp{Time: time.Now().Add(-10 * time.Minute), Valid: true},
+ }, nil
+ },
+ })
user := &db.User{UserID: 1}
resp, err := h.GetGameWatchLatestStates(context.Background(), GetGameWatchLatestStatesRequestObject{GameID: 1}, user)
if err != nil {
diff --git a/backend/config/config_test.go b/backend/config/config_test.go
index 5110e0c..210e89b 100644
--- a/backend/config/config_test.go
+++ b/backend/config/config_test.go
@@ -1,6 +1,7 @@
package config
import (
+ "os"
"testing"
)
@@ -125,10 +126,24 @@ func TestNewConfigFromEnv_MissingRequired(t *testing.T) {
},
}
+ allKeys := []string{
+ "ALBATROSS_DB_HOST",
+ "ALBATROSS_DB_PORT",
+ "ALBATROSS_DB_USER",
+ "ALBATROSS_DB_PASSWORD",
+ "ALBATROSS_DB_NAME",
+ "ALBATROSS_BASE_PATH",
+ }
+
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- for k, v := range tt.envVars {
- t.Setenv(k, v)
+ for _, k := range allKeys {
+ if v, ok := tt.envVars[k]; ok {
+ t.Setenv(k, v)
+ } else {
+ t.Setenv(k, "")
+ os.Unsetenv(k)
+ }
}
_, err := NewConfigFromEnv()
if err == nil {
diff --git a/backend/game/service.go b/backend/game/service.go
index a054388..dc23061 100644
--- a/backend/game/service.go
+++ b/backend/game/service.go
@@ -293,6 +293,15 @@ func (s *Service) GetLatestState(ctx context.Context, gameID int, userID int32)
}
func (s *Service) GetWatchLatestStates(ctx context.Context, gameID int, userID *int32, isAdmin bool) (map[int]LatestState, error) {
+ gameRow, err := s.q.GetGameByID(ctx, int32(gameID))
+ if err != nil {
+ if errors.Is(err, pgx.ErrNoRows) {
+ return nil, ErrNotFound
+ }
+ return nil, err
+ }
+ finished := IsGameFinished(gameRow.StartedAt, gameRow.DurationSeconds)
+
rows, err := s.q.GetLatestStatesOfMainPlayers(ctx, int32(gameID))
if err != nil {
return nil, err
@@ -320,7 +329,7 @@ func (s *Service) GetWatchLatestStates(ctx context.Context, gameID int, userID *
submittedAt = &ts
}
- if userID != nil && row.UserID == *userID && !isAdmin {
+ if userID != nil && row.UserID == *userID && !isAdmin && !finished {
return nil, ErrForbidden
}
diff --git a/backend/game/service_test.go b/backend/game/service_test.go
index 95ceef6..93e62f7 100644
--- a/backend/game/service_test.go
+++ b/backend/game/service_test.go
@@ -1,12 +1,31 @@
package game
import (
+ "context"
"testing"
"time"
"github.com/jackc/pgx/v5/pgtype"
+
+ "albatross-2026-backend/db"
)
+// stubQuerier implements db.Querier with only the methods needed for tests.
+// All unimplemented methods panic so missing stubs are caught immediately.
+type stubQuerier struct {
+ db.Querier
+ getGameByID func(ctx context.Context, gameID int32) (db.GetGameByIDRow, error)
+ getLatestStatesOfMainPlayers func(ctx context.Context, gameID int32) ([]db.GetLatestStatesOfMainPlayersRow, error)
+}
+
+func (s *stubQuerier) GetGameByID(ctx context.Context, gameID int32) (db.GetGameByIDRow, error) {
+ return s.getGameByID(ctx, gameID)
+}
+
+func (s *stubQuerier) GetLatestStatesOfMainPlayers(ctx context.Context, gameID int32) ([]db.GetLatestStatesOfMainPlayersRow, error) {
+ return s.getLatestStatesOfMainPlayers(ctx, gameID)
+}
+
func TestIsGameRunning(t *testing.T) {
now := time.Now()
tests := []struct {
@@ -80,3 +99,96 @@ func TestIsGameFinished(t *testing.T) {
})
}
}
+
+func TestGetWatchLatestStates_ParticipantRestriction(t *testing.T) {
+ now := time.Now()
+ var playerID int32 = 1
+ var otherID int32 = 2
+
+ code := "<?php echo 1;"
+ status := "pass"
+ var codeSize int32 = 14
+
+ mainPlayerRows := []db.GetLatestStatesOfMainPlayersRow{
+ {
+ GameID: 1,
+ UserID: playerID,
+ Code: &code,
+ Status: &status,
+ CodeSize: &codeSize,
+ },
+ {
+ GameID: 1,
+ UserID: otherID,
+ Code: &code,
+ Status: &status,
+ CodeSize: &codeSize,
+ },
+ }
+
+ tests := []struct {
+ name string
+ startedAt pgtype.Timestamp
+ userID *int32
+ isAdmin bool
+ wantErr error
+ }{
+ {
+ name: "participant blocked while game is running",
+ startedAt: pgtype.Timestamp{Time: now.Add(-1 * time.Minute), Valid: true},
+ userID: &playerID,
+ isAdmin: false,
+ wantErr: ErrForbidden,
+ },
+ {
+ name: "participant allowed after game finished",
+ startedAt: pgtype.Timestamp{Time: now.Add(-10 * time.Minute), Valid: true},
+ userID: &playerID,
+ isAdmin: false,
+ wantErr: nil,
+ },
+ {
+ name: "participant blocked before game starts",
+ startedAt: pgtype.Timestamp{Valid: false},
+ userID: &playerID,
+ isAdmin: false,
+ wantErr: ErrForbidden,
+ },
+ {
+ name: "admin always allowed even while running",
+ startedAt: pgtype.Timestamp{Time: now.Add(-1 * time.Minute), Valid: true},
+ userID: &playerID,
+ isAdmin: true,
+ wantErr: nil,
+ },
+ {
+ name: "non-participant allowed while running",
+ startedAt: pgtype.Timestamp{Time: now.Add(-1 * time.Minute), Valid: true},
+ userID: nil,
+ isAdmin: false,
+ wantErr: nil,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ q := &stubQuerier{
+ getGameByID: func(_ context.Context, _ int32) (db.GetGameByIDRow, error) {
+ return db.GetGameByIDRow{
+ StartedAt: tt.startedAt,
+ DurationSeconds: 300,
+ }, nil
+ },
+ getLatestStatesOfMainPlayers: func(_ context.Context, _ int32) ([]db.GetLatestStatesOfMainPlayersRow, error) {
+ return mainPlayerRows, nil
+ },
+ }
+ svc := NewService(q, nil, nil)
+
+ _, err := svc.GetWatchLatestStates(context.Background(), 1, tt.userID, tt.isAdmin)
+ if err != tt.wantErr {
+ t.Errorf("got err=%v, want %v", err, tt.wantErr)
+ }
+ })
+ }
+}