initial commit
This commit is contained in:
302
internal/auth/auth.go
Normal file
302
internal/auth/auth.go
Normal file
@@ -0,0 +1,302 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
// Auth handles authentication
|
||||
type Auth struct {
|
||||
db *sql.DB
|
||||
googleConfig *oauth2.Config
|
||||
discordConfig *oauth2.Config
|
||||
sessionStore map[string]*Session
|
||||
}
|
||||
|
||||
// Session represents a user session
|
||||
type Session struct {
|
||||
UserID int64
|
||||
Email string
|
||||
Name string
|
||||
IsAdmin bool
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// NewAuth creates a new auth instance
|
||||
func NewAuth(db *sql.DB) (*Auth, error) {
|
||||
auth := &Auth{
|
||||
db: db,
|
||||
sessionStore: make(map[string]*Session),
|
||||
}
|
||||
|
||||
// Initialize Google OAuth
|
||||
googleClientID := os.Getenv("GOOGLE_CLIENT_ID")
|
||||
googleClientSecret := os.Getenv("GOOGLE_CLIENT_SECRET")
|
||||
if googleClientID != "" && googleClientSecret != "" {
|
||||
auth.googleConfig = &oauth2.Config{
|
||||
ClientID: googleClientID,
|
||||
ClientSecret: googleClientSecret,
|
||||
RedirectURL: os.Getenv("GOOGLE_REDIRECT_URL"),
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize Discord OAuth
|
||||
discordClientID := os.Getenv("DISCORD_CLIENT_ID")
|
||||
discordClientSecret := os.Getenv("DISCORD_CLIENT_SECRET")
|
||||
if discordClientID != "" && discordClientSecret != "" {
|
||||
auth.discordConfig = &oauth2.Config{
|
||||
ClientID: discordClientID,
|
||||
ClientSecret: discordClientSecret,
|
||||
RedirectURL: os.Getenv("DISCORD_REDIRECT_URL"),
|
||||
Scopes: []string{"identify", "email"},
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://discord.com/api/oauth2/authorize",
|
||||
TokenURL: "https://discord.com/api/oauth2/token",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// GoogleLoginURL returns the Google OAuth login URL
|
||||
func (a *Auth) GoogleLoginURL() (string, error) {
|
||||
if a.googleConfig == nil {
|
||||
return "", fmt.Errorf("Google OAuth not configured")
|
||||
}
|
||||
state := uuid.New().String()
|
||||
return a.googleConfig.AuthCodeURL(state), nil
|
||||
}
|
||||
|
||||
// DiscordLoginURL returns the Discord OAuth login URL
|
||||
func (a *Auth) DiscordLoginURL() (string, error) {
|
||||
if a.discordConfig == nil {
|
||||
return "", fmt.Errorf("Discord OAuth not configured")
|
||||
}
|
||||
state := uuid.New().String()
|
||||
return a.discordConfig.AuthCodeURL(state), nil
|
||||
}
|
||||
|
||||
// GoogleCallback handles Google OAuth callback
|
||||
func (a *Auth) GoogleCallback(ctx context.Context, code string) (*Session, error) {
|
||||
if a.googleConfig == nil {
|
||||
return nil, fmt.Errorf("Google OAuth not configured")
|
||||
}
|
||||
|
||||
token, err := a.googleConfig.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange token: %w", err)
|
||||
}
|
||||
|
||||
client := a.googleConfig.Client(ctx, token)
|
||||
resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user info: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var userInfo struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode user info: %w", err)
|
||||
}
|
||||
|
||||
return a.getOrCreateUser("google", userInfo.ID, userInfo.Email, userInfo.Name)
|
||||
}
|
||||
|
||||
// DiscordCallback handles Discord OAuth callback
|
||||
func (a *Auth) DiscordCallback(ctx context.Context, code string) (*Session, error) {
|
||||
if a.discordConfig == nil {
|
||||
return nil, fmt.Errorf("Discord OAuth not configured")
|
||||
}
|
||||
|
||||
token, err := a.discordConfig.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange token: %w", err)
|
||||
}
|
||||
|
||||
client := a.discordConfig.Client(ctx, token)
|
||||
resp, err := client.Get("https://discord.com/api/users/@me")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user info: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var userInfo struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode user info: %w", err)
|
||||
}
|
||||
|
||||
return a.getOrCreateUser("discord", userInfo.ID, userInfo.Email, userInfo.Username)
|
||||
}
|
||||
|
||||
// getOrCreateUser gets or creates a user in the database
|
||||
func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session, error) {
|
||||
var userID int64
|
||||
var dbEmail, dbName string
|
||||
var isAdmin bool
|
||||
|
||||
err := a.db.QueryRow(
|
||||
"SELECT id, email, name, is_admin FROM users WHERE oauth_provider = ? AND oauth_id = ?",
|
||||
provider, oauthID,
|
||||
).Scan(&userID, &dbEmail, &dbName, &isAdmin)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
// Check if this is the first user
|
||||
var userCount int
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
isAdmin = userCount == 0
|
||||
|
||||
// Create new user
|
||||
result, err := a.db.Exec(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?)",
|
||||
email, name, provider, oauthID, isAdmin,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
userID, _ = result.LastInsertId()
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("failed to query user: %w", err)
|
||||
} else {
|
||||
// Update user info if changed
|
||||
if dbEmail != email || dbName != name {
|
||||
_, err = a.db.Exec(
|
||||
"UPDATE users SET email = ?, name = ? WHERE id = ?",
|
||||
email, name, userID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
session := &Session{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
Name: name,
|
||||
IsAdmin: isAdmin,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// CreateSession creates a new session and returns a session ID
|
||||
func (a *Auth) CreateSession(session *Session) string {
|
||||
sessionID := uuid.New().String()
|
||||
a.sessionStore[sessionID] = session
|
||||
return sessionID
|
||||
}
|
||||
|
||||
// GetSession retrieves a session by ID
|
||||
func (a *Auth) GetSession(sessionID string) (*Session, bool) {
|
||||
session, ok := a.sessionStore[sessionID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
delete(a.sessionStore, sessionID)
|
||||
return nil, false
|
||||
}
|
||||
// Refresh admin status from database
|
||||
var isAdmin bool
|
||||
err := a.db.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin)
|
||||
if err == nil {
|
||||
session.IsAdmin = isAdmin
|
||||
}
|
||||
return session, true
|
||||
}
|
||||
|
||||
// DeleteSession deletes a session
|
||||
func (a *Auth) DeleteSession(sessionID string) {
|
||||
delete(a.sessionStore, sessionID)
|
||||
}
|
||||
|
||||
// Middleware creates an authentication middleware
|
||||
func (a *Auth) Middleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("session_id")
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
session, ok := a.GetSession(cookie.Value)
|
||||
if !ok {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Add user info to request context
|
||||
ctx := context.WithValue(r.Context(), "user_id", session.UserID)
|
||||
ctx = context.WithValue(ctx, "user_email", session.Email)
|
||||
ctx = context.WithValue(ctx, "user_name", session.Name)
|
||||
ctx = context.WithValue(ctx, "is_admin", session.IsAdmin)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserID gets the user ID from context
|
||||
func GetUserID(ctx context.Context) (int64, bool) {
|
||||
userID, ok := ctx.Value("user_id").(int64)
|
||||
return userID, ok
|
||||
}
|
||||
|
||||
// IsAdmin checks if the user in context is an admin
|
||||
func IsAdmin(ctx context.Context) bool {
|
||||
isAdmin, ok := ctx.Value("is_admin").(bool)
|
||||
return ok && isAdmin
|
||||
}
|
||||
|
||||
// AdminMiddleware creates an admin-only middleware
|
||||
func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// First check authentication
|
||||
cookie, err := r.Cookie("session_id")
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
session, ok := a.GetSession(cookie.Value)
|
||||
if !ok {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Then check admin status
|
||||
if !session.IsAdmin {
|
||||
http.Error(w, "Forbidden: Admin access required", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Add user info to request context
|
||||
ctx := context.WithValue(r.Context(), "user_id", session.UserID)
|
||||
ctx = context.WithValue(ctx, "user_email", session.Email)
|
||||
ctx = context.WithValue(ctx, "user_name", session.Name)
|
||||
ctx = context.WithValue(ctx, "is_admin", session.IsAdmin)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
244
internal/auth/secrets.go
Normal file
244
internal/auth/secrets.go
Normal file
@@ -0,0 +1,244 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Secrets handles secret and token management
|
||||
type Secrets struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewSecrets creates a new secrets manager
|
||||
func NewSecrets(db *sql.DB) (*Secrets, error) {
|
||||
s := &Secrets{db: db}
|
||||
|
||||
// Ensure manager secret exists
|
||||
if err := s.ensureManagerSecret(); err != nil {
|
||||
return nil, fmt.Errorf("failed to ensure manager secret: %w", err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// ensureManagerSecret ensures a manager secret exists in the database
|
||||
func (s *Secrets) ensureManagerSecret() error {
|
||||
var count int
|
||||
err := s.db.QueryRow("SELECT COUNT(*) FROM manager_secrets").Scan(&count)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check manager secrets: %w", err)
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
// Generate new manager secret
|
||||
secret, err := generateSecret(32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate manager secret: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.db.Exec("INSERT INTO manager_secrets (secret) VALUES (?)", secret)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store manager secret: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetManagerSecret retrieves the current manager secret
|
||||
func (s *Secrets) GetManagerSecret() (string, error) {
|
||||
var secret string
|
||||
err := s.db.QueryRow("SELECT secret FROM manager_secrets ORDER BY created_at DESC LIMIT 1").Scan(&secret)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get manager secret: %w", err)
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
// GenerateRegistrationToken generates a new registration token
|
||||
func (s *Secrets) GenerateRegistrationToken(createdBy int64, expiresIn time.Duration) (string, error) {
|
||||
token, err := generateSecret(32)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate token: %w", err)
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(expiresIn)
|
||||
|
||||
_, err = s.db.Exec(
|
||||
"INSERT INTO registration_tokens (token, expires_at, created_by) VALUES (?, ?, ?)",
|
||||
token, expiresAt, createdBy,
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to store registration token: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// ValidateRegistrationToken validates a registration token
|
||||
func (s *Secrets) ValidateRegistrationToken(token string) (bool, error) {
|
||||
var used bool
|
||||
var expiresAt time.Time
|
||||
var id int64
|
||||
|
||||
err := s.db.QueryRow(
|
||||
"SELECT id, expires_at, used FROM registration_tokens WHERE token = ?",
|
||||
token,
|
||||
).Scan(&id, &expiresAt, &used)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to query token: %w", err)
|
||||
}
|
||||
|
||||
if used {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if time.Now().After(expiresAt) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Mark token as used
|
||||
_, err = s.db.Exec("UPDATE registration_tokens SET used = 1 WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to mark token as used: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ListRegistrationTokens lists all registration tokens
|
||||
func (s *Secrets) ListRegistrationTokens() ([]map[string]interface{}, error) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, token, expires_at, used, created_at, created_by
|
||||
FROM registration_tokens
|
||||
ORDER BY created_at DESC`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query tokens: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tokens []map[string]interface{}
|
||||
for rows.Next() {
|
||||
var id, createdBy sql.NullInt64
|
||||
var token string
|
||||
var expiresAt, createdAt time.Time
|
||||
var used bool
|
||||
|
||||
err := rows.Scan(&id, &token, &expiresAt, &used, &createdAt, &createdBy)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
tokens = append(tokens, map[string]interface{}{
|
||||
"id": id.Int64,
|
||||
"token": token,
|
||||
"expires_at": expiresAt,
|
||||
"used": used,
|
||||
"created_at": createdAt,
|
||||
"created_by": createdBy.Int64,
|
||||
})
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// RevokeRegistrationToken revokes a registration token
|
||||
func (s *Secrets) RevokeRegistrationToken(tokenID int64) error {
|
||||
_, err := s.db.Exec("UPDATE registration_tokens SET used = 1 WHERE id = ?", tokenID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GenerateRunnerSecret generates a unique secret for a runner
|
||||
func (s *Secrets) GenerateRunnerSecret() (string, error) {
|
||||
return generateSecret(32)
|
||||
}
|
||||
|
||||
// SignRequest signs a request with the given secret
|
||||
func SignRequest(method, path, body, secret string, timestamp time.Time) string {
|
||||
message := fmt.Sprintf("%s\n%s\n%s\n%d", method, path, body, timestamp.Unix())
|
||||
h := hmac.New(sha256.New, []byte(secret))
|
||||
h.Write([]byte(message))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// VerifyRequest verifies a signed request
|
||||
func VerifyRequest(r *http.Request, secret string, maxAge time.Duration) (bool, error) {
|
||||
signature := r.Header.Get("X-Runner-Signature")
|
||||
if signature == "" {
|
||||
return false, fmt.Errorf("missing signature")
|
||||
}
|
||||
|
||||
timestampStr := r.Header.Get("X-Runner-Timestamp")
|
||||
if timestampStr == "" {
|
||||
return false, fmt.Errorf("missing timestamp")
|
||||
}
|
||||
|
||||
var timestamp time.Time
|
||||
_, err := fmt.Sscanf(timestampStr, "%d", ×tamp)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("invalid timestamp: %w", err)
|
||||
}
|
||||
|
||||
// Check timestamp is not too old
|
||||
if time.Since(timestamp) > maxAge {
|
||||
return false, fmt.Errorf("request too old")
|
||||
}
|
||||
|
||||
// Check timestamp is not in the future (allow 1 minute clock skew)
|
||||
if timestamp.After(time.Now().Add(1 * time.Minute)) {
|
||||
return false, fmt.Errorf("timestamp in future")
|
||||
}
|
||||
|
||||
// Read body
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to read body: %w", err)
|
||||
}
|
||||
// Restore body for handler
|
||||
r.Body = io.NopCloser(strings.NewReader(string(bodyBytes)))
|
||||
|
||||
// Verify signature
|
||||
expectedSig := SignRequest(r.Method, r.URL.Path, string(bodyBytes), secret, timestamp)
|
||||
|
||||
return hmac.Equal([]byte(signature), []byte(expectedSig)), nil
|
||||
}
|
||||
|
||||
// GetRunnerSecret retrieves the runner secret for a runner ID
|
||||
func (s *Secrets) GetRunnerSecret(runnerID int64) (string, error) {
|
||||
var secret string
|
||||
err := s.db.QueryRow("SELECT runner_secret FROM runners WHERE id = ?", runnerID).Scan(&secret)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", fmt.Errorf("runner not found")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get runner secret: %w", err)
|
||||
}
|
||||
if secret == "" {
|
||||
return "", fmt.Errorf("runner not verified")
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
// generateSecret generates a random secret of the given length
|
||||
func generateSecret(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user