massive changes and it works
This commit is contained in:
@@ -5,21 +5,24 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
// Auth handles authentication
|
||||
type Auth struct {
|
||||
db *sql.DB
|
||||
googleConfig *oauth2.Config
|
||||
discordConfig *oauth2.Config
|
||||
sessionStore map[string]*Session
|
||||
db *sql.DB
|
||||
googleConfig *oauth2.Config
|
||||
discordConfig *oauth2.Config
|
||||
sessionStore map[string]*Session
|
||||
}
|
||||
|
||||
// Session represents a user session
|
||||
@@ -67,9 +70,95 @@ func NewAuth(db *sql.DB) (*Auth, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize admin settings on startup to ensure they persist between boots
|
||||
if err := auth.initializeSettings(); err != nil {
|
||||
log.Printf("Warning: Failed to initialize admin settings: %v", err)
|
||||
// Don't fail startup, but log the warning
|
||||
}
|
||||
|
||||
// Initialize test local user from environment variables (for testing only)
|
||||
if err := auth.initializeTestUser(); err != nil {
|
||||
log.Printf("Warning: Failed to initialize test user: %v", err)
|
||||
// Don't fail startup, but log the warning
|
||||
}
|
||||
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// initializeSettings ensures all admin settings are initialized with defaults if they don't exist
|
||||
func (a *Auth) initializeSettings() error {
|
||||
// Initialize registration_enabled setting (default: true) if it doesn't exist
|
||||
var settingCount int
|
||||
err := a.db.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check registration_enabled setting: %w", err)
|
||||
}
|
||||
if settingCount == 0 {
|
||||
_, err = a.db.Exec(
|
||||
`INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)`,
|
||||
"registration_enabled", "true",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize registration_enabled setting: %w", err)
|
||||
}
|
||||
log.Printf("Initialized admin setting: registration_enabled = true")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// initializeTestUser creates a test local user from environment variables (for testing only)
|
||||
func (a *Auth) initializeTestUser() error {
|
||||
testEmail := os.Getenv("LOCAL_TEST_EMAIL")
|
||||
testPassword := os.Getenv("LOCAL_TEST_PASSWORD")
|
||||
|
||||
if testEmail == "" || testPassword == "" {
|
||||
// No test user configured, skip
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if user already exists
|
||||
var exists bool
|
||||
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND oauth_provider = 'local')", testEmail).Scan(&exists)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if test user exists: %w", err)
|
||||
}
|
||||
|
||||
if exists {
|
||||
// User already exists, skip creation
|
||||
log.Printf("Test user %s already exists, skipping creation", testEmail)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Hash password
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(testPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash test user password: %w", err)
|
||||
}
|
||||
|
||||
// Check if this is the first user (make them admin)
|
||||
var userCount int
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
isAdmin := userCount == 0
|
||||
|
||||
// Create test user (use email as name if no name is provided)
|
||||
testName := testEmail
|
||||
if atIndex := strings.Index(testEmail, "@"); atIndex > 0 {
|
||||
testName = testEmail[:atIndex]
|
||||
}
|
||||
|
||||
// Create test user
|
||||
_, err = a.db.Exec(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?)",
|
||||
testEmail, testName, testEmail, string(hashedPassword), isAdmin,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create test user: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Created test user: %s (admin: %v)", testEmail, isAdmin)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GoogleLoginURL returns the Google OAuth login URL
|
||||
func (a *Auth) GoogleLoginURL() (string, error) {
|
||||
if a.googleConfig == nil {
|
||||
@@ -150,36 +239,119 @@ func (a *Auth) DiscordCallback(ctx context.Context, code string) (*Session, erro
|
||||
return a.getOrCreateUser("discord", userInfo.ID, userInfo.Email, userInfo.Username)
|
||||
}
|
||||
|
||||
// IsRegistrationEnabled checks if new user registration is enabled
|
||||
func (a *Auth) IsRegistrationEnabled() (bool, error) {
|
||||
var value string
|
||||
err := a.db.QueryRow("SELECT value FROM settings WHERE key = ?", "registration_enabled").Scan(&value)
|
||||
if err == sql.ErrNoRows {
|
||||
// Default to enabled if setting doesn't exist
|
||||
return true, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check registration setting: %w", err)
|
||||
}
|
||||
return value == "true", nil
|
||||
}
|
||||
|
||||
// SetRegistrationEnabled sets whether new user registration is enabled
|
||||
func (a *Auth) SetRegistrationEnabled(enabled bool) error {
|
||||
value := "false"
|
||||
if enabled {
|
||||
value = "true"
|
||||
}
|
||||
|
||||
// Check if setting exists
|
||||
var exists bool
|
||||
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", "registration_enabled").Scan(&exists)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if setting exists: %w", err)
|
||||
}
|
||||
|
||||
if exists {
|
||||
// Update existing setting
|
||||
_, err = a.db.Exec(
|
||||
"UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?",
|
||||
value, "registration_enabled",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update setting: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Insert new setting
|
||||
_, err = a.db.Exec(
|
||||
"INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)",
|
||||
"registration_enabled", value,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert setting: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getOrCreateUser gets or creates a user in the database
|
||||
// Automatically links accounts by email across different OAuth providers and local login
|
||||
func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session, error) {
|
||||
var userID int64
|
||||
var dbEmail, dbName string
|
||||
var isAdmin bool
|
||||
var dbProvider, dbOAuthID string
|
||||
|
||||
// First, try to find by provider + oauth_id
|
||||
err := a.db.QueryRow(
|
||||
"SELECT id, email, name, is_admin FROM users WHERE oauth_provider = ? AND oauth_id = ?",
|
||||
"SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE oauth_provider = ? AND oauth_id = ?",
|
||||
provider, oauthID,
|
||||
).Scan(&userID, &dbEmail, &dbName, &isAdmin)
|
||||
).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
// Check if this is the first user
|
||||
var userCount int
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
isAdmin = userCount == 0
|
||||
// Not found by provider+oauth_id, check by email for account linking
|
||||
err = a.db.QueryRow(
|
||||
"SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE email = ?",
|
||||
email,
|
||||
).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID)
|
||||
|
||||
// Create new user
|
||||
result, err := a.db.Exec(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?)",
|
||||
email, name, provider, oauthID, isAdmin,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
if err == sql.ErrNoRows {
|
||||
// User doesn't exist, check if registration is enabled
|
||||
registrationEnabled, err := a.IsRegistrationEnabled()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check registration setting: %w", err)
|
||||
}
|
||||
if !registrationEnabled {
|
||||
return nil, fmt.Errorf("registration is disabled")
|
||||
}
|
||||
|
||||
// Check if this is the first user
|
||||
var userCount int
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
isAdmin = userCount == 0
|
||||
|
||||
// Create new user
|
||||
err = a.db.QueryRow(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?) RETURNING id",
|
||||
email, name, provider, oauthID, isAdmin,
|
||||
).Scan(&userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("failed to query user by email: %w", err)
|
||||
} else {
|
||||
// User exists with same email but different provider - link accounts by updating provider info
|
||||
// This allows the user to log in with any provider that has the same email
|
||||
_, err = a.db.Exec(
|
||||
"UPDATE users SET oauth_provider = ?, oauth_id = ?, name = ? WHERE id = ?",
|
||||
provider, oauthID, name, userID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to link account: %w", err)
|
||||
}
|
||||
log.Printf("Linked account: user %d (email: %s) now accessible via %s provider", userID, email, provider)
|
||||
}
|
||||
userID, _ = result.LastInsertId()
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("failed to query user: %w", err)
|
||||
} else {
|
||||
// Update user info if changed
|
||||
// User found by provider+oauth_id, update info if changed
|
||||
if dbEmail != email || dbName != name {
|
||||
_, err = a.db.Exec(
|
||||
"UPDATE users SET email = ?, name = ? WHERE id = ?",
|
||||
@@ -238,13 +410,17 @@ func (a *Auth) Middleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("session_id")
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
session, ok := a.GetSession(cookie.Value)
|
||||
if !ok {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -275,19 +451,25 @@ func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
// First check authentication
|
||||
cookie, err := r.Cookie("session_id")
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
session, ok := a.GetSession(cookie.Value)
|
||||
if !ok {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
// Then check admin status
|
||||
if !session.IsAdmin {
|
||||
http.Error(w, "Forbidden: Admin access required", http.StatusForbidden)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Forbidden: Admin access required"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -300,3 +482,221 @@ func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// IsLocalLoginEnabled returns whether local login is enabled
|
||||
// Local login is enabled when ENABLE_LOCAL_AUTH environment variable is set to "true"
|
||||
func (a *Auth) IsLocalLoginEnabled() bool {
|
||||
return os.Getenv("ENABLE_LOCAL_AUTH") == "true"
|
||||
}
|
||||
|
||||
// IsGoogleOAuthConfigured returns whether Google OAuth is configured
|
||||
func (a *Auth) IsGoogleOAuthConfigured() bool {
|
||||
return a.googleConfig != nil
|
||||
}
|
||||
|
||||
// IsDiscordOAuthConfigured returns whether Discord OAuth is configured
|
||||
func (a *Auth) IsDiscordOAuthConfigured() bool {
|
||||
return a.discordConfig != nil
|
||||
}
|
||||
|
||||
// LocalLogin handles local username/password authentication
|
||||
func (a *Auth) LocalLogin(username, password string) (*Session, error) {
|
||||
// Find user by email (local users use email as username)
|
||||
email := username
|
||||
var userID int64
|
||||
var dbEmail, dbName, passwordHash string
|
||||
var isAdmin bool
|
||||
|
||||
err := a.db.QueryRow(
|
||||
"SELECT id, email, name, password_hash, is_admin FROM users WHERE email = ? AND oauth_provider = 'local'",
|
||||
email,
|
||||
).Scan(&userID, &dbEmail, &dbName, &passwordHash, &isAdmin)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query user: %w", err)
|
||||
}
|
||||
|
||||
// Verify password
|
||||
if passwordHash == "" {
|
||||
return nil, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
// Create session
|
||||
session := &Session{
|
||||
UserID: userID,
|
||||
Email: dbEmail,
|
||||
Name: dbName,
|
||||
IsAdmin: isAdmin,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// RegisterLocalUser creates a new local user account
|
||||
func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error) {
|
||||
// Check if registration is enabled
|
||||
registrationEnabled, err := a.IsRegistrationEnabled()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check registration setting: %w", err)
|
||||
}
|
||||
if !registrationEnabled {
|
||||
return nil, fmt.Errorf("registration is disabled")
|
||||
}
|
||||
|
||||
// Check if user already exists
|
||||
var exists bool
|
||||
err = a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ?)", email).Scan(&exists)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if user exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, fmt.Errorf("user with this email already exists")
|
||||
}
|
||||
|
||||
// Hash password
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
// Check if this is the first user (make them admin)
|
||||
var userCount int
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
isAdmin := userCount == 0
|
||||
|
||||
// Create user
|
||||
var userID int64
|
||||
err = a.db.QueryRow(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?) RETURNING id",
|
||||
email, name, email, string(hashedPassword), isAdmin,
|
||||
).Scan(&userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
// Create session
|
||||
session := &Session{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
Name: name,
|
||||
IsAdmin: isAdmin,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// ChangePassword allows a user to change their own password
|
||||
func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) error {
|
||||
// Get current password hash
|
||||
var passwordHash string
|
||||
err := a.db.QueryRow("SELECT password_hash FROM users WHERE id = ? AND oauth_provider = 'local'", userID).Scan(&passwordHash)
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("user not found or not a local user")
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query user: %w", err)
|
||||
}
|
||||
|
||||
// Verify old password
|
||||
if passwordHash == "" {
|
||||
return fmt.Errorf("user has no password set")
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(oldPassword))
|
||||
if err != nil {
|
||||
return fmt.Errorf("incorrect old password")
|
||||
}
|
||||
|
||||
// Hash new password
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
// Update password
|
||||
_, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AdminChangePassword allows an admin to change any user's password without knowing the old password
|
||||
func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error {
|
||||
// Verify user exists and is a local user
|
||||
var exists bool
|
||||
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ? AND oauth_provider = 'local')", targetUserID).Scan(&exists)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if user exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return fmt.Errorf("user not found or not a local user")
|
||||
}
|
||||
|
||||
// Hash new password
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
// Update password
|
||||
_, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), targetUserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFirstUserID returns the ID of the first user (user with the lowest ID)
|
||||
func (a *Auth) GetFirstUserID() (int64, error) {
|
||||
var firstUserID int64
|
||||
err := a.db.QueryRow("SELECT id FROM users ORDER BY id ASC LIMIT 1").Scan(&firstUserID)
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, fmt.Errorf("no users found")
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get first user ID: %w", err)
|
||||
}
|
||||
return firstUserID, nil
|
||||
}
|
||||
|
||||
// SetUserAdminStatus allows an admin to change a user's admin status
|
||||
func (a *Auth) SetUserAdminStatus(targetUserID int64, isAdmin bool) error {
|
||||
// Verify user exists
|
||||
var exists bool
|
||||
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", targetUserID).Scan(&exists)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if user exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return fmt.Errorf("user not found")
|
||||
}
|
||||
|
||||
// Prevent removing admin status from the first user
|
||||
firstUserID, err := a.GetFirstUserID()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check first user: %w", err)
|
||||
}
|
||||
if targetUserID == firstUserID && !isAdmin {
|
||||
return fmt.Errorf("cannot remove admin status from the first user")
|
||||
}
|
||||
|
||||
// Update admin status
|
||||
_, err = a.db.Exec("UPDATE users SET is_admin = ? WHERE id = ?", isAdmin, targetUserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update admin status: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user