package auth import ( "crypto/hmac" "crypto/rand" "crypto/sha256" "database/sql" "encoding/hex" "fmt" "io" "net/http" "strings" "time" ) // Secrets handles secret and token management type Secrets struct { db *sql.DB } // NewSecrets creates a new secrets manager func NewSecrets(db *sql.DB) (*Secrets, error) { s := &Secrets{db: db} // Ensure manager secret exists if err := s.ensureManagerSecret(); err != nil { return nil, fmt.Errorf("failed to ensure manager secret: %w", err) } return s, nil } // ensureManagerSecret ensures a manager secret exists in the database func (s *Secrets) ensureManagerSecret() error { var count int err := s.db.QueryRow("SELECT COUNT(*) FROM manager_secrets").Scan(&count) if err != nil { return fmt.Errorf("failed to check manager secrets: %w", err) } if count == 0 { // Generate new manager secret secret, err := generateSecret(32) if err != nil { return fmt.Errorf("failed to generate manager secret: %w", err) } _, err = s.db.Exec("INSERT INTO manager_secrets (secret) VALUES (?)", secret) if err != nil { return fmt.Errorf("failed to store manager secret: %w", err) } } return nil } // GetManagerSecret retrieves the current manager secret func (s *Secrets) GetManagerSecret() (string, error) { var secret string err := s.db.QueryRow("SELECT secret FROM manager_secrets ORDER BY created_at DESC LIMIT 1").Scan(&secret) if err != nil { return "", fmt.Errorf("failed to get manager secret: %w", err) } return secret, nil } // GenerateRegistrationToken generates a new registration token func (s *Secrets) GenerateRegistrationToken(createdBy int64, expiresIn time.Duration) (string, error) { token, err := generateSecret(32) if err != nil { return "", fmt.Errorf("failed to generate token: %w", err) } expiresAt := time.Now().Add(expiresIn) _, err = s.db.Exec( "INSERT INTO registration_tokens (token, expires_at, created_by) VALUES (?, ?, ?)", token, expiresAt, createdBy, ) if err != nil { return "", fmt.Errorf("failed to store registration token: %w", err) } return token, nil } // ValidateRegistrationToken validates a registration token func (s *Secrets) ValidateRegistrationToken(token string) (bool, error) { var used bool var expiresAt time.Time var id int64 err := s.db.QueryRow( "SELECT id, expires_at, used FROM registration_tokens WHERE token = ?", token, ).Scan(&id, &expiresAt, &used) if err == sql.ErrNoRows { return false, nil } if err != nil { return false, fmt.Errorf("failed to query token: %w", err) } if used { return false, nil } if time.Now().After(expiresAt) { return false, nil } // Mark token as used _, err = s.db.Exec("UPDATE registration_tokens SET used = 1 WHERE id = ?", id) if err != nil { return false, fmt.Errorf("failed to mark token as used: %w", err) } return true, nil } // ListRegistrationTokens lists all registration tokens func (s *Secrets) ListRegistrationTokens() ([]map[string]interface{}, error) { rows, err := s.db.Query( `SELECT id, token, expires_at, used, created_at, created_by FROM registration_tokens ORDER BY created_at DESC`, ) if err != nil { return nil, fmt.Errorf("failed to query tokens: %w", err) } defer rows.Close() var tokens []map[string]interface{} for rows.Next() { var id, createdBy sql.NullInt64 var token string var expiresAt, createdAt time.Time var used bool err := rows.Scan(&id, &token, &expiresAt, &used, &createdAt, &createdBy) if err != nil { continue } tokens = append(tokens, map[string]interface{}{ "id": id.Int64, "token": token, "expires_at": expiresAt, "used": used, "created_at": createdAt, "created_by": createdBy.Int64, }) } return tokens, nil } // RevokeRegistrationToken revokes a registration token func (s *Secrets) RevokeRegistrationToken(tokenID int64) error { _, err := s.db.Exec("UPDATE registration_tokens SET used = 1 WHERE id = ?", tokenID) return err } // GenerateRunnerSecret generates a unique secret for a runner func (s *Secrets) GenerateRunnerSecret() (string, error) { return generateSecret(32) } // SignRequest signs a request with the given secret func SignRequest(method, path, body, secret string, timestamp time.Time) string { message := fmt.Sprintf("%s\n%s\n%s\n%d", method, path, body, timestamp.Unix()) h := hmac.New(sha256.New, []byte(secret)) h.Write([]byte(message)) return hex.EncodeToString(h.Sum(nil)) } // VerifyRequest verifies a signed request func VerifyRequest(r *http.Request, secret string, maxAge time.Duration) (bool, error) { signature := r.Header.Get("X-Runner-Signature") if signature == "" { return false, fmt.Errorf("missing signature") } timestampStr := r.Header.Get("X-Runner-Timestamp") if timestampStr == "" { return false, fmt.Errorf("missing timestamp") } var timestamp time.Time _, err := fmt.Sscanf(timestampStr, "%d", ×tamp) if err != nil { return false, fmt.Errorf("invalid timestamp: %w", err) } // Check timestamp is not too old if time.Since(timestamp) > maxAge { return false, fmt.Errorf("request too old") } // Check timestamp is not in the future (allow 1 minute clock skew) if timestamp.After(time.Now().Add(1 * time.Minute)) { return false, fmt.Errorf("timestamp in future") } // Read body bodyBytes, err := io.ReadAll(r.Body) if err != nil { return false, fmt.Errorf("failed to read body: %w", err) } // Restore body for handler r.Body = io.NopCloser(strings.NewReader(string(bodyBytes))) // Verify signature expectedSig := SignRequest(r.Method, r.URL.Path, string(bodyBytes), secret, timestamp) return hmac.Equal([]byte(signature), []byte(expectedSig)), nil } // GetRunnerSecret retrieves the runner secret for a runner ID func (s *Secrets) GetRunnerSecret(runnerID int64) (string, error) { var secret string err := s.db.QueryRow("SELECT runner_secret FROM runners WHERE id = ?", runnerID).Scan(&secret) if err == sql.ErrNoRows { return "", fmt.Errorf("runner not found") } if err != nil { return "", fmt.Errorf("failed to get runner secret: %w", err) } if secret == "" { return "", fmt.Errorf("runner not verified") } return secret, nil } // generateSecret generates a random secret of the given length func generateSecret(length int) (string, error) { bytes := make([]byte, length) if _, err := rand.Read(bytes); err != nil { return "", err } return hex.EncodeToString(bytes), nil }