package auth import ( "crypto/hmac" "crypto/rand" "crypto/sha256" "database/sql" "encoding/hex" "fmt" "io" "log" "net/http" "os" "strings" "time" ) // Secrets handles secret and token management type Secrets struct { db *sql.DB fixedRegistrationToken string // Fixed token from environment variable (reusable, never expires) } // NewSecrets creates a new secrets manager func NewSecrets(db *sql.DB) (*Secrets, error) { s := &Secrets{db: db} // Check for fixed registration token from environment fixedToken := os.Getenv("FIXED_REGISTRATION_TOKEN") if fixedToken != "" { s.fixedRegistrationToken = fixedToken log.Printf("Fixed registration token enabled (from FIXED_REGISTRATION_TOKEN env var)") log.Printf("WARNING: Fixed registration token is reusable and never expires - use only for testing/development!") } // Ensure manager secret exists if err := s.ensureManagerSecret(); err != nil { return nil, fmt.Errorf("failed to ensure manager secret: %w", err) } return s, nil } // ensureManagerSecret ensures a manager secret exists in the database func (s *Secrets) ensureManagerSecret() error { var count int err := s.db.QueryRow("SELECT COUNT(*) FROM manager_secrets").Scan(&count) if err != nil { return fmt.Errorf("failed to check manager secrets: %w", err) } if count == 0 { // Generate new manager secret secret, err := generateSecret(32) if err != nil { return fmt.Errorf("failed to generate manager secret: %w", err) } _, err = s.db.Exec("INSERT INTO manager_secrets (secret) VALUES (?)", secret) if err != nil { return fmt.Errorf("failed to store manager secret: %w", err) } } return nil } // GetManagerSecret retrieves the current manager secret func (s *Secrets) GetManagerSecret() (string, error) { var secret string err := s.db.QueryRow("SELECT secret FROM manager_secrets ORDER BY created_at DESC LIMIT 1").Scan(&secret) if err != nil { return "", fmt.Errorf("failed to get manager secret: %w", err) } return secret, nil } // GenerateRegistrationToken generates a new registration token // If expiresIn is 0, the token will never expire (uses far future date) // Note: Token expiration only affects whether the token can be used for registration. // Once a runner registers, it operates independently using its own secrets. func (s *Secrets) GenerateRegistrationToken(createdBy int64, expiresIn time.Duration) (string, error) { token, err := generateSecret(32) if err != nil { return "", fmt.Errorf("failed to generate token: %w", err) } var expiresAt time.Time if expiresIn == 0 { // Use far future date (year 9999) to represent infinite expiration expiresAt = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC) } else { expiresAt = time.Now().Add(expiresIn) } _, err = s.db.Exec( "INSERT INTO registration_tokens (token, expires_at, created_by) VALUES (?, ?, ?)", token, expiresAt, createdBy, ) if err != nil { return "", fmt.Errorf("failed to store registration token: %w", err) } return token, nil } // TokenValidationResult represents the result of token validation type TokenValidationResult struct { Valid bool Reason string // "valid", "not_found", "already_used", "expired" Error error } // ValidateRegistrationToken validates a registration token func (s *Secrets) ValidateRegistrationToken(token string) (bool, error) { result, err := s.ValidateRegistrationTokenDetailed(token) if err != nil { return false, err } // For backward compatibility, return just the valid boolean return result.Valid, nil } // ValidateRegistrationTokenDetailed validates a registration token and returns detailed result func (s *Secrets) ValidateRegistrationTokenDetailed(token string) (*TokenValidationResult, error) { // Check fixed token first (if set) - it's reusable and never expires if s.fixedRegistrationToken != "" && token == s.fixedRegistrationToken { log.Printf("Fixed registration token used (from FIXED_REGISTRATION_TOKEN env var)") return &TokenValidationResult{Valid: true, Reason: "valid"}, nil } // Check database tokens var used bool var expiresAt time.Time var id int64 err := s.db.QueryRow( "SELECT id, expires_at, used FROM registration_tokens WHERE token = ?", token, ).Scan(&id, &expiresAt, &used) if err == sql.ErrNoRows { return &TokenValidationResult{Valid: false, Reason: "not_found"}, nil } if err != nil { return nil, fmt.Errorf("failed to query token: %w", err) } if used { return &TokenValidationResult{Valid: false, Reason: "already_used"}, nil } // Check if token has infinite expiration (year 9999 or later) // Tokens with infinite expiration never expire infiniteExpirationThreshold := time.Date(3000, 1, 1, 0, 0, 0, 0, time.UTC) if expiresAt.Before(infiniteExpirationThreshold) { // Normal expiration check for tokens with finite expiration if time.Now().After(expiresAt) { return &TokenValidationResult{Valid: false, Reason: "expired"}, nil } } // If expiresAt is after the threshold, treat it as infinite (never expires) // Mark token as used _, err = s.db.Exec("UPDATE registration_tokens SET used = 1 WHERE id = ?", id) if err != nil { return nil, fmt.Errorf("failed to mark token as used: %w", err) } return &TokenValidationResult{Valid: true, Reason: "valid"}, nil } // ListRegistrationTokens lists all registration tokens func (s *Secrets) ListRegistrationTokens() ([]map[string]interface{}, error) { rows, err := s.db.Query( `SELECT id, token, expires_at, used, created_at, created_by FROM registration_tokens ORDER BY created_at DESC`, ) if err != nil { return nil, fmt.Errorf("failed to query tokens: %w", err) } defer rows.Close() var tokens []map[string]interface{} for rows.Next() { var id, createdBy sql.NullInt64 var token string var expiresAt, createdAt time.Time var used bool err := rows.Scan(&id, &token, &expiresAt, &used, &createdAt, &createdBy) if err != nil { continue } tokens = append(tokens, map[string]interface{}{ "id": id.Int64, "token": token, "expires_at": expiresAt, "used": used, "created_at": createdAt, "created_by": createdBy.Int64, }) } return tokens, nil } // RevokeRegistrationToken revokes a registration token func (s *Secrets) RevokeRegistrationToken(tokenID int64) error { _, err := s.db.Exec("UPDATE registration_tokens SET used = 1 WHERE id = ?", tokenID) return err } // GenerateRunnerSecret generates a unique secret for a runner func (s *Secrets) GenerateRunnerSecret() (string, error) { return generateSecret(32) } // SignRequest signs a request with the given secret func SignRequest(method, path, body, secret string, timestamp time.Time) string { message := fmt.Sprintf("%s\n%s\n%s\n%d", method, path, body, timestamp.Unix()) h := hmac.New(sha256.New, []byte(secret)) h.Write([]byte(message)) return hex.EncodeToString(h.Sum(nil)) } // VerifyRequest verifies a signed request func VerifyRequest(r *http.Request, secret string, maxAge time.Duration) (bool, error) { signature := r.Header.Get("X-Runner-Signature") if signature == "" { return false, fmt.Errorf("missing signature") } timestampStr := r.Header.Get("X-Runner-Timestamp") if timestampStr == "" { return false, fmt.Errorf("missing timestamp") } var timestampUnix int64 _, err := fmt.Sscanf(timestampStr, "%d", ×tampUnix) if err != nil { return false, fmt.Errorf("invalid timestamp: %w", err) } timestamp := time.Unix(timestampUnix, 0) // Check timestamp is not too old if time.Since(timestamp) > maxAge { return false, fmt.Errorf("request too old") } // Check timestamp is not in the future (allow 1 minute clock skew) if timestamp.After(time.Now().Add(1 * time.Minute)) { return false, fmt.Errorf("timestamp in future") } // Read body bodyBytes, err := io.ReadAll(r.Body) if err != nil { return false, fmt.Errorf("failed to read body: %w", err) } // Restore body for handler r.Body = io.NopCloser(strings.NewReader(string(bodyBytes))) // Verify signature - use path without query parameters (query params are not part of signature) // The runner signs with the path including query params, but we verify with just the path // This is intentional - query params are for identification, not part of the signature path := r.URL.Path expectedSig := SignRequest(r.Method, path, string(bodyBytes), secret, timestamp) return hmac.Equal([]byte(signature), []byte(expectedSig)), nil } // GetRunnerSecret retrieves the runner secret for a runner ID func (s *Secrets) GetRunnerSecret(runnerID int64) (string, error) { var secret string err := s.db.QueryRow("SELECT runner_secret FROM runners WHERE id = ?", runnerID).Scan(&secret) if err == sql.ErrNoRows { return "", fmt.Errorf("runner not found") } if err != nil { return "", fmt.Errorf("failed to get runner secret: %w", err) } if secret == "" { return "", fmt.Errorf("runner not verified") } return secret, nil } // generateSecret generates a random secret of the given length func generateSecret(length int) (string, error) { bytes := make([]byte, length) if _, err := rand.Read(bytes); err != nil { return "", err } return hex.EncodeToString(bytes), nil }