diff options
Diffstat (limited to 'backend/auth/auth.go')
| -rw-r--r-- | backend/auth/auth.go | 76 |
1 files changed, 29 insertions, 47 deletions
diff --git a/backend/auth/auth.go b/backend/auth/auth.go index 7d9a4c2..a1fcb64 100644 --- a/backend/auth/auth.go +++ b/backend/auth/auth.go @@ -7,7 +7,6 @@ import ( "time" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgxpool" "golang.org/x/crypto/bcrypt" "albatross-2026-backend/account" @@ -15,28 +14,32 @@ import ( "albatross-2026-backend/fortee" ) -var ( - ErrForteeLoginTimeout = errors.New("fortee login timeout") -) +var ErrForteeLoginTimeout = errors.New("fortee login timeout") const ( forteeAPITimeout = 3 * time.Second ) -func Login( +type Authenticator struct { + q db.Querier + txm db.TxManager +} + +func NewAuthenticator(q db.Querier, txm db.TxManager) *Authenticator { + return &Authenticator{q: q, txm: txm} +} + +func (a *Authenticator) Login( ctx context.Context, - queries *db.Queries, - pool *pgxpool.Pool, username string, password string, ) (int, error) { - userAuth, err := queries.GetUserAuthByUsername(ctx, username) + userAuth, err := a.q.GetUserAuthByUsername(ctx, username) if err != nil && !errors.Is(err, pgx.ErrNoRows) { return 0, err } if userAuth.AuthType == "password" { - // Authenticate with password. passwordHash := userAuth.PasswordHash if passwordHash == nil { return 0, errors.New("inconsistent data: password auth type but no password hash") @@ -48,14 +51,11 @@ func Login( return int(userAuth.UserID), nil } - // Authenticate with fortee. - return verifyForteeAccountOrSignup(ctx, queries, pool, username, password) + return a.verifyForteeAccountOrSignup(ctx, username, password) } -func verifyForteeAccountOrSignup( +func (a *Authenticator) verifyForteeAccountOrSignup( ctx context.Context, - queries *db.Queries, - pool *pgxpool.Pool, username string, password string, ) (int, error) { @@ -63,58 +63,40 @@ func verifyForteeAccountOrSignup( if err != nil { return 0, err } - userID, err := queries.GetUserIDByUsername(ctx, canonicalizedUsername) + userID, err := a.q.GetUserIDByUsername(ctx, canonicalizedUsername) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - return signup( - ctx, - queries, - pool, - canonicalizedUsername, - ) + return a.signup(ctx, canonicalizedUsername) } return 0, err } return int(userID), nil } -func signup( +func (a *Authenticator) signup( ctx context.Context, - queries *db.Queries, - pool *pgxpool.Pool, username string, ) (int, error) { - tx, err := pool.Begin(ctx) - if err != nil { - return 0, err - } - defer func() { - if err := tx.Rollback(ctx); err != nil && err != pgx.ErrTxClosed { - slog.Error("failed to rollback transaction", "error", err) + var userID int32 + err := a.txm.RunInTx(ctx, func(qtx db.Querier) error { + var err error + userID, err = qtx.CreateUser(ctx, username) + if err != nil { + return err } - }() - - qtx := queries.WithTx(tx) - userID, err := qtx.CreateUser(ctx, username) + return qtx.CreateUserAuth(ctx, db.CreateUserAuthParams{ + UserID: userID, + AuthType: "fortee", + }) + }) if err != nil { return 0, err } - if err := qtx.CreateUserAuth(ctx, db.CreateUserAuthParams{ - UserID: userID, - AuthType: "fortee", - }); err != nil { - return 0, err - } - - if err := tx.Commit(ctx); err != nil { - return 0, err - } go func() { - err := account.FetchIcon(context.Background(), queries, int(userID)) + err := account.FetchIcon(context.Background(), a.q, int(userID)) if err != nil { slog.Error("failed to fetch icon", "error", err) - // The failure is intentionally ignored. Retry manually if needed. } }() return int(userID), nil |
