Files
jiggablend/internal/auth/auth.go

708 lines
22 KiB
Go

package auth
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"strings"
"time"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
// Auth handles authentication
type Auth struct {
db *sql.DB
googleConfig *oauth2.Config
discordConfig *oauth2.Config
sessionStore map[string]*Session
}
// Session represents a user session
type Session struct {
UserID int64
Email string
Name string
IsAdmin bool
ExpiresAt time.Time
}
// NewAuth creates a new auth instance
func NewAuth(db *sql.DB) (*Auth, error) {
auth := &Auth{
db: db,
sessionStore: make(map[string]*Session),
}
// Initialize Google OAuth
googleClientID := os.Getenv("GOOGLE_CLIENT_ID")
googleClientSecret := os.Getenv("GOOGLE_CLIENT_SECRET")
if googleClientID != "" && googleClientSecret != "" {
auth.googleConfig = &oauth2.Config{
ClientID: googleClientID,
ClientSecret: googleClientSecret,
RedirectURL: os.Getenv("GOOGLE_REDIRECT_URL"),
Scopes: []string{"openid", "profile", "email"},
Endpoint: google.Endpoint,
}
}
// Initialize Discord OAuth
discordClientID := os.Getenv("DISCORD_CLIENT_ID")
discordClientSecret := os.Getenv("DISCORD_CLIENT_SECRET")
if discordClientID != "" && discordClientSecret != "" {
auth.discordConfig = &oauth2.Config{
ClientID: discordClientID,
ClientSecret: discordClientSecret,
RedirectURL: os.Getenv("DISCORD_REDIRECT_URL"),
Scopes: []string{"identify", "email"},
Endpoint: oauth2.Endpoint{
AuthURL: "https://discord.com/api/oauth2/authorize",
TokenURL: "https://discord.com/api/oauth2/token",
},
}
}
// Initialize admin settings on startup to ensure they persist between boots
if err := auth.initializeSettings(); err != nil {
log.Printf("Warning: Failed to initialize admin settings: %v", err)
// Don't fail startup, but log the warning
}
// Initialize test local user from environment variables (for testing only)
if err := auth.initializeTestUser(); err != nil {
log.Printf("Warning: Failed to initialize test user: %v", err)
// Don't fail startup, but log the warning
}
return auth, nil
}
// initializeSettings ensures all admin settings are initialized with defaults if they don't exist
func (a *Auth) initializeSettings() error {
// Initialize registration_enabled setting (default: true) if it doesn't exist
var settingCount int
err := a.db.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
if err != nil {
return fmt.Errorf("failed to check registration_enabled setting: %w", err)
}
if settingCount == 0 {
_, err = a.db.Exec(
`INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)`,
"registration_enabled", "true",
)
if err != nil {
return fmt.Errorf("failed to initialize registration_enabled setting: %w", err)
}
log.Printf("Initialized admin setting: registration_enabled = true")
}
return nil
}
// initializeTestUser creates a test local user from environment variables (for testing only)
func (a *Auth) initializeTestUser() error {
testEmail := os.Getenv("LOCAL_TEST_EMAIL")
testPassword := os.Getenv("LOCAL_TEST_PASSWORD")
if testEmail == "" || testPassword == "" {
// No test user configured, skip
return nil
}
// Check if user already exists
var exists bool
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND oauth_provider = 'local')", testEmail).Scan(&exists)
if err != nil {
return fmt.Errorf("failed to check if test user exists: %w", err)
}
if exists {
// User already exists, skip creation
log.Printf("Test user %s already exists, skipping creation", testEmail)
return nil
}
// Hash password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(testPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash test user password: %w", err)
}
// Check if this is the first user (make them admin)
var userCount int
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
isAdmin := userCount == 0
// Create test user (use email as name if no name is provided)
testName := testEmail
if atIndex := strings.Index(testEmail, "@"); atIndex > 0 {
testName = testEmail[:atIndex]
}
// Create test user
_, err = a.db.Exec(
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?)",
testEmail, testName, testEmail, string(hashedPassword), isAdmin,
)
if err != nil {
return fmt.Errorf("failed to create test user: %w", err)
}
log.Printf("Created test user: %s (admin: %v)", testEmail, isAdmin)
return nil
}
// GoogleLoginURL returns the Google OAuth login URL
func (a *Auth) GoogleLoginURL() (string, error) {
if a.googleConfig == nil {
return "", fmt.Errorf("Google OAuth not configured")
}
state := uuid.New().String()
return a.googleConfig.AuthCodeURL(state), nil
}
// DiscordLoginURL returns the Discord OAuth login URL
func (a *Auth) DiscordLoginURL() (string, error) {
if a.discordConfig == nil {
return "", fmt.Errorf("Discord OAuth not configured")
}
state := uuid.New().String()
return a.discordConfig.AuthCodeURL(state), nil
}
// GoogleCallback handles Google OAuth callback
func (a *Auth) GoogleCallback(ctx context.Context, code string) (*Session, error) {
if a.googleConfig == nil {
return nil, fmt.Errorf("Google OAuth not configured")
}
token, err := a.googleConfig.Exchange(ctx, code)
if err != nil {
return nil, fmt.Errorf("failed to exchange token: %w", err)
}
client := a.googleConfig.Client(ctx, token)
resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo")
if err != nil {
return nil, fmt.Errorf("failed to get user info: %w", err)
}
defer resp.Body.Close()
var userInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
}
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
return nil, fmt.Errorf("failed to decode user info: %w", err)
}
return a.getOrCreateUser("google", userInfo.ID, userInfo.Email, userInfo.Name)
}
// DiscordCallback handles Discord OAuth callback
func (a *Auth) DiscordCallback(ctx context.Context, code string) (*Session, error) {
if a.discordConfig == nil {
return nil, fmt.Errorf("Discord OAuth not configured")
}
token, err := a.discordConfig.Exchange(ctx, code)
if err != nil {
return nil, fmt.Errorf("failed to exchange token: %w", err)
}
client := a.discordConfig.Client(ctx, token)
resp, err := client.Get("https://discord.com/api/users/@me")
if err != nil {
return nil, fmt.Errorf("failed to get user info: %w", err)
}
defer resp.Body.Close()
var userInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
}
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
return nil, fmt.Errorf("failed to decode user info: %w", err)
}
return a.getOrCreateUser("discord", userInfo.ID, userInfo.Email, userInfo.Username)
}
// IsRegistrationEnabled checks if new user registration is enabled
func (a *Auth) IsRegistrationEnabled() (bool, error) {
var value string
err := a.db.QueryRow("SELECT value FROM settings WHERE key = ?", "registration_enabled").Scan(&value)
if err == sql.ErrNoRows {
// Default to enabled if setting doesn't exist
return true, nil
}
if err != nil {
return false, fmt.Errorf("failed to check registration setting: %w", err)
}
return value == "true", nil
}
// SetRegistrationEnabled sets whether new user registration is enabled
func (a *Auth) SetRegistrationEnabled(enabled bool) error {
value := "false"
if enabled {
value = "true"
}
// Check if setting exists
var exists bool
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", "registration_enabled").Scan(&exists)
if err != nil {
return fmt.Errorf("failed to check if setting exists: %w", err)
}
if exists {
// Update existing setting
_, err = a.db.Exec(
"UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?",
value, "registration_enabled",
)
if err != nil {
return fmt.Errorf("failed to update setting: %w", err)
}
} else {
// Insert new setting
_, err = a.db.Exec(
"INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)",
"registration_enabled", value,
)
if err != nil {
return fmt.Errorf("failed to insert setting: %w", err)
}
}
return nil
}
// getOrCreateUser gets or creates a user in the database
// Automatically links accounts by email across different OAuth providers and local login
func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session, error) {
var userID int64
var dbEmail, dbName string
var isAdmin bool
var dbProvider, dbOAuthID string
// First, try to find by provider + oauth_id
err := a.db.QueryRow(
"SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE oauth_provider = ? AND oauth_id = ?",
provider, oauthID,
).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID)
if err == sql.ErrNoRows {
// Not found by provider+oauth_id, check by email for account linking
err = a.db.QueryRow(
"SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE email = ?",
email,
).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID)
if err == sql.ErrNoRows {
// User doesn't exist, check if registration is enabled
registrationEnabled, err := a.IsRegistrationEnabled()
if err != nil {
return nil, fmt.Errorf("failed to check registration setting: %w", err)
}
if !registrationEnabled {
return nil, fmt.Errorf("registration is disabled")
}
// Check if this is the first user
var userCount int
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
isAdmin = userCount == 0
// Create new user
err = a.db.QueryRow(
"INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?) RETURNING id",
email, name, provider, oauthID, isAdmin,
).Scan(&userID)
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
} else if err != nil {
return nil, fmt.Errorf("failed to query user by email: %w", err)
} else {
// User exists with same email but different provider - link accounts by updating provider info
// This allows the user to log in with any provider that has the same email
_, err = a.db.Exec(
"UPDATE users SET oauth_provider = ?, oauth_id = ?, name = ? WHERE id = ?",
provider, oauthID, name, userID,
)
if err != nil {
return nil, fmt.Errorf("failed to link account: %w", err)
}
log.Printf("Linked account: user %d (email: %s) now accessible via %s provider", userID, email, provider)
}
} else if err != nil {
return nil, fmt.Errorf("failed to query user: %w", err)
} else {
// User found by provider+oauth_id, update info if changed
if dbEmail != email || dbName != name {
_, err = a.db.Exec(
"UPDATE users SET email = ?, name = ? WHERE id = ?",
email, name, userID,
)
if err != nil {
return nil, fmt.Errorf("failed to update user: %w", err)
}
}
}
session := &Session{
UserID: userID,
Email: email,
Name: name,
IsAdmin: isAdmin,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
return session, nil
}
// CreateSession creates a new session and returns a session ID
func (a *Auth) CreateSession(session *Session) string {
sessionID := uuid.New().String()
a.sessionStore[sessionID] = session
return sessionID
}
// GetSession retrieves a session by ID
func (a *Auth) GetSession(sessionID string) (*Session, bool) {
session, ok := a.sessionStore[sessionID]
if !ok {
return nil, false
}
if time.Now().After(session.ExpiresAt) {
delete(a.sessionStore, sessionID)
return nil, false
}
// Refresh admin status from database
var isAdmin bool
err := a.db.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin)
if err == nil {
session.IsAdmin = isAdmin
}
return session, true
}
// DeleteSession deletes a session
func (a *Auth) DeleteSession(sessionID string) {
delete(a.sessionStore, sessionID)
}
// Middleware creates an authentication middleware
func (a *Auth) Middleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("session_id")
if err != nil {
log.Printf("Authentication failed: missing session cookie for %s %s", r.Method, r.URL.Path)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
session, ok := a.GetSession(cookie.Value)
if !ok {
log.Printf("Authentication failed: invalid session cookie for %s %s", r.Method, r.URL.Path)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
// Add user info to request context
ctx := context.WithValue(r.Context(), "user_id", session.UserID)
ctx = context.WithValue(ctx, "user_email", session.Email)
ctx = context.WithValue(ctx, "user_name", session.Name)
ctx = context.WithValue(ctx, "is_admin", session.IsAdmin)
next(w, r.WithContext(ctx))
}
}
// GetUserID gets the user ID from context
func GetUserID(ctx context.Context) (int64, bool) {
userID, ok := ctx.Value("user_id").(int64)
return userID, ok
}
// IsAdmin checks if the user in context is an admin
func IsAdmin(ctx context.Context) bool {
isAdmin, ok := ctx.Value("is_admin").(bool)
return ok && isAdmin
}
// AdminMiddleware creates an admin-only middleware
func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// First check authentication
cookie, err := r.Cookie("session_id")
if err != nil {
log.Printf("Admin authentication failed: missing session cookie for %s %s", r.Method, r.URL.Path)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
session, ok := a.GetSession(cookie.Value)
if !ok {
log.Printf("Admin authentication failed: invalid session cookie for %s %s", r.Method, r.URL.Path)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
// Then check admin status
if !session.IsAdmin {
log.Printf("Admin access denied: user %d (email: %s) attempted to access admin endpoint %s %s", session.UserID, session.Email, r.Method, r.URL.Path)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
json.NewEncoder(w).Encode(map[string]string{"error": "Forbidden: Admin access required"})
return
}
// Add user info to request context
ctx := context.WithValue(r.Context(), "user_id", session.UserID)
ctx = context.WithValue(ctx, "user_email", session.Email)
ctx = context.WithValue(ctx, "user_name", session.Name)
ctx = context.WithValue(ctx, "is_admin", session.IsAdmin)
next(w, r.WithContext(ctx))
}
}
// IsLocalLoginEnabled returns whether local login is enabled
// Local login is enabled when ENABLE_LOCAL_AUTH environment variable is set to "true"
func (a *Auth) IsLocalLoginEnabled() bool {
return os.Getenv("ENABLE_LOCAL_AUTH") == "true"
}
// IsGoogleOAuthConfigured returns whether Google OAuth is configured
func (a *Auth) IsGoogleOAuthConfigured() bool {
return a.googleConfig != nil
}
// IsDiscordOAuthConfigured returns whether Discord OAuth is configured
func (a *Auth) IsDiscordOAuthConfigured() bool {
return a.discordConfig != nil
}
// LocalLogin handles local username/password authentication
func (a *Auth) LocalLogin(username, password string) (*Session, error) {
// Find user by email (local users use email as username)
email := username
var userID int64
var dbEmail, dbName, passwordHash string
var isAdmin bool
err := a.db.QueryRow(
"SELECT id, email, name, password_hash, is_admin FROM users WHERE email = ? AND oauth_provider = 'local'",
email,
).Scan(&userID, &dbEmail, &dbName, &passwordHash, &isAdmin)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("invalid credentials")
}
if err != nil {
return nil, fmt.Errorf("failed to query user: %w", err)
}
// Verify password
if passwordHash == "" {
return nil, fmt.Errorf("invalid credentials")
}
err = bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password))
if err != nil {
return nil, fmt.Errorf("invalid credentials")
}
// Create session
session := &Session{
UserID: userID,
Email: dbEmail,
Name: dbName,
IsAdmin: isAdmin,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
return session, nil
}
// RegisterLocalUser creates a new local user account
func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error) {
// Check if registration is enabled
registrationEnabled, err := a.IsRegistrationEnabled()
if err != nil {
return nil, fmt.Errorf("failed to check registration setting: %w", err)
}
if !registrationEnabled {
return nil, fmt.Errorf("registration is disabled")
}
// Check if user already exists
var exists bool
err = a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ?)", email).Scan(&exists)
if err != nil {
return nil, fmt.Errorf("failed to check if user exists: %w", err)
}
if exists {
return nil, fmt.Errorf("user with this email already exists")
}
// Hash password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("failed to hash password: %w", err)
}
// Check if this is the first user (make them admin)
var userCount int
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
isAdmin := userCount == 0
// Create user
var userID int64
err = a.db.QueryRow(
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?) RETURNING id",
email, name, email, string(hashedPassword), isAdmin,
).Scan(&userID)
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
// Create session
session := &Session{
UserID: userID,
Email: email,
Name: name,
IsAdmin: isAdmin,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
return session, nil
}
// ChangePassword allows a user to change their own password
func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) error {
// Get current password hash
var passwordHash string
err := a.db.QueryRow("SELECT password_hash FROM users WHERE id = ? AND oauth_provider = 'local'", userID).Scan(&passwordHash)
if err == sql.ErrNoRows {
return fmt.Errorf("user not found or not a local user")
}
if err != nil {
return fmt.Errorf("failed to query user: %w", err)
}
// Verify old password
if passwordHash == "" {
return fmt.Errorf("user has no password set")
}
err = bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(oldPassword))
if err != nil {
return fmt.Errorf("incorrect old password")
}
// Hash new password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %w", err)
}
// Update password
_, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), userID)
if err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
return nil
}
// AdminChangePassword allows an admin to change any user's password without knowing the old password
func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error {
// Verify user exists and is a local user
var exists bool
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ? AND oauth_provider = 'local')", targetUserID).Scan(&exists)
if err != nil {
return fmt.Errorf("failed to check if user exists: %w", err)
}
if !exists {
return fmt.Errorf("user not found or not a local user")
}
// Hash new password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %w", err)
}
// Update password
_, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), targetUserID)
if err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
return nil
}
// GetFirstUserID returns the ID of the first user (user with the lowest ID)
func (a *Auth) GetFirstUserID() (int64, error) {
var firstUserID int64
err := a.db.QueryRow("SELECT id FROM users ORDER BY id ASC LIMIT 1").Scan(&firstUserID)
if err == sql.ErrNoRows {
return 0, fmt.Errorf("no users found")
}
if err != nil {
return 0, fmt.Errorf("failed to get first user ID: %w", err)
}
return firstUserID, nil
}
// SetUserAdminStatus allows an admin to change a user's admin status
func (a *Auth) SetUserAdminStatus(targetUserID int64, isAdmin bool) error {
// Verify user exists
var exists bool
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", targetUserID).Scan(&exists)
if err != nil {
return fmt.Errorf("failed to check if user exists: %w", err)
}
if !exists {
return fmt.Errorf("user not found")
}
// Prevent removing admin status from the first user
firstUserID, err := a.GetFirstUserID()
if err != nil {
return fmt.Errorf("failed to check first user: %w", err)
}
if targetUserID == firstUserID && !isAdmin {
return fmt.Errorf("cannot remove admin status from the first user")
}
// Update admin status
_, err = a.db.Exec("UPDATE users SET is_admin = ? WHERE id = ?", isAdmin, targetUserID)
if err != nil {
return fmt.Errorf("failed to update admin status: %w", err)
}
return nil
}