303 lines
8.2 KiB
Go
303 lines
8.2 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"golang.org/x/oauth2"
|
|
"golang.org/x/oauth2/google"
|
|
)
|
|
|
|
// Auth handles authentication
|
|
type Auth struct {
|
|
db *sql.DB
|
|
googleConfig *oauth2.Config
|
|
discordConfig *oauth2.Config
|
|
sessionStore map[string]*Session
|
|
}
|
|
|
|
// Session represents a user session
|
|
type Session struct {
|
|
UserID int64
|
|
Email string
|
|
Name string
|
|
IsAdmin bool
|
|
ExpiresAt time.Time
|
|
}
|
|
|
|
// NewAuth creates a new auth instance
|
|
func NewAuth(db *sql.DB) (*Auth, error) {
|
|
auth := &Auth{
|
|
db: db,
|
|
sessionStore: make(map[string]*Session),
|
|
}
|
|
|
|
// Initialize Google OAuth
|
|
googleClientID := os.Getenv("GOOGLE_CLIENT_ID")
|
|
googleClientSecret := os.Getenv("GOOGLE_CLIENT_SECRET")
|
|
if googleClientID != "" && googleClientSecret != "" {
|
|
auth.googleConfig = &oauth2.Config{
|
|
ClientID: googleClientID,
|
|
ClientSecret: googleClientSecret,
|
|
RedirectURL: os.Getenv("GOOGLE_REDIRECT_URL"),
|
|
Scopes: []string{"openid", "profile", "email"},
|
|
Endpoint: google.Endpoint,
|
|
}
|
|
}
|
|
|
|
// Initialize Discord OAuth
|
|
discordClientID := os.Getenv("DISCORD_CLIENT_ID")
|
|
discordClientSecret := os.Getenv("DISCORD_CLIENT_SECRET")
|
|
if discordClientID != "" && discordClientSecret != "" {
|
|
auth.discordConfig = &oauth2.Config{
|
|
ClientID: discordClientID,
|
|
ClientSecret: discordClientSecret,
|
|
RedirectURL: os.Getenv("DISCORD_REDIRECT_URL"),
|
|
Scopes: []string{"identify", "email"},
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: "https://discord.com/api/oauth2/authorize",
|
|
TokenURL: "https://discord.com/api/oauth2/token",
|
|
},
|
|
}
|
|
}
|
|
|
|
return auth, nil
|
|
}
|
|
|
|
// GoogleLoginURL returns the Google OAuth login URL
|
|
func (a *Auth) GoogleLoginURL() (string, error) {
|
|
if a.googleConfig == nil {
|
|
return "", fmt.Errorf("Google OAuth not configured")
|
|
}
|
|
state := uuid.New().String()
|
|
return a.googleConfig.AuthCodeURL(state), nil
|
|
}
|
|
|
|
// DiscordLoginURL returns the Discord OAuth login URL
|
|
func (a *Auth) DiscordLoginURL() (string, error) {
|
|
if a.discordConfig == nil {
|
|
return "", fmt.Errorf("Discord OAuth not configured")
|
|
}
|
|
state := uuid.New().String()
|
|
return a.discordConfig.AuthCodeURL(state), nil
|
|
}
|
|
|
|
// GoogleCallback handles Google OAuth callback
|
|
func (a *Auth) GoogleCallback(ctx context.Context, code string) (*Session, error) {
|
|
if a.googleConfig == nil {
|
|
return nil, fmt.Errorf("Google OAuth not configured")
|
|
}
|
|
|
|
token, err := a.googleConfig.Exchange(ctx, code)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to exchange token: %w", err)
|
|
}
|
|
|
|
client := a.googleConfig.Client(ctx, token)
|
|
resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get user info: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var userInfo struct {
|
|
ID string `json:"id"`
|
|
Email string `json:"email"`
|
|
Name string `json:"name"`
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
|
return nil, fmt.Errorf("failed to decode user info: %w", err)
|
|
}
|
|
|
|
return a.getOrCreateUser("google", userInfo.ID, userInfo.Email, userInfo.Name)
|
|
}
|
|
|
|
// DiscordCallback handles Discord OAuth callback
|
|
func (a *Auth) DiscordCallback(ctx context.Context, code string) (*Session, error) {
|
|
if a.discordConfig == nil {
|
|
return nil, fmt.Errorf("Discord OAuth not configured")
|
|
}
|
|
|
|
token, err := a.discordConfig.Exchange(ctx, code)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to exchange token: %w", err)
|
|
}
|
|
|
|
client := a.discordConfig.Client(ctx, token)
|
|
resp, err := client.Get("https://discord.com/api/users/@me")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get user info: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var userInfo struct {
|
|
ID string `json:"id"`
|
|
Email string `json:"email"`
|
|
Username string `json:"username"`
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
|
return nil, fmt.Errorf("failed to decode user info: %w", err)
|
|
}
|
|
|
|
return a.getOrCreateUser("discord", userInfo.ID, userInfo.Email, userInfo.Username)
|
|
}
|
|
|
|
// getOrCreateUser gets or creates a user in the database
|
|
func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session, error) {
|
|
var userID int64
|
|
var dbEmail, dbName string
|
|
var isAdmin bool
|
|
|
|
err := a.db.QueryRow(
|
|
"SELECT id, email, name, is_admin FROM users WHERE oauth_provider = ? AND oauth_id = ?",
|
|
provider, oauthID,
|
|
).Scan(&userID, &dbEmail, &dbName, &isAdmin)
|
|
|
|
if err == sql.ErrNoRows {
|
|
// Check if this is the first user
|
|
var userCount int
|
|
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
|
isAdmin = userCount == 0
|
|
|
|
// Create new user
|
|
result, err := a.db.Exec(
|
|
"INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?)",
|
|
email, name, provider, oauthID, isAdmin,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create user: %w", err)
|
|
}
|
|
userID, _ = result.LastInsertId()
|
|
} else if err != nil {
|
|
return nil, fmt.Errorf("failed to query user: %w", err)
|
|
} else {
|
|
// Update user info if changed
|
|
if dbEmail != email || dbName != name {
|
|
_, err = a.db.Exec(
|
|
"UPDATE users SET email = ?, name = ? WHERE id = ?",
|
|
email, name, userID,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to update user: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
session := &Session{
|
|
UserID: userID,
|
|
Email: email,
|
|
Name: name,
|
|
IsAdmin: isAdmin,
|
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
|
}
|
|
|
|
return session, nil
|
|
}
|
|
|
|
// CreateSession creates a new session and returns a session ID
|
|
func (a *Auth) CreateSession(session *Session) string {
|
|
sessionID := uuid.New().String()
|
|
a.sessionStore[sessionID] = session
|
|
return sessionID
|
|
}
|
|
|
|
// GetSession retrieves a session by ID
|
|
func (a *Auth) GetSession(sessionID string) (*Session, bool) {
|
|
session, ok := a.sessionStore[sessionID]
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
if time.Now().After(session.ExpiresAt) {
|
|
delete(a.sessionStore, sessionID)
|
|
return nil, false
|
|
}
|
|
// Refresh admin status from database
|
|
var isAdmin bool
|
|
err := a.db.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin)
|
|
if err == nil {
|
|
session.IsAdmin = isAdmin
|
|
}
|
|
return session, true
|
|
}
|
|
|
|
// DeleteSession deletes a session
|
|
func (a *Auth) DeleteSession(sessionID string) {
|
|
delete(a.sessionStore, sessionID)
|
|
}
|
|
|
|
// Middleware creates an authentication middleware
|
|
func (a *Auth) Middleware(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
cookie, err := r.Cookie("session_id")
|
|
if err != nil {
|
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
session, ok := a.GetSession(cookie.Value)
|
|
if !ok {
|
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Add user info to request context
|
|
ctx := context.WithValue(r.Context(), "user_id", session.UserID)
|
|
ctx = context.WithValue(ctx, "user_email", session.Email)
|
|
ctx = context.WithValue(ctx, "user_name", session.Name)
|
|
ctx = context.WithValue(ctx, "is_admin", session.IsAdmin)
|
|
next(w, r.WithContext(ctx))
|
|
}
|
|
}
|
|
|
|
// GetUserID gets the user ID from context
|
|
func GetUserID(ctx context.Context) (int64, bool) {
|
|
userID, ok := ctx.Value("user_id").(int64)
|
|
return userID, ok
|
|
}
|
|
|
|
// IsAdmin checks if the user in context is an admin
|
|
func IsAdmin(ctx context.Context) bool {
|
|
isAdmin, ok := ctx.Value("is_admin").(bool)
|
|
return ok && isAdmin
|
|
}
|
|
|
|
// AdminMiddleware creates an admin-only middleware
|
|
func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
// First check authentication
|
|
cookie, err := r.Cookie("session_id")
|
|
if err != nil {
|
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
session, ok := a.GetSession(cookie.Value)
|
|
if !ok {
|
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Then check admin status
|
|
if !session.IsAdmin {
|
|
http.Error(w, "Forbidden: Admin access required", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
// Add user info to request context
|
|
ctx := context.WithValue(r.Context(), "user_id", session.UserID)
|
|
ctx = context.WithValue(ctx, "user_email", session.Email)
|
|
ctx = context.WithValue(ctx, "user_name", session.Name)
|
|
ctx = context.WithValue(ctx, "is_admin", session.IsAdmin)
|
|
next(w, r.WithContext(ctx))
|
|
}
|
|
}
|
|
|