something
This commit is contained in:
@@ -5,10 +5,13 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"jiggablend/internal/config"
|
||||
"jiggablend/internal/database"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -17,12 +20,31 @@ import (
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
// Context key types to avoid collisions (typed keys are safer than string keys)
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
contextKeyUserID contextKey = iota
|
||||
contextKeyUserEmail
|
||||
contextKeyUserName
|
||||
contextKeyIsAdmin
|
||||
)
|
||||
|
||||
// Configuration constants
|
||||
const (
|
||||
SessionDuration = 24 * time.Hour
|
||||
SessionCleanupInterval = 1 * time.Hour
|
||||
)
|
||||
|
||||
// Auth handles authentication
|
||||
type Auth struct {
|
||||
db *sql.DB
|
||||
db *database.DB
|
||||
cfg *config.Config
|
||||
googleConfig *oauth2.Config
|
||||
discordConfig *oauth2.Config
|
||||
sessionStore map[string]*Session
|
||||
sessionCache map[string]*Session // In-memory cache for performance
|
||||
cacheMu sync.RWMutex
|
||||
stopCleanup chan struct{}
|
||||
}
|
||||
|
||||
// Session represents a user session
|
||||
@@ -35,41 +57,53 @@ type Session struct {
|
||||
}
|
||||
|
||||
// NewAuth creates a new auth instance
|
||||
func NewAuth(db *sql.DB) (*Auth, error) {
|
||||
func NewAuth(db *database.DB, cfg *config.Config) (*Auth, error) {
|
||||
auth := &Auth{
|
||||
db: db,
|
||||
sessionStore: make(map[string]*Session),
|
||||
cfg: cfg,
|
||||
sessionCache: make(map[string]*Session),
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Initialize Google OAuth
|
||||
googleClientID := os.Getenv("GOOGLE_CLIENT_ID")
|
||||
googleClientSecret := os.Getenv("GOOGLE_CLIENT_SECRET")
|
||||
// Initialize Google OAuth from database config
|
||||
googleClientID := cfg.GoogleClientID()
|
||||
googleClientSecret := cfg.GoogleClientSecret()
|
||||
if googleClientID != "" && googleClientSecret != "" {
|
||||
auth.googleConfig = &oauth2.Config{
|
||||
ClientID: googleClientID,
|
||||
ClientSecret: googleClientSecret,
|
||||
RedirectURL: os.Getenv("GOOGLE_REDIRECT_URL"),
|
||||
RedirectURL: cfg.GoogleRedirectURL(),
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
log.Printf("Google OAuth configured")
|
||||
}
|
||||
|
||||
// Initialize Discord OAuth
|
||||
discordClientID := os.Getenv("DISCORD_CLIENT_ID")
|
||||
discordClientSecret := os.Getenv("DISCORD_CLIENT_SECRET")
|
||||
// Initialize Discord OAuth from database config
|
||||
discordClientID := cfg.DiscordClientID()
|
||||
discordClientSecret := cfg.DiscordClientSecret()
|
||||
if discordClientID != "" && discordClientSecret != "" {
|
||||
auth.discordConfig = &oauth2.Config{
|
||||
ClientID: discordClientID,
|
||||
ClientSecret: discordClientSecret,
|
||||
RedirectURL: os.Getenv("DISCORD_REDIRECT_URL"),
|
||||
RedirectURL: cfg.DiscordRedirectURL(),
|
||||
Scopes: []string{"identify", "email"},
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://discord.com/api/oauth2/authorize",
|
||||
TokenURL: "https://discord.com/api/oauth2/token",
|
||||
},
|
||||
}
|
||||
log.Printf("Discord OAuth configured")
|
||||
}
|
||||
|
||||
// Load existing sessions from database into cache
|
||||
if err := auth.loadSessionsFromDB(); err != nil {
|
||||
log.Printf("Warning: Failed to load sessions from database: %v", err)
|
||||
}
|
||||
|
||||
// Start background cleanup goroutine
|
||||
go auth.cleanupExpiredSessions()
|
||||
|
||||
// Initialize admin settings on startup to ensure they persist between boots
|
||||
if err := auth.initializeSettings(); err != nil {
|
||||
log.Printf("Warning: Failed to initialize admin settings: %v", err)
|
||||
@@ -85,19 +119,119 @@ func NewAuth(db *sql.DB) (*Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// Close stops background goroutines
|
||||
func (a *Auth) Close() {
|
||||
close(a.stopCleanup)
|
||||
}
|
||||
|
||||
// loadSessionsFromDB loads all valid sessions from database into cache
|
||||
func (a *Auth) loadSessionsFromDB() error {
|
||||
var sessions []struct {
|
||||
sessionID string
|
||||
session Session
|
||||
}
|
||||
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
rows, err := conn.Query(
|
||||
`SELECT session_id, user_id, email, name, is_admin, expires_at
|
||||
FROM sessions WHERE expires_at > CURRENT_TIMESTAMP`,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query sessions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var s struct {
|
||||
sessionID string
|
||||
session Session
|
||||
}
|
||||
err := rows.Scan(&s.sessionID, &s.session.UserID, &s.session.Email, &s.session.Name, &s.session.IsAdmin, &s.session.ExpiresAt)
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to scan session row: %v", err)
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, s)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a.cacheMu.Lock()
|
||||
defer a.cacheMu.Unlock()
|
||||
|
||||
for _, s := range sessions {
|
||||
a.sessionCache[s.sessionID] = &s.session
|
||||
}
|
||||
|
||||
if len(sessions) > 0 {
|
||||
log.Printf("Loaded %d active sessions from database", len(sessions))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupExpiredSessions periodically removes expired sessions from database and cache
|
||||
func (a *Auth) cleanupExpiredSessions() {
|
||||
ticker := time.NewTicker(SessionCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// Delete expired sessions from database
|
||||
var deleted int64
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
result, err := conn.Exec(`DELETE FROM sessions WHERE expires_at < CURRENT_TIMESTAMP`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
deleted, _ = result.RowsAffected()
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to cleanup expired sessions: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Clean up cache
|
||||
a.cacheMu.Lock()
|
||||
now := time.Now()
|
||||
for sessionID, session := range a.sessionCache {
|
||||
if now.After(session.ExpiresAt) {
|
||||
delete(a.sessionCache, sessionID)
|
||||
}
|
||||
}
|
||||
a.cacheMu.Unlock()
|
||||
|
||||
if deleted > 0 {
|
||||
log.Printf("Cleaned up %d expired sessions", deleted)
|
||||
}
|
||||
case <-a.stopCleanup:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// initializeSettings ensures all admin settings are initialized with defaults if they don't exist
|
||||
func (a *Auth) initializeSettings() error {
|
||||
// Initialize registration_enabled setting (default: true) if it doesn't exist
|
||||
var settingCount int
|
||||
err := a.db.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check registration_enabled setting: %w", err)
|
||||
}
|
||||
if settingCount == 0 {
|
||||
_, err = a.db.Exec(
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec(
|
||||
`INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)`,
|
||||
"registration_enabled", "true",
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize registration_enabled setting: %w", err)
|
||||
}
|
||||
@@ -118,7 +252,9 @@ func (a *Auth) initializeTestUser() error {
|
||||
|
||||
// Check if user already exists
|
||||
var exists bool
|
||||
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND oauth_provider = 'local')", testEmail).Scan(&exists)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND oauth_provider = 'local')", testEmail).Scan(&exists)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if test user exists: %w", err)
|
||||
}
|
||||
@@ -137,7 +273,12 @@ func (a *Auth) initializeTestUser() error {
|
||||
|
||||
// Check if this is the first user (make them admin)
|
||||
var userCount int
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check user count: %w", err)
|
||||
}
|
||||
isAdmin := userCount == 0
|
||||
|
||||
// Create test user (use email as name if no name is provided)
|
||||
@@ -147,10 +288,13 @@ func (a *Auth) initializeTestUser() error {
|
||||
}
|
||||
|
||||
// Create test user
|
||||
_, err = a.db.Exec(
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?)",
|
||||
testEmail, testName, testEmail, string(hashedPassword), isAdmin,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create test user: %w", err)
|
||||
}
|
||||
@@ -242,7 +386,9 @@ func (a *Auth) DiscordCallback(ctx context.Context, code string) (*Session, erro
|
||||
// IsRegistrationEnabled checks if new user registration is enabled
|
||||
func (a *Auth) IsRegistrationEnabled() (bool, error) {
|
||||
var value string
|
||||
err := a.db.QueryRow("SELECT value FROM settings WHERE key = ?", "registration_enabled").Scan(&value)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT value FROM settings WHERE key = ?", "registration_enabled").Scan(&value)
|
||||
})
|
||||
if err == sql.ErrNoRows {
|
||||
// Default to enabled if setting doesn't exist
|
||||
return true, nil
|
||||
@@ -262,29 +408,31 @@ func (a *Auth) SetRegistrationEnabled(enabled bool) error {
|
||||
|
||||
// Check if setting exists
|
||||
var exists bool
|
||||
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", "registration_enabled").Scan(&exists)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", "registration_enabled").Scan(&exists)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if setting exists: %w", err)
|
||||
}
|
||||
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
if exists {
|
||||
// Update existing setting
|
||||
_, err = a.db.Exec(
|
||||
_, err = conn.Exec(
|
||||
"UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?",
|
||||
value, "registration_enabled",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update setting: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Insert new setting
|
||||
_, err = a.db.Exec(
|
||||
_, err = conn.Exec(
|
||||
"INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)",
|
||||
"registration_enabled", value,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert setting: %w", err)
|
||||
}
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set registration_enabled: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -299,17 +447,21 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
|
||||
var dbProvider, dbOAuthID string
|
||||
|
||||
// First, try to find by provider + oauth_id
|
||||
err := a.db.QueryRow(
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow(
|
||||
"SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE oauth_provider = ? AND oauth_id = ?",
|
||||
provider, oauthID,
|
||||
).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID)
|
||||
})
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
// Not found by provider+oauth_id, check by email for account linking
|
||||
err = a.db.QueryRow(
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow(
|
||||
"SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE email = ?",
|
||||
email,
|
||||
).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID)
|
||||
})
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
// User doesn't exist, check if registration is enabled
|
||||
@@ -323,14 +475,26 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
|
||||
|
||||
// Check if this is the first user
|
||||
var userCount int
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check user count: %w", err)
|
||||
}
|
||||
isAdmin = userCount == 0
|
||||
|
||||
// Create new user
|
||||
err = a.db.QueryRow(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?) RETURNING id",
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
result, err := conn.Exec(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?)",
|
||||
email, name, provider, oauthID, isAdmin,
|
||||
).Scan(&userID)
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userID, err = result.LastInsertId()
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
@@ -339,10 +503,13 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
|
||||
} else {
|
||||
// User exists with same email but different provider - link accounts by updating provider info
|
||||
// This allows the user to log in with any provider that has the same email
|
||||
_, err = a.db.Exec(
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
_, err = conn.Exec(
|
||||
"UPDATE users SET oauth_provider = ?, oauth_id = ?, name = ? WHERE id = ?",
|
||||
provider, oauthID, name, userID,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to link account: %w", err)
|
||||
}
|
||||
@@ -353,10 +520,13 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
|
||||
} else {
|
||||
// User found by provider+oauth_id, update info if changed
|
||||
if dbEmail != email || dbName != name {
|
||||
_, err = a.db.Exec(
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
_, err = conn.Exec(
|
||||
"UPDATE users SET email = ?, name = ? WHERE id = ?",
|
||||
email, name, userID,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
@@ -368,41 +538,134 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
|
||||
Email: email,
|
||||
Name: name,
|
||||
IsAdmin: isAdmin,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
ExpiresAt: time.Now().Add(SessionDuration),
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// CreateSession creates a new session and returns a session ID
|
||||
// Sessions are persisted to database and cached in memory
|
||||
func (a *Auth) CreateSession(session *Session) string {
|
||||
sessionID := uuid.New().String()
|
||||
a.sessionStore[sessionID] = session
|
||||
|
||||
// Store in database first
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec(
|
||||
`INSERT INTO sessions (session_id, user_id, email, name, is_admin, expires_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
sessionID, session.UserID, session.Email, session.Name, session.IsAdmin, session.ExpiresAt,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to persist session to database: %v", err)
|
||||
// Continue anyway - session will work from cache but won't survive restart
|
||||
}
|
||||
|
||||
// Store in cache
|
||||
a.cacheMu.Lock()
|
||||
a.sessionCache[sessionID] = session
|
||||
a.cacheMu.Unlock()
|
||||
|
||||
return sessionID
|
||||
}
|
||||
|
||||
// GetSession retrieves a session by ID
|
||||
// First checks cache, then database if not found
|
||||
func (a *Auth) GetSession(sessionID string) (*Session, bool) {
|
||||
session, ok := a.sessionStore[sessionID]
|
||||
if !ok {
|
||||
// Check cache first
|
||||
a.cacheMu.RLock()
|
||||
session, ok := a.sessionCache[sessionID]
|
||||
a.cacheMu.RUnlock()
|
||||
|
||||
if ok {
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
a.DeleteSession(sessionID)
|
||||
return nil, false
|
||||
}
|
||||
// Refresh admin status from database
|
||||
var isAdmin bool
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin)
|
||||
})
|
||||
if err == nil {
|
||||
session.IsAdmin = isAdmin
|
||||
}
|
||||
return session, true
|
||||
}
|
||||
|
||||
// Not in cache, check database
|
||||
session = &Session{}
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow(
|
||||
`SELECT user_id, email, name, is_admin, expires_at
|
||||
FROM sessions WHERE session_id = ?`,
|
||||
sessionID,
|
||||
).Scan(&session.UserID, &session.Email, &session.Name, &session.IsAdmin, &session.ExpiresAt)
|
||||
})
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, false
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to query session from database: %v", err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
delete(a.sessionStore, sessionID)
|
||||
a.DeleteSession(sessionID)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Refresh admin status from database
|
||||
var isAdmin bool
|
||||
err := a.db.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin)
|
||||
})
|
||||
if err == nil {
|
||||
session.IsAdmin = isAdmin
|
||||
}
|
||||
|
||||
// Add to cache
|
||||
a.cacheMu.Lock()
|
||||
a.sessionCache[sessionID] = session
|
||||
a.cacheMu.Unlock()
|
||||
|
||||
return session, true
|
||||
}
|
||||
|
||||
// DeleteSession deletes a session
|
||||
// DeleteSession deletes a session from both cache and database
|
||||
func (a *Auth) DeleteSession(sessionID string) {
|
||||
delete(a.sessionStore, sessionID)
|
||||
// Delete from cache
|
||||
a.cacheMu.Lock()
|
||||
delete(a.sessionCache, sessionID)
|
||||
a.cacheMu.Unlock()
|
||||
|
||||
// Delete from database
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("DELETE FROM sessions WHERE session_id = ?", sessionID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to delete session from database: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// IsProductionMode returns true if running in production mode
|
||||
// This is a package-level function that checks the environment variable
|
||||
// For config-based checks, use Config.IsProductionMode()
|
||||
func IsProductionMode() bool {
|
||||
// Check environment variable first for backwards compatibility
|
||||
if os.Getenv("PRODUCTION") == "true" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsProductionModeFromConfig returns true if production mode is enabled in config
|
||||
func (a *Auth) IsProductionModeFromConfig() bool {
|
||||
return a.cfg.IsProductionMode()
|
||||
}
|
||||
|
||||
// Middleware creates an authentication middleware
|
||||
@@ -426,24 +689,24 @@ func (a *Auth) Middleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// Add user info to request context
|
||||
ctx := context.WithValue(r.Context(), "user_id", session.UserID)
|
||||
ctx = context.WithValue(ctx, "user_email", session.Email)
|
||||
ctx = context.WithValue(ctx, "user_name", session.Name)
|
||||
ctx = context.WithValue(ctx, "is_admin", session.IsAdmin)
|
||||
// Add user info to request context using typed keys
|
||||
ctx := context.WithValue(r.Context(), contextKeyUserID, session.UserID)
|
||||
ctx = context.WithValue(ctx, contextKeyUserEmail, session.Email)
|
||||
ctx = context.WithValue(ctx, contextKeyUserName, session.Name)
|
||||
ctx = context.WithValue(ctx, contextKeyIsAdmin, session.IsAdmin)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserID gets the user ID from context
|
||||
func GetUserID(ctx context.Context) (int64, bool) {
|
||||
userID, ok := ctx.Value("user_id").(int64)
|
||||
userID, ok := ctx.Value(contextKeyUserID).(int64)
|
||||
return userID, ok
|
||||
}
|
||||
|
||||
// IsAdmin checks if the user in context is an admin
|
||||
func IsAdmin(ctx context.Context) bool {
|
||||
isAdmin, ok := ctx.Value("is_admin").(bool)
|
||||
isAdmin, ok := ctx.Value(contextKeyIsAdmin).(bool)
|
||||
return ok && isAdmin
|
||||
}
|
||||
|
||||
@@ -478,19 +741,19 @@ func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// Add user info to request context
|
||||
ctx := context.WithValue(r.Context(), "user_id", session.UserID)
|
||||
ctx = context.WithValue(ctx, "user_email", session.Email)
|
||||
ctx = context.WithValue(ctx, "user_name", session.Name)
|
||||
ctx = context.WithValue(ctx, "is_admin", session.IsAdmin)
|
||||
// Add user info to request context using typed keys
|
||||
ctx := context.WithValue(r.Context(), contextKeyUserID, session.UserID)
|
||||
ctx = context.WithValue(ctx, contextKeyUserEmail, session.Email)
|
||||
ctx = context.WithValue(ctx, contextKeyUserName, session.Name)
|
||||
ctx = context.WithValue(ctx, contextKeyIsAdmin, session.IsAdmin)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// IsLocalLoginEnabled returns whether local login is enabled
|
||||
// Local login is enabled when ENABLE_LOCAL_AUTH environment variable is set to "true"
|
||||
// Checks database config first, falls back to environment variable
|
||||
func (a *Auth) IsLocalLoginEnabled() bool {
|
||||
return os.Getenv("ENABLE_LOCAL_AUTH") == "true"
|
||||
return a.cfg.IsLocalAuthEnabled()
|
||||
}
|
||||
|
||||
// IsGoogleOAuthConfigured returns whether Google OAuth is configured
|
||||
@@ -511,10 +774,12 @@ func (a *Auth) LocalLogin(username, password string) (*Session, error) {
|
||||
var dbEmail, dbName, passwordHash string
|
||||
var isAdmin bool
|
||||
|
||||
err := a.db.QueryRow(
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow(
|
||||
"SELECT id, email, name, password_hash, is_admin FROM users WHERE email = ? AND oauth_provider = 'local'",
|
||||
email,
|
||||
).Scan(&userID, &dbEmail, &dbName, &passwordHash, &isAdmin)
|
||||
})
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("invalid credentials")
|
||||
@@ -539,7 +804,7 @@ func (a *Auth) LocalLogin(username, password string) (*Session, error) {
|
||||
Email: dbEmail,
|
||||
Name: dbName,
|
||||
IsAdmin: isAdmin,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
ExpiresAt: time.Now().Add(SessionDuration),
|
||||
}
|
||||
|
||||
return session, nil
|
||||
@@ -558,7 +823,9 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error)
|
||||
|
||||
// Check if user already exists
|
||||
var exists bool
|
||||
err = a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ?)", email).Scan(&exists)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ?)", email).Scan(&exists)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if user exists: %w", err)
|
||||
}
|
||||
@@ -574,15 +841,27 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error)
|
||||
|
||||
// Check if this is the first user (make them admin)
|
||||
var userCount int
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check user count: %w", err)
|
||||
}
|
||||
isAdmin := userCount == 0
|
||||
|
||||
// Create user
|
||||
var userID int64
|
||||
err = a.db.QueryRow(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?) RETURNING id",
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
result, err := conn.Exec(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?)",
|
||||
email, name, email, string(hashedPassword), isAdmin,
|
||||
).Scan(&userID)
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userID, err = result.LastInsertId()
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
@@ -593,7 +872,7 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error)
|
||||
Email: email,
|
||||
Name: name,
|
||||
IsAdmin: isAdmin,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
ExpiresAt: time.Now().Add(SessionDuration),
|
||||
}
|
||||
|
||||
return session, nil
|
||||
@@ -603,7 +882,9 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error)
|
||||
func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) error {
|
||||
// Get current password hash
|
||||
var passwordHash string
|
||||
err := a.db.QueryRow("SELECT password_hash FROM users WHERE id = ? AND oauth_provider = 'local'", userID).Scan(&passwordHash)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT password_hash FROM users WHERE id = ? AND oauth_provider = 'local'", userID).Scan(&passwordHash)
|
||||
})
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("user not found or not a local user")
|
||||
}
|
||||
@@ -628,7 +909,10 @@ func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) err
|
||||
}
|
||||
|
||||
// Update password
|
||||
_, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), userID)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), userID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
@@ -640,7 +924,9 @@ func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) err
|
||||
func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error {
|
||||
// Verify user exists and is a local user
|
||||
var exists bool
|
||||
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ? AND oauth_provider = 'local')", targetUserID).Scan(&exists)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ? AND oauth_provider = 'local')", targetUserID).Scan(&exists)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if user exists: %w", err)
|
||||
}
|
||||
@@ -655,7 +941,10 @@ func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error
|
||||
}
|
||||
|
||||
// Update password
|
||||
_, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), targetUserID)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), targetUserID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
@@ -666,7 +955,9 @@ func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error
|
||||
// GetFirstUserID returns the ID of the first user (user with the lowest ID)
|
||||
func (a *Auth) GetFirstUserID() (int64, error) {
|
||||
var firstUserID int64
|
||||
err := a.db.QueryRow("SELECT id FROM users ORDER BY id ASC LIMIT 1").Scan(&firstUserID)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT id FROM users ORDER BY id ASC LIMIT 1").Scan(&firstUserID)
|
||||
})
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, fmt.Errorf("no users found")
|
||||
}
|
||||
@@ -680,7 +971,9 @@ func (a *Auth) GetFirstUserID() (int64, error) {
|
||||
func (a *Auth) SetUserAdminStatus(targetUserID int64, isAdmin bool) error {
|
||||
// Verify user exists
|
||||
var exists bool
|
||||
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", targetUserID).Scan(&exists)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", targetUserID).Scan(&exists)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if user exists: %w", err)
|
||||
}
|
||||
@@ -698,7 +991,10 @@ func (a *Auth) SetUserAdminStatus(targetUserID int64, isAdmin bool) error {
|
||||
}
|
||||
|
||||
// Update admin status
|
||||
_, err = a.db.Exec("UPDATE users SET is_admin = ? WHERE id = ?", isAdmin, targetUserID)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("UPDATE users SET is_admin = ? WHERE id = ?", isAdmin, targetUserID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update admin status: %w", err)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,8 @@ import (
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"jiggablend/internal/config"
|
||||
"jiggablend/internal/database"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -14,21 +15,14 @@ import (
|
||||
|
||||
// Secrets handles API key management
|
||||
type Secrets struct {
|
||||
db *sql.DB
|
||||
db *database.DB
|
||||
cfg *config.Config
|
||||
RegistrationMu sync.Mutex // Protects concurrent runner registrations
|
||||
fixedAPIKey string // Fixed API key from environment variable (optional)
|
||||
}
|
||||
|
||||
// NewSecrets creates a new secrets manager
|
||||
func NewSecrets(db *sql.DB) (*Secrets, error) {
|
||||
s := &Secrets{db: db}
|
||||
|
||||
// Check for fixed API key from environment
|
||||
if fixedKey := os.Getenv("FIXED_API_KEY"); fixedKey != "" {
|
||||
s.fixedAPIKey = fixedKey
|
||||
}
|
||||
|
||||
return s, nil
|
||||
func NewSecrets(db *database.DB, cfg *config.Config) (*Secrets, error) {
|
||||
return &Secrets{db: db, cfg: cfg}, nil
|
||||
}
|
||||
|
||||
// APIKeyInfo represents information about an API key
|
||||
@@ -61,25 +55,34 @@ func (s *Secrets) GenerateRunnerAPIKey(createdBy int64, name, description string
|
||||
keyHash := sha256.Sum256([]byte(key))
|
||||
keyHashStr := hex.EncodeToString(keyHash[:])
|
||||
|
||||
_, err = s.db.Exec(
|
||||
var keyInfo APIKeyInfo
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
result, err := conn.Exec(
|
||||
`INSERT INTO runner_api_keys (key_prefix, key_hash, name, description, scope, is_active, created_by)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
keyPrefix, keyHashStr, name, description, scope, true, createdBy,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to store API key: %w", err)
|
||||
return fmt.Errorf("failed to store API key: %w", err)
|
||||
}
|
||||
|
||||
keyID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get inserted key ID: %w", err)
|
||||
}
|
||||
|
||||
// Get the inserted key info
|
||||
var keyInfo APIKeyInfo
|
||||
err = s.db.QueryRow(
|
||||
err = conn.QueryRow(
|
||||
`SELECT id, name, description, scope, is_active, created_at, created_by
|
||||
FROM runner_api_keys WHERE key_prefix = ?`,
|
||||
keyPrefix,
|
||||
FROM runner_api_keys WHERE id = ?`,
|
||||
keyID,
|
||||
).Scan(&keyInfo.ID, &keyInfo.Name, &keyInfo.Description, &keyInfo.Scope, &keyInfo.IsActive, &keyInfo.CreatedAt, &keyInfo.CreatedBy)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve created API key: %w", err)
|
||||
return nil, fmt.Errorf("failed to create API key: %w", err)
|
||||
}
|
||||
|
||||
keyInfo.Key = key
|
||||
@@ -91,18 +94,25 @@ func (s *Secrets) generateAPIKey() (string, error) {
|
||||
// Generate random suffix
|
||||
randomBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
randomStr := hex.EncodeToString(randomBytes)
|
||||
|
||||
// Generate a unique prefix (jk_r followed by 1 random digit)
|
||||
prefixDigit := make([]byte, 1)
|
||||
if _, err := rand.Read(prefixDigit); err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to generate prefix digit: %w", err)
|
||||
}
|
||||
|
||||
prefix := fmt.Sprintf("jk_r%d", prefixDigit[0]%10)
|
||||
return fmt.Sprintf("%s_%s", prefix, randomStr), nil
|
||||
key := fmt.Sprintf("%s_%s", prefix, randomStr)
|
||||
|
||||
// Validate generated key format
|
||||
if !strings.HasPrefix(key, "jk_r") {
|
||||
return "", fmt.Errorf("generated invalid API key format: %s", key)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// ValidateRunnerAPIKey validates an API key and returns the key ID and scope if valid
|
||||
@@ -111,8 +121,9 @@ func (s *Secrets) ValidateRunnerAPIKey(apiKey string) (int64, string, error) {
|
||||
return 0, "", fmt.Errorf("API key is required")
|
||||
}
|
||||
|
||||
// Check fixed API key first (for testing/development)
|
||||
if s.fixedAPIKey != "" && apiKey == s.fixedAPIKey {
|
||||
// Check fixed API key first (from database config)
|
||||
fixedKey := s.cfg.FixedAPIKey()
|
||||
if fixedKey != "" && apiKey == fixedKey {
|
||||
// Return a special ID for fixed API key (doesn't exist in database)
|
||||
return -1, "manager", nil
|
||||
}
|
||||
@@ -137,42 +148,50 @@ func (s *Secrets) ValidateRunnerAPIKey(apiKey string) (int64, string, error) {
|
||||
var scope string
|
||||
var isActive bool
|
||||
|
||||
err := s.db.QueryRow(
|
||||
err := s.db.With(func(conn *sql.DB) error {
|
||||
err := conn.QueryRow(
|
||||
`SELECT id, scope, is_active FROM runner_api_keys
|
||||
WHERE key_prefix = ? AND key_hash = ?`,
|
||||
keyPrefix, keyHashStr,
|
||||
).Scan(&keyID, &scope, &isActive)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, "", fmt.Errorf("API key not found or invalid - please check that the key is correct and active")
|
||||
return fmt.Errorf("API key not found or invalid - please check that the key is correct and active")
|
||||
}
|
||||
if err != nil {
|
||||
return 0, "", fmt.Errorf("failed to validate API key: %w", err)
|
||||
return fmt.Errorf("failed to validate API key: %w", err)
|
||||
}
|
||||
|
||||
if !isActive {
|
||||
return 0, "", fmt.Errorf("API key is inactive")
|
||||
return fmt.Errorf("API key is inactive")
|
||||
}
|
||||
|
||||
// Update last_used_at (don't fail if this update fails)
|
||||
s.db.Exec(`UPDATE runner_api_keys SET last_used_at = ? WHERE id = ?`, time.Now(), keyID)
|
||||
conn.Exec(`UPDATE runner_api_keys SET last_used_at = ? WHERE id = ?`, time.Now(), keyID)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
|
||||
return keyID, scope, nil
|
||||
}
|
||||
|
||||
// ListRunnerAPIKeys lists all runner API keys
|
||||
func (s *Secrets) ListRunnerAPIKeys() ([]APIKeyInfo, error) {
|
||||
rows, err := s.db.Query(
|
||||
var keys []APIKeyInfo
|
||||
err := s.db.With(func(conn *sql.DB) error {
|
||||
rows, err := conn.Query(
|
||||
`SELECT id, key_prefix, name, description, scope, is_active, created_at, created_by
|
||||
FROM runner_api_keys
|
||||
ORDER BY created_at DESC`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query API keys: %w", err)
|
||||
return fmt.Errorf("failed to query API keys: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var keys []APIKeyInfo
|
||||
for rows.Next() {
|
||||
var key APIKeyInfo
|
||||
var description sql.NullString
|
||||
@@ -188,20 +207,28 @@ func (s *Secrets) ListRunnerAPIKeys() ([]APIKeyInfo, error) {
|
||||
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// RevokeRunnerAPIKey revokes (deactivates) a runner API key
|
||||
func (s *Secrets) RevokeRunnerAPIKey(keyID int64) error {
|
||||
_, err := s.db.Exec("UPDATE runner_api_keys SET is_active = false WHERE id = ?", keyID)
|
||||
return s.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("UPDATE runner_api_keys SET is_active = false WHERE id = ?", keyID)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteRunnerAPIKey deletes a runner API key
|
||||
func (s *Secrets) DeleteRunnerAPIKey(keyID int64) error {
|
||||
_, err := s.db.Exec("DELETE FROM runner_api_keys WHERE id = ?", keyID)
|
||||
return s.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("DELETE FROM runner_api_keys WHERE id = ?", keyID)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user