initial commit
This commit is contained in:
244
internal/auth/secrets.go
Normal file
244
internal/auth/secrets.go
Normal file
@@ -0,0 +1,244 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Secrets handles secret and token management
|
||||
type Secrets struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewSecrets creates a new secrets manager
|
||||
func NewSecrets(db *sql.DB) (*Secrets, error) {
|
||||
s := &Secrets{db: db}
|
||||
|
||||
// Ensure manager secret exists
|
||||
if err := s.ensureManagerSecret(); err != nil {
|
||||
return nil, fmt.Errorf("failed to ensure manager secret: %w", err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// ensureManagerSecret ensures a manager secret exists in the database
|
||||
func (s *Secrets) ensureManagerSecret() error {
|
||||
var count int
|
||||
err := s.db.QueryRow("SELECT COUNT(*) FROM manager_secrets").Scan(&count)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check manager secrets: %w", err)
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
// Generate new manager secret
|
||||
secret, err := generateSecret(32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate manager secret: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.db.Exec("INSERT INTO manager_secrets (secret) VALUES (?)", secret)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store manager secret: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetManagerSecret retrieves the current manager secret
|
||||
func (s *Secrets) GetManagerSecret() (string, error) {
|
||||
var secret string
|
||||
err := s.db.QueryRow("SELECT secret FROM manager_secrets ORDER BY created_at DESC LIMIT 1").Scan(&secret)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get manager secret: %w", err)
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
// GenerateRegistrationToken generates a new registration token
|
||||
func (s *Secrets) GenerateRegistrationToken(createdBy int64, expiresIn time.Duration) (string, error) {
|
||||
token, err := generateSecret(32)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate token: %w", err)
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(expiresIn)
|
||||
|
||||
_, err = s.db.Exec(
|
||||
"INSERT INTO registration_tokens (token, expires_at, created_by) VALUES (?, ?, ?)",
|
||||
token, expiresAt, createdBy,
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to store registration token: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// ValidateRegistrationToken validates a registration token
|
||||
func (s *Secrets) ValidateRegistrationToken(token string) (bool, error) {
|
||||
var used bool
|
||||
var expiresAt time.Time
|
||||
var id int64
|
||||
|
||||
err := s.db.QueryRow(
|
||||
"SELECT id, expires_at, used FROM registration_tokens WHERE token = ?",
|
||||
token,
|
||||
).Scan(&id, &expiresAt, &used)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to query token: %w", err)
|
||||
}
|
||||
|
||||
if used {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if time.Now().After(expiresAt) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Mark token as used
|
||||
_, err = s.db.Exec("UPDATE registration_tokens SET used = 1 WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to mark token as used: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ListRegistrationTokens lists all registration tokens
|
||||
func (s *Secrets) ListRegistrationTokens() ([]map[string]interface{}, error) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, token, expires_at, used, created_at, created_by
|
||||
FROM registration_tokens
|
||||
ORDER BY created_at DESC`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query tokens: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tokens []map[string]interface{}
|
||||
for rows.Next() {
|
||||
var id, createdBy sql.NullInt64
|
||||
var token string
|
||||
var expiresAt, createdAt time.Time
|
||||
var used bool
|
||||
|
||||
err := rows.Scan(&id, &token, &expiresAt, &used, &createdAt, &createdBy)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
tokens = append(tokens, map[string]interface{}{
|
||||
"id": id.Int64,
|
||||
"token": token,
|
||||
"expires_at": expiresAt,
|
||||
"used": used,
|
||||
"created_at": createdAt,
|
||||
"created_by": createdBy.Int64,
|
||||
})
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// RevokeRegistrationToken revokes a registration token
|
||||
func (s *Secrets) RevokeRegistrationToken(tokenID int64) error {
|
||||
_, err := s.db.Exec("UPDATE registration_tokens SET used = 1 WHERE id = ?", tokenID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GenerateRunnerSecret generates a unique secret for a runner
|
||||
func (s *Secrets) GenerateRunnerSecret() (string, error) {
|
||||
return generateSecret(32)
|
||||
}
|
||||
|
||||
// SignRequest signs a request with the given secret
|
||||
func SignRequest(method, path, body, secret string, timestamp time.Time) string {
|
||||
message := fmt.Sprintf("%s\n%s\n%s\n%d", method, path, body, timestamp.Unix())
|
||||
h := hmac.New(sha256.New, []byte(secret))
|
||||
h.Write([]byte(message))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// VerifyRequest verifies a signed request
|
||||
func VerifyRequest(r *http.Request, secret string, maxAge time.Duration) (bool, error) {
|
||||
signature := r.Header.Get("X-Runner-Signature")
|
||||
if signature == "" {
|
||||
return false, fmt.Errorf("missing signature")
|
||||
}
|
||||
|
||||
timestampStr := r.Header.Get("X-Runner-Timestamp")
|
||||
if timestampStr == "" {
|
||||
return false, fmt.Errorf("missing timestamp")
|
||||
}
|
||||
|
||||
var timestamp time.Time
|
||||
_, err := fmt.Sscanf(timestampStr, "%d", ×tamp)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("invalid timestamp: %w", err)
|
||||
}
|
||||
|
||||
// Check timestamp is not too old
|
||||
if time.Since(timestamp) > maxAge {
|
||||
return false, fmt.Errorf("request too old")
|
||||
}
|
||||
|
||||
// Check timestamp is not in the future (allow 1 minute clock skew)
|
||||
if timestamp.After(time.Now().Add(1 * time.Minute)) {
|
||||
return false, fmt.Errorf("timestamp in future")
|
||||
}
|
||||
|
||||
// Read body
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to read body: %w", err)
|
||||
}
|
||||
// Restore body for handler
|
||||
r.Body = io.NopCloser(strings.NewReader(string(bodyBytes)))
|
||||
|
||||
// Verify signature
|
||||
expectedSig := SignRequest(r.Method, r.URL.Path, string(bodyBytes), secret, timestamp)
|
||||
|
||||
return hmac.Equal([]byte(signature), []byte(expectedSig)), nil
|
||||
}
|
||||
|
||||
// GetRunnerSecret retrieves the runner secret for a runner ID
|
||||
func (s *Secrets) GetRunnerSecret(runnerID int64) (string, error) {
|
||||
var secret string
|
||||
err := s.db.QueryRow("SELECT runner_secret FROM runners WHERE id = ?", runnerID).Scan(&secret)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", fmt.Errorf("runner not found")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get runner secret: %w", err)
|
||||
}
|
||||
if secret == "" {
|
||||
return "", fmt.Errorf("runner not verified")
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
// generateSecret generates a random secret of the given length
|
||||
func generateSecret(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user