package auth import ( "context" "database/sql" "encoding/json" "fmt" "log" "net/http" "os" "strings" "time" "github.com/google/uuid" "golang.org/x/crypto/bcrypt" "golang.org/x/oauth2" "golang.org/x/oauth2/google" ) // Auth handles authentication type Auth struct { db *sql.DB googleConfig *oauth2.Config discordConfig *oauth2.Config sessionStore map[string]*Session } // Session represents a user session type Session struct { UserID int64 Email string Name string IsAdmin bool ExpiresAt time.Time } // NewAuth creates a new auth instance func NewAuth(db *sql.DB) (*Auth, error) { auth := &Auth{ db: db, sessionStore: make(map[string]*Session), } // Initialize Google OAuth googleClientID := os.Getenv("GOOGLE_CLIENT_ID") googleClientSecret := os.Getenv("GOOGLE_CLIENT_SECRET") if googleClientID != "" && googleClientSecret != "" { auth.googleConfig = &oauth2.Config{ ClientID: googleClientID, ClientSecret: googleClientSecret, RedirectURL: os.Getenv("GOOGLE_REDIRECT_URL"), Scopes: []string{"openid", "profile", "email"}, Endpoint: google.Endpoint, } } // Initialize Discord OAuth discordClientID := os.Getenv("DISCORD_CLIENT_ID") discordClientSecret := os.Getenv("DISCORD_CLIENT_SECRET") if discordClientID != "" && discordClientSecret != "" { auth.discordConfig = &oauth2.Config{ ClientID: discordClientID, ClientSecret: discordClientSecret, RedirectURL: os.Getenv("DISCORD_REDIRECT_URL"), Scopes: []string{"identify", "email"}, Endpoint: oauth2.Endpoint{ AuthURL: "https://discord.com/api/oauth2/authorize", TokenURL: "https://discord.com/api/oauth2/token", }, } } // Initialize admin settings on startup to ensure they persist between boots if err := auth.initializeSettings(); err != nil { log.Printf("Warning: Failed to initialize admin settings: %v", err) // Don't fail startup, but log the warning } // Initialize test local user from environment variables (for testing only) if err := auth.initializeTestUser(); err != nil { log.Printf("Warning: Failed to initialize test user: %v", err) // Don't fail startup, but log the warning } return auth, nil } // initializeSettings ensures all admin settings are initialized with defaults if they don't exist func (a *Auth) initializeSettings() error { // Initialize registration_enabled setting (default: true) if it doesn't exist var settingCount int err := a.db.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount) if err != nil { return fmt.Errorf("failed to check registration_enabled setting: %w", err) } if settingCount == 0 { _, err = a.db.Exec( `INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)`, "registration_enabled", "true", ) if err != nil { return fmt.Errorf("failed to initialize registration_enabled setting: %w", err) } log.Printf("Initialized admin setting: registration_enabled = true") } return nil } // initializeTestUser creates a test local user from environment variables (for testing only) func (a *Auth) initializeTestUser() error { testEmail := os.Getenv("LOCAL_TEST_EMAIL") testPassword := os.Getenv("LOCAL_TEST_PASSWORD") if testEmail == "" || testPassword == "" { // No test user configured, skip return nil } // Check if user already exists var exists bool err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND oauth_provider = 'local')", testEmail).Scan(&exists) if err != nil { return fmt.Errorf("failed to check if test user exists: %w", err) } if exists { // User already exists, skip creation log.Printf("Test user %s already exists, skipping creation", testEmail) return nil } // Hash password hashedPassword, err := bcrypt.GenerateFromPassword([]byte(testPassword), bcrypt.DefaultCost) if err != nil { return fmt.Errorf("failed to hash test user password: %w", err) } // Check if this is the first user (make them admin) var userCount int a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount) isAdmin := userCount == 0 // Create test user (use email as name if no name is provided) testName := testEmail if atIndex := strings.Index(testEmail, "@"); atIndex > 0 { testName = testEmail[:atIndex] } // Create test user _, err = a.db.Exec( "INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?)", testEmail, testName, testEmail, string(hashedPassword), isAdmin, ) if err != nil { return fmt.Errorf("failed to create test user: %w", err) } log.Printf("Created test user: %s (admin: %v)", testEmail, isAdmin) return nil } // GoogleLoginURL returns the Google OAuth login URL func (a *Auth) GoogleLoginURL() (string, error) { if a.googleConfig == nil { return "", fmt.Errorf("Google OAuth not configured") } state := uuid.New().String() return a.googleConfig.AuthCodeURL(state), nil } // DiscordLoginURL returns the Discord OAuth login URL func (a *Auth) DiscordLoginURL() (string, error) { if a.discordConfig == nil { return "", fmt.Errorf("Discord OAuth not configured") } state := uuid.New().String() return a.discordConfig.AuthCodeURL(state), nil } // GoogleCallback handles Google OAuth callback func (a *Auth) GoogleCallback(ctx context.Context, code string) (*Session, error) { if a.googleConfig == nil { return nil, fmt.Errorf("Google OAuth not configured") } token, err := a.googleConfig.Exchange(ctx, code) if err != nil { return nil, fmt.Errorf("failed to exchange token: %w", err) } client := a.googleConfig.Client(ctx, token) resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo") if err != nil { return nil, fmt.Errorf("failed to get user info: %w", err) } defer resp.Body.Close() var userInfo struct { ID string `json:"id"` Email string `json:"email"` Name string `json:"name"` } if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { return nil, fmt.Errorf("failed to decode user info: %w", err) } return a.getOrCreateUser("google", userInfo.ID, userInfo.Email, userInfo.Name) } // DiscordCallback handles Discord OAuth callback func (a *Auth) DiscordCallback(ctx context.Context, code string) (*Session, error) { if a.discordConfig == nil { return nil, fmt.Errorf("Discord OAuth not configured") } token, err := a.discordConfig.Exchange(ctx, code) if err != nil { return nil, fmt.Errorf("failed to exchange token: %w", err) } client := a.discordConfig.Client(ctx, token) resp, err := client.Get("https://discord.com/api/users/@me") if err != nil { return nil, fmt.Errorf("failed to get user info: %w", err) } defer resp.Body.Close() var userInfo struct { ID string `json:"id"` Email string `json:"email"` Username string `json:"username"` } if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { return nil, fmt.Errorf("failed to decode user info: %w", err) } return a.getOrCreateUser("discord", userInfo.ID, userInfo.Email, userInfo.Username) } // IsRegistrationEnabled checks if new user registration is enabled func (a *Auth) IsRegistrationEnabled() (bool, error) { var value string err := a.db.QueryRow("SELECT value FROM settings WHERE key = ?", "registration_enabled").Scan(&value) if err == sql.ErrNoRows { // Default to enabled if setting doesn't exist return true, nil } if err != nil { return false, fmt.Errorf("failed to check registration setting: %w", err) } return value == "true", nil } // SetRegistrationEnabled sets whether new user registration is enabled func (a *Auth) SetRegistrationEnabled(enabled bool) error { value := "false" if enabled { value = "true" } // Check if setting exists var exists bool err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", "registration_enabled").Scan(&exists) if err != nil { return fmt.Errorf("failed to check if setting exists: %w", err) } if exists { // Update existing setting _, err = a.db.Exec( "UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?", value, "registration_enabled", ) if err != nil { return fmt.Errorf("failed to update setting: %w", err) } } else { // Insert new setting _, err = a.db.Exec( "INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)", "registration_enabled", value, ) if err != nil { return fmt.Errorf("failed to insert setting: %w", err) } } return nil } // getOrCreateUser gets or creates a user in the database // Automatically links accounts by email across different OAuth providers and local login func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session, error) { var userID int64 var dbEmail, dbName string var isAdmin bool var dbProvider, dbOAuthID string // First, try to find by provider + oauth_id err := a.db.QueryRow( "SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE oauth_provider = ? AND oauth_id = ?", provider, oauthID, ).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID) if err == sql.ErrNoRows { // Not found by provider+oauth_id, check by email for account linking err = a.db.QueryRow( "SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE email = ?", email, ).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID) if err == sql.ErrNoRows { // User doesn't exist, check if registration is enabled registrationEnabled, err := a.IsRegistrationEnabled() if err != nil { return nil, fmt.Errorf("failed to check registration setting: %w", err) } if !registrationEnabled { return nil, fmt.Errorf("registration is disabled") } // Check if this is the first user var userCount int a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount) isAdmin = userCount == 0 // Create new user err = a.db.QueryRow( "INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?) RETURNING id", email, name, provider, oauthID, isAdmin, ).Scan(&userID) if err != nil { return nil, fmt.Errorf("failed to create user: %w", err) } } else if err != nil { return nil, fmt.Errorf("failed to query user by email: %w", err) } else { // User exists with same email but different provider - link accounts by updating provider info // This allows the user to log in with any provider that has the same email _, err = a.db.Exec( "UPDATE users SET oauth_provider = ?, oauth_id = ?, name = ? WHERE id = ?", provider, oauthID, name, userID, ) if err != nil { return nil, fmt.Errorf("failed to link account: %w", err) } log.Printf("Linked account: user %d (email: %s) now accessible via %s provider", userID, email, provider) } } else if err != nil { return nil, fmt.Errorf("failed to query user: %w", err) } else { // User found by provider+oauth_id, update info if changed if dbEmail != email || dbName != name { _, err = a.db.Exec( "UPDATE users SET email = ?, name = ? WHERE id = ?", email, name, userID, ) if err != nil { return nil, fmt.Errorf("failed to update user: %w", err) } } } session := &Session{ UserID: userID, Email: email, Name: name, IsAdmin: isAdmin, ExpiresAt: time.Now().Add(24 * time.Hour), } return session, nil } // CreateSession creates a new session and returns a session ID func (a *Auth) CreateSession(session *Session) string { sessionID := uuid.New().String() a.sessionStore[sessionID] = session return sessionID } // GetSession retrieves a session by ID func (a *Auth) GetSession(sessionID string) (*Session, bool) { session, ok := a.sessionStore[sessionID] if !ok { return nil, false } if time.Now().After(session.ExpiresAt) { delete(a.sessionStore, sessionID) return nil, false } // Refresh admin status from database var isAdmin bool err := a.db.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin) if err == nil { session.IsAdmin = isAdmin } return session, true } // DeleteSession deletes a session func (a *Auth) DeleteSession(sessionID string) { delete(a.sessionStore, sessionID) } // Middleware creates an authentication middleware func (a *Auth) Middleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("session_id") if err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"}) return } session, ok := a.GetSession(cookie.Value) if !ok { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"}) return } // Add user info to request context ctx := context.WithValue(r.Context(), "user_id", session.UserID) ctx = context.WithValue(ctx, "user_email", session.Email) ctx = context.WithValue(ctx, "user_name", session.Name) ctx = context.WithValue(ctx, "is_admin", session.IsAdmin) next(w, r.WithContext(ctx)) } } // GetUserID gets the user ID from context func GetUserID(ctx context.Context) (int64, bool) { userID, ok := ctx.Value("user_id").(int64) return userID, ok } // IsAdmin checks if the user in context is an admin func IsAdmin(ctx context.Context) bool { isAdmin, ok := ctx.Value("is_admin").(bool) return ok && isAdmin } // AdminMiddleware creates an admin-only middleware func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // First check authentication cookie, err := r.Cookie("session_id") if err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"}) return } session, ok := a.GetSession(cookie.Value) if !ok { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"}) return } // Then check admin status if !session.IsAdmin { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusForbidden) json.NewEncoder(w).Encode(map[string]string{"error": "Forbidden: Admin access required"}) return } // Add user info to request context ctx := context.WithValue(r.Context(), "user_id", session.UserID) ctx = context.WithValue(ctx, "user_email", session.Email) ctx = context.WithValue(ctx, "user_name", session.Name) ctx = context.WithValue(ctx, "is_admin", session.IsAdmin) next(w, r.WithContext(ctx)) } } // IsLocalLoginEnabled returns whether local login is enabled // Local login is enabled when ENABLE_LOCAL_AUTH environment variable is set to "true" func (a *Auth) IsLocalLoginEnabled() bool { return os.Getenv("ENABLE_LOCAL_AUTH") == "true" } // IsGoogleOAuthConfigured returns whether Google OAuth is configured func (a *Auth) IsGoogleOAuthConfigured() bool { return a.googleConfig != nil } // IsDiscordOAuthConfigured returns whether Discord OAuth is configured func (a *Auth) IsDiscordOAuthConfigured() bool { return a.discordConfig != nil } // LocalLogin handles local username/password authentication func (a *Auth) LocalLogin(username, password string) (*Session, error) { // Find user by email (local users use email as username) email := username var userID int64 var dbEmail, dbName, passwordHash string var isAdmin bool err := a.db.QueryRow( "SELECT id, email, name, password_hash, is_admin FROM users WHERE email = ? AND oauth_provider = 'local'", email, ).Scan(&userID, &dbEmail, &dbName, &passwordHash, &isAdmin) if err == sql.ErrNoRows { return nil, fmt.Errorf("invalid credentials") } if err != nil { return nil, fmt.Errorf("failed to query user: %w", err) } // Verify password if passwordHash == "" { return nil, fmt.Errorf("invalid credentials") } err = bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password)) if err != nil { return nil, fmt.Errorf("invalid credentials") } // Create session session := &Session{ UserID: userID, Email: dbEmail, Name: dbName, IsAdmin: isAdmin, ExpiresAt: time.Now().Add(24 * time.Hour), } return session, nil } // RegisterLocalUser creates a new local user account func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error) { // Check if registration is enabled registrationEnabled, err := a.IsRegistrationEnabled() if err != nil { return nil, fmt.Errorf("failed to check registration setting: %w", err) } if !registrationEnabled { return nil, fmt.Errorf("registration is disabled") } // Check if user already exists var exists bool err = a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ?)", email).Scan(&exists) if err != nil { return nil, fmt.Errorf("failed to check if user exists: %w", err) } if exists { return nil, fmt.Errorf("user with this email already exists") } // Hash password hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return nil, fmt.Errorf("failed to hash password: %w", err) } // Check if this is the first user (make them admin) var userCount int a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount) isAdmin := userCount == 0 // Create user var userID int64 err = a.db.QueryRow( "INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?) RETURNING id", email, name, email, string(hashedPassword), isAdmin, ).Scan(&userID) if err != nil { return nil, fmt.Errorf("failed to create user: %w", err) } // Create session session := &Session{ UserID: userID, Email: email, Name: name, IsAdmin: isAdmin, ExpiresAt: time.Now().Add(24 * time.Hour), } return session, nil } // ChangePassword allows a user to change their own password func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) error { // Get current password hash var passwordHash string err := a.db.QueryRow("SELECT password_hash FROM users WHERE id = ? AND oauth_provider = 'local'", userID).Scan(&passwordHash) if err == sql.ErrNoRows { return fmt.Errorf("user not found or not a local user") } if err != nil { return fmt.Errorf("failed to query user: %w", err) } // Verify old password if passwordHash == "" { return fmt.Errorf("user has no password set") } err = bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(oldPassword)) if err != nil { return fmt.Errorf("incorrect old password") } // Hash new password hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost) if err != nil { return fmt.Errorf("failed to hash password: %w", err) } // Update password _, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), userID) if err != nil { return fmt.Errorf("failed to update password: %w", err) } return nil } // AdminChangePassword allows an admin to change any user's password without knowing the old password func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error { // Verify user exists and is a local user var exists bool err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ? AND oauth_provider = 'local')", targetUserID).Scan(&exists) if err != nil { return fmt.Errorf("failed to check if user exists: %w", err) } if !exists { return fmt.Errorf("user not found or not a local user") } // Hash new password hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost) if err != nil { return fmt.Errorf("failed to hash password: %w", err) } // Update password _, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), targetUserID) if err != nil { return fmt.Errorf("failed to update password: %w", err) } return nil } // GetFirstUserID returns the ID of the first user (user with the lowest ID) func (a *Auth) GetFirstUserID() (int64, error) { var firstUserID int64 err := a.db.QueryRow("SELECT id FROM users ORDER BY id ASC LIMIT 1").Scan(&firstUserID) if err == sql.ErrNoRows { return 0, fmt.Errorf("no users found") } if err != nil { return 0, fmt.Errorf("failed to get first user ID: %w", err) } return firstUserID, nil } // SetUserAdminStatus allows an admin to change a user's admin status func (a *Auth) SetUserAdminStatus(targetUserID int64, isAdmin bool) error { // Verify user exists var exists bool err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", targetUserID).Scan(&exists) if err != nil { return fmt.Errorf("failed to check if user exists: %w", err) } if !exists { return fmt.Errorf("user not found") } // Prevent removing admin status from the first user firstUserID, err := a.GetFirstUserID() if err != nil { return fmt.Errorf("failed to check first user: %w", err) } if targetUserID == firstUserID && !isAdmin { return fmt.Errorf("cannot remove admin status from the first user") } // Update admin status _, err = a.db.Exec("UPDATE users SET is_admin = ? WHERE id = ?", isAdmin, targetUserID) if err != nil { return fmt.Errorf("failed to update admin status: %w", err) } return nil }