something
This commit is contained in:
@@ -131,14 +131,19 @@ func (s *Server) handleVerifyRunner(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Check if runner exists
|
||||
var exists bool
|
||||
err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists)
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists)
|
||||
})
|
||||
if err != nil || !exists {
|
||||
s.respondError(w, http.StatusNotFound, "Runner not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Mark runner as verified
|
||||
_, err = s.db.Exec("UPDATE runners SET verified = 1 WHERE id = ?", runnerID)
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("UPDATE runners SET verified = 1 WHERE id = ?", runnerID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify runner: %v", err))
|
||||
return
|
||||
@@ -157,14 +162,19 @@ func (s *Server) handleDeleteRunner(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Check if runner exists
|
||||
var exists bool
|
||||
err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists)
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists)
|
||||
})
|
||||
if err != nil || !exists {
|
||||
s.respondError(w, http.StatusNotFound, "Runner not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Delete runner
|
||||
_, err = s.db.Exec("DELETE FROM runners WHERE id = ?", runnerID)
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("DELETE FROM runners WHERE id = ?", runnerID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to delete runner: %v", err))
|
||||
return
|
||||
@@ -175,17 +185,31 @@ func (s *Server) handleDeleteRunner(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// handleListRunnersAdmin lists all runners with admin details
|
||||
func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
rows, err := s.db.Query(
|
||||
var rows *sql.Rows
|
||||
err := s.db.With(func(conn *sql.DB) error {
|
||||
var err error
|
||||
rows, err = conn.Query(
|
||||
`SELECT id, name, hostname, status, last_heartbeat, capabilities,
|
||||
api_key_id, api_key_scope, priority, created_at
|
||||
FROM runners ORDER BY created_at DESC`,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query runners: %v", err))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// Get the set of currently connected runners via WebSocket
|
||||
// This is the source of truth for online status
|
||||
s.runnerConnsMu.RLock()
|
||||
connectedRunners := make(map[int64]bool)
|
||||
for runnerID := range s.runnerConns {
|
||||
connectedRunners[runnerID] = true
|
||||
}
|
||||
s.runnerConnsMu.RUnlock()
|
||||
|
||||
runners := []map[string]interface{}{}
|
||||
for rows.Next() {
|
||||
var runner types.Runner
|
||||
@@ -202,11 +226,21 @@ func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
// Override status based on actual WebSocket connection state
|
||||
// The WebSocket connection is the source of truth for runner status
|
||||
actualStatus := runner.Status
|
||||
if connectedRunners[runner.ID] {
|
||||
actualStatus = types.RunnerStatusOnline
|
||||
} else if runner.Status == types.RunnerStatusOnline {
|
||||
// Database says online but not connected via WebSocket - mark as offline
|
||||
actualStatus = types.RunnerStatusOffline
|
||||
}
|
||||
|
||||
runners = append(runners, map[string]interface{}{
|
||||
"id": runner.ID,
|
||||
"name": runner.Name,
|
||||
"hostname": runner.Hostname,
|
||||
"status": runner.Status,
|
||||
"status": actualStatus,
|
||||
"last_heartbeat": runner.LastHeartbeat,
|
||||
"capabilities": runner.Capabilities,
|
||||
"api_key_id": apiKeyID.Int64,
|
||||
@@ -228,10 +262,15 @@ func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) {
|
||||
firstUserID = 0
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(
|
||||
var rows *sql.Rows
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
var err error
|
||||
rows, err = conn.Query(
|
||||
`SELECT id, email, name, oauth_provider, is_admin, created_at
|
||||
FROM users ORDER BY created_at DESC`,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query users: %v", err))
|
||||
return
|
||||
@@ -253,7 +292,9 @@ func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Get job count for this user
|
||||
var jobCount int
|
||||
err = s.db.QueryRow("SELECT COUNT(*) FROM jobs WHERE user_id = ?", userID).Scan(&jobCount)
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT COUNT(*) FROM jobs WHERE user_id = ?", userID).Scan(&jobCount)
|
||||
})
|
||||
if err != nil {
|
||||
jobCount = 0 // Default to 0 if query fails
|
||||
}
|
||||
@@ -283,18 +324,25 @@ func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Verify user exists
|
||||
var exists bool
|
||||
err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", userID).Scan(&exists)
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", userID).Scan(&exists)
|
||||
})
|
||||
if err != nil || !exists {
|
||||
s.respondError(w, http.StatusNotFound, "User not found")
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(
|
||||
var rows *sql.Rows
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
var err error
|
||||
rows, err = conn.Query(
|
||||
`SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format,
|
||||
allow_parallel_runners, timeout_seconds, blend_metadata, created_at, started_at, completed_at, error_message
|
||||
FROM jobs WHERE user_id = ? ORDER BY created_at DESC`,
|
||||
userID,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query jobs: %v", err))
|
||||
return
|
||||
|
||||
1885
internal/api/jobs.go
1885
internal/api/jobs.go
File diff suppressed because it is too large
Load Diff
@@ -2,8 +2,6 @@ package api
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bufio"
|
||||
"bytes"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
@@ -13,10 +11,10 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"jiggablend/pkg/executils"
|
||||
"jiggablend/pkg/scripts"
|
||||
"jiggablend/pkg/types"
|
||||
)
|
||||
@@ -44,7 +42,9 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Verify job exists
|
||||
var jobUserID int64
|
||||
err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID)
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID)
|
||||
})
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Job not found")
|
||||
return
|
||||
@@ -57,30 +57,36 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) {
|
||||
// Find the metadata extraction task for this job
|
||||
// First try to find task assigned to this runner, then fall back to any metadata task for this job
|
||||
var taskID int64
|
||||
err = s.db.QueryRow(
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
err := conn.QueryRow(
|
||||
`SELECT id FROM tasks WHERE job_id = ? AND task_type = ? AND runner_id = ?`,
|
||||
jobID, types.TaskTypeMetadata, runnerID,
|
||||
).Scan(&taskID)
|
||||
if err == sql.ErrNoRows {
|
||||
// Fall back to any metadata task for this job (in case assignment changed)
|
||||
err = s.db.QueryRow(
|
||||
err = conn.QueryRow(
|
||||
`SELECT id FROM tasks WHERE job_id = ? AND task_type = ? ORDER BY created_at DESC LIMIT 1`,
|
||||
jobID, types.TaskTypeMetadata,
|
||||
).Scan(&taskID)
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("metadata extraction task not found")
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Update the task to be assigned to this runner if it wasn't already
|
||||
conn.Exec(
|
||||
`UPDATE tasks SET runner_id = ? WHERE id = ? AND runner_id IS NULL`,
|
||||
runnerID, taskID,
|
||||
)
|
||||
}
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
if err.Error() == "metadata extraction task not found" {
|
||||
s.respondError(w, http.StatusNotFound, "Metadata extraction task not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to find task: %v", err))
|
||||
return
|
||||
}
|
||||
// Update the task to be assigned to this runner if it wasn't already
|
||||
s.db.Exec(
|
||||
`UPDATE tasks SET runner_id = ? WHERE id = ? AND runner_id IS NULL`,
|
||||
runnerID, taskID,
|
||||
)
|
||||
} else if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to find task: %v", err))
|
||||
return
|
||||
}
|
||||
@@ -93,20 +99,27 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Update job with metadata
|
||||
_, err = s.db.Exec(
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec(
|
||||
`UPDATE jobs SET blend_metadata = ? WHERE id = ?`,
|
||||
string(metadataJSON), jobID,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update job metadata: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Mark task as completed
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, completed_at = CURRENT_TIMESTAMP WHERE id = ?`,
|
||||
types.TaskStatusCompleted, taskID,
|
||||
)
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec(`UPDATE tasks SET status = ? WHERE id = ?`, types.TaskStatusCompleted, taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = conn.Exec(`UPDATE tasks SET completed_at = CURRENT_TIMESTAMP WHERE id = ?`, taskID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Failed to mark metadata task as completed: %v", err)
|
||||
} else {
|
||||
@@ -136,10 +149,12 @@ func (s *Server) handleGetJobMetadata(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify job belongs to user
|
||||
var jobUserID int64
|
||||
var blendMetadataJSON sql.NullString
|
||||
err = s.db.QueryRow(
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow(
|
||||
`SELECT user_id, blend_metadata FROM jobs WHERE id = ?`,
|
||||
jobID,
|
||||
).Scan(&jobUserID, &blendMetadataJSON)
|
||||
})
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Job not found")
|
||||
return
|
||||
@@ -245,64 +260,23 @@ func (s *Server) extractMetadataFromContext(jobID int64) (*types.BlendMetadata,
|
||||
return nil, fmt.Errorf("failed to get relative path for blend file: %w", err)
|
||||
}
|
||||
|
||||
// Execute Blender with Python script
|
||||
cmd := exec.Command("blender", "-b", blendFileRel, "--python", "extract_metadata.py")
|
||||
cmd.Dir = tmpDir
|
||||
// Execute Blender with Python script using executils
|
||||
result, err := executils.RunCommand(
|
||||
"blender",
|
||||
[]string{"-b", blendFileRel, "--python", "extract_metadata.py"},
|
||||
tmpDir,
|
||||
nil, // inherit environment
|
||||
jobID,
|
||||
nil, // no process tracker needed for metadata extraction
|
||||
)
|
||||
|
||||
// Capture stdout and stderr
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
// Buffer to collect stdout for JSON parsing
|
||||
var stdoutBuffer bytes.Buffer
|
||||
|
||||
// Start the command
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start blender: %w", err)
|
||||
}
|
||||
|
||||
// Stream stdout and collect for JSON parsing
|
||||
stdoutDone := make(chan bool)
|
||||
go func() {
|
||||
defer close(stdoutDone)
|
||||
scanner := bufio.NewScanner(stdoutPipe)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
stdoutBuffer.WriteString(line)
|
||||
stdoutBuffer.WriteString("\n")
|
||||
stderrOutput := ""
|
||||
stdoutOutput := ""
|
||||
if result != nil {
|
||||
stderrOutput = strings.TrimSpace(result.Stderr)
|
||||
stdoutOutput = strings.TrimSpace(result.Stdout)
|
||||
}
|
||||
}()
|
||||
|
||||
// Capture stderr for error reporting
|
||||
var stderrBuffer bytes.Buffer
|
||||
stderrDone := make(chan bool)
|
||||
go func() {
|
||||
defer close(stderrDone)
|
||||
scanner := bufio.NewScanner(stderrPipe)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
stderrBuffer.WriteString(line)
|
||||
stderrBuffer.WriteString("\n")
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for command to complete
|
||||
err = cmd.Wait()
|
||||
|
||||
// Wait for streaming goroutines to finish
|
||||
<-stdoutDone
|
||||
<-stderrDone
|
||||
|
||||
if err != nil {
|
||||
stderrOutput := strings.TrimSpace(stderrBuffer.String())
|
||||
stdoutOutput := strings.TrimSpace(stdoutBuffer.String())
|
||||
log.Printf("Blender metadata extraction failed for job %d:", jobID)
|
||||
if stderrOutput != "" {
|
||||
log.Printf("Blender stderr: %s", stderrOutput)
|
||||
@@ -317,7 +291,7 @@ func (s *Server) extractMetadataFromContext(jobID int64) (*types.BlendMetadata,
|
||||
}
|
||||
|
||||
// Parse output (metadata is printed to stdout)
|
||||
metadataJSON := strings.TrimSpace(stdoutBuffer.String())
|
||||
metadataJSON := strings.TrimSpace(result.Stdout)
|
||||
// Extract JSON from output (Blender may print other stuff)
|
||||
jsonStart := strings.Index(metadataJSON, "{")
|
||||
jsonEnd := strings.LastIndex(metadataJSON, "}")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -5,10 +5,13 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"jiggablend/internal/config"
|
||||
"jiggablend/internal/database"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -17,12 +20,31 @@ import (
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
// Context key types to avoid collisions (typed keys are safer than string keys)
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
contextKeyUserID contextKey = iota
|
||||
contextKeyUserEmail
|
||||
contextKeyUserName
|
||||
contextKeyIsAdmin
|
||||
)
|
||||
|
||||
// Configuration constants
|
||||
const (
|
||||
SessionDuration = 24 * time.Hour
|
||||
SessionCleanupInterval = 1 * time.Hour
|
||||
)
|
||||
|
||||
// Auth handles authentication
|
||||
type Auth struct {
|
||||
db *sql.DB
|
||||
db *database.DB
|
||||
cfg *config.Config
|
||||
googleConfig *oauth2.Config
|
||||
discordConfig *oauth2.Config
|
||||
sessionStore map[string]*Session
|
||||
sessionCache map[string]*Session // In-memory cache for performance
|
||||
cacheMu sync.RWMutex
|
||||
stopCleanup chan struct{}
|
||||
}
|
||||
|
||||
// Session represents a user session
|
||||
@@ -35,41 +57,53 @@ type Session struct {
|
||||
}
|
||||
|
||||
// NewAuth creates a new auth instance
|
||||
func NewAuth(db *sql.DB) (*Auth, error) {
|
||||
func NewAuth(db *database.DB, cfg *config.Config) (*Auth, error) {
|
||||
auth := &Auth{
|
||||
db: db,
|
||||
sessionStore: make(map[string]*Session),
|
||||
cfg: cfg,
|
||||
sessionCache: make(map[string]*Session),
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Initialize Google OAuth
|
||||
googleClientID := os.Getenv("GOOGLE_CLIENT_ID")
|
||||
googleClientSecret := os.Getenv("GOOGLE_CLIENT_SECRET")
|
||||
// Initialize Google OAuth from database config
|
||||
googleClientID := cfg.GoogleClientID()
|
||||
googleClientSecret := cfg.GoogleClientSecret()
|
||||
if googleClientID != "" && googleClientSecret != "" {
|
||||
auth.googleConfig = &oauth2.Config{
|
||||
ClientID: googleClientID,
|
||||
ClientSecret: googleClientSecret,
|
||||
RedirectURL: os.Getenv("GOOGLE_REDIRECT_URL"),
|
||||
RedirectURL: cfg.GoogleRedirectURL(),
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
log.Printf("Google OAuth configured")
|
||||
}
|
||||
|
||||
// Initialize Discord OAuth
|
||||
discordClientID := os.Getenv("DISCORD_CLIENT_ID")
|
||||
discordClientSecret := os.Getenv("DISCORD_CLIENT_SECRET")
|
||||
// Initialize Discord OAuth from database config
|
||||
discordClientID := cfg.DiscordClientID()
|
||||
discordClientSecret := cfg.DiscordClientSecret()
|
||||
if discordClientID != "" && discordClientSecret != "" {
|
||||
auth.discordConfig = &oauth2.Config{
|
||||
ClientID: discordClientID,
|
||||
ClientSecret: discordClientSecret,
|
||||
RedirectURL: os.Getenv("DISCORD_REDIRECT_URL"),
|
||||
RedirectURL: cfg.DiscordRedirectURL(),
|
||||
Scopes: []string{"identify", "email"},
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://discord.com/api/oauth2/authorize",
|
||||
TokenURL: "https://discord.com/api/oauth2/token",
|
||||
},
|
||||
}
|
||||
log.Printf("Discord OAuth configured")
|
||||
}
|
||||
|
||||
// Load existing sessions from database into cache
|
||||
if err := auth.loadSessionsFromDB(); err != nil {
|
||||
log.Printf("Warning: Failed to load sessions from database: %v", err)
|
||||
}
|
||||
|
||||
// Start background cleanup goroutine
|
||||
go auth.cleanupExpiredSessions()
|
||||
|
||||
// Initialize admin settings on startup to ensure they persist between boots
|
||||
if err := auth.initializeSettings(); err != nil {
|
||||
log.Printf("Warning: Failed to initialize admin settings: %v", err)
|
||||
@@ -85,19 +119,119 @@ func NewAuth(db *sql.DB) (*Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// Close stops background goroutines
|
||||
func (a *Auth) Close() {
|
||||
close(a.stopCleanup)
|
||||
}
|
||||
|
||||
// loadSessionsFromDB loads all valid sessions from database into cache
|
||||
func (a *Auth) loadSessionsFromDB() error {
|
||||
var sessions []struct {
|
||||
sessionID string
|
||||
session Session
|
||||
}
|
||||
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
rows, err := conn.Query(
|
||||
`SELECT session_id, user_id, email, name, is_admin, expires_at
|
||||
FROM sessions WHERE expires_at > CURRENT_TIMESTAMP`,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query sessions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var s struct {
|
||||
sessionID string
|
||||
session Session
|
||||
}
|
||||
err := rows.Scan(&s.sessionID, &s.session.UserID, &s.session.Email, &s.session.Name, &s.session.IsAdmin, &s.session.ExpiresAt)
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to scan session row: %v", err)
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, s)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a.cacheMu.Lock()
|
||||
defer a.cacheMu.Unlock()
|
||||
|
||||
for _, s := range sessions {
|
||||
a.sessionCache[s.sessionID] = &s.session
|
||||
}
|
||||
|
||||
if len(sessions) > 0 {
|
||||
log.Printf("Loaded %d active sessions from database", len(sessions))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupExpiredSessions periodically removes expired sessions from database and cache
|
||||
func (a *Auth) cleanupExpiredSessions() {
|
||||
ticker := time.NewTicker(SessionCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// Delete expired sessions from database
|
||||
var deleted int64
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
result, err := conn.Exec(`DELETE FROM sessions WHERE expires_at < CURRENT_TIMESTAMP`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
deleted, _ = result.RowsAffected()
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to cleanup expired sessions: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Clean up cache
|
||||
a.cacheMu.Lock()
|
||||
now := time.Now()
|
||||
for sessionID, session := range a.sessionCache {
|
||||
if now.After(session.ExpiresAt) {
|
||||
delete(a.sessionCache, sessionID)
|
||||
}
|
||||
}
|
||||
a.cacheMu.Unlock()
|
||||
|
||||
if deleted > 0 {
|
||||
log.Printf("Cleaned up %d expired sessions", deleted)
|
||||
}
|
||||
case <-a.stopCleanup:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// initializeSettings ensures all admin settings are initialized with defaults if they don't exist
|
||||
func (a *Auth) initializeSettings() error {
|
||||
// Initialize registration_enabled setting (default: true) if it doesn't exist
|
||||
var settingCount int
|
||||
err := a.db.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check registration_enabled setting: %w", err)
|
||||
}
|
||||
if settingCount == 0 {
|
||||
_, err = a.db.Exec(
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec(
|
||||
`INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)`,
|
||||
"registration_enabled", "true",
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize registration_enabled setting: %w", err)
|
||||
}
|
||||
@@ -118,7 +252,9 @@ func (a *Auth) initializeTestUser() error {
|
||||
|
||||
// Check if user already exists
|
||||
var exists bool
|
||||
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND oauth_provider = 'local')", testEmail).Scan(&exists)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND oauth_provider = 'local')", testEmail).Scan(&exists)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if test user exists: %w", err)
|
||||
}
|
||||
@@ -137,7 +273,12 @@ func (a *Auth) initializeTestUser() error {
|
||||
|
||||
// Check if this is the first user (make them admin)
|
||||
var userCount int
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check user count: %w", err)
|
||||
}
|
||||
isAdmin := userCount == 0
|
||||
|
||||
// Create test user (use email as name if no name is provided)
|
||||
@@ -147,10 +288,13 @@ func (a *Auth) initializeTestUser() error {
|
||||
}
|
||||
|
||||
// Create test user
|
||||
_, err = a.db.Exec(
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?)",
|
||||
testEmail, testName, testEmail, string(hashedPassword), isAdmin,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create test user: %w", err)
|
||||
}
|
||||
@@ -242,7 +386,9 @@ func (a *Auth) DiscordCallback(ctx context.Context, code string) (*Session, erro
|
||||
// IsRegistrationEnabled checks if new user registration is enabled
|
||||
func (a *Auth) IsRegistrationEnabled() (bool, error) {
|
||||
var value string
|
||||
err := a.db.QueryRow("SELECT value FROM settings WHERE key = ?", "registration_enabled").Scan(&value)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT value FROM settings WHERE key = ?", "registration_enabled").Scan(&value)
|
||||
})
|
||||
if err == sql.ErrNoRows {
|
||||
// Default to enabled if setting doesn't exist
|
||||
return true, nil
|
||||
@@ -262,29 +408,31 @@ func (a *Auth) SetRegistrationEnabled(enabled bool) error {
|
||||
|
||||
// Check if setting exists
|
||||
var exists bool
|
||||
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", "registration_enabled").Scan(&exists)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", "registration_enabled").Scan(&exists)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if setting exists: %w", err)
|
||||
}
|
||||
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
if exists {
|
||||
// Update existing setting
|
||||
_, err = a.db.Exec(
|
||||
_, err = conn.Exec(
|
||||
"UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?",
|
||||
value, "registration_enabled",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update setting: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Insert new setting
|
||||
_, err = a.db.Exec(
|
||||
_, err = conn.Exec(
|
||||
"INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)",
|
||||
"registration_enabled", value,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert setting: %w", err)
|
||||
}
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set registration_enabled: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -299,17 +447,21 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
|
||||
var dbProvider, dbOAuthID string
|
||||
|
||||
// First, try to find by provider + oauth_id
|
||||
err := a.db.QueryRow(
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow(
|
||||
"SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE oauth_provider = ? AND oauth_id = ?",
|
||||
provider, oauthID,
|
||||
).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID)
|
||||
})
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
// Not found by provider+oauth_id, check by email for account linking
|
||||
err = a.db.QueryRow(
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow(
|
||||
"SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE email = ?",
|
||||
email,
|
||||
).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID)
|
||||
})
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
// User doesn't exist, check if registration is enabled
|
||||
@@ -323,14 +475,26 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
|
||||
|
||||
// Check if this is the first user
|
||||
var userCount int
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check user count: %w", err)
|
||||
}
|
||||
isAdmin = userCount == 0
|
||||
|
||||
// Create new user
|
||||
err = a.db.QueryRow(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?) RETURNING id",
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
result, err := conn.Exec(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?)",
|
||||
email, name, provider, oauthID, isAdmin,
|
||||
).Scan(&userID)
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userID, err = result.LastInsertId()
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
@@ -339,10 +503,13 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
|
||||
} else {
|
||||
// User exists with same email but different provider - link accounts by updating provider info
|
||||
// This allows the user to log in with any provider that has the same email
|
||||
_, err = a.db.Exec(
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
_, err = conn.Exec(
|
||||
"UPDATE users SET oauth_provider = ?, oauth_id = ?, name = ? WHERE id = ?",
|
||||
provider, oauthID, name, userID,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to link account: %w", err)
|
||||
}
|
||||
@@ -353,10 +520,13 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
|
||||
} else {
|
||||
// User found by provider+oauth_id, update info if changed
|
||||
if dbEmail != email || dbName != name {
|
||||
_, err = a.db.Exec(
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
_, err = conn.Exec(
|
||||
"UPDATE users SET email = ?, name = ? WHERE id = ?",
|
||||
email, name, userID,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
@@ -368,41 +538,134 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
|
||||
Email: email,
|
||||
Name: name,
|
||||
IsAdmin: isAdmin,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
ExpiresAt: time.Now().Add(SessionDuration),
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// CreateSession creates a new session and returns a session ID
|
||||
// Sessions are persisted to database and cached in memory
|
||||
func (a *Auth) CreateSession(session *Session) string {
|
||||
sessionID := uuid.New().String()
|
||||
a.sessionStore[sessionID] = session
|
||||
|
||||
// Store in database first
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec(
|
||||
`INSERT INTO sessions (session_id, user_id, email, name, is_admin, expires_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
sessionID, session.UserID, session.Email, session.Name, session.IsAdmin, session.ExpiresAt,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to persist session to database: %v", err)
|
||||
// Continue anyway - session will work from cache but won't survive restart
|
||||
}
|
||||
|
||||
// Store in cache
|
||||
a.cacheMu.Lock()
|
||||
a.sessionCache[sessionID] = session
|
||||
a.cacheMu.Unlock()
|
||||
|
||||
return sessionID
|
||||
}
|
||||
|
||||
// GetSession retrieves a session by ID
|
||||
// First checks cache, then database if not found
|
||||
func (a *Auth) GetSession(sessionID string) (*Session, bool) {
|
||||
session, ok := a.sessionStore[sessionID]
|
||||
if !ok {
|
||||
// Check cache first
|
||||
a.cacheMu.RLock()
|
||||
session, ok := a.sessionCache[sessionID]
|
||||
a.cacheMu.RUnlock()
|
||||
|
||||
if ok {
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
a.DeleteSession(sessionID)
|
||||
return nil, false
|
||||
}
|
||||
// Refresh admin status from database
|
||||
var isAdmin bool
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin)
|
||||
})
|
||||
if err == nil {
|
||||
session.IsAdmin = isAdmin
|
||||
}
|
||||
return session, true
|
||||
}
|
||||
|
||||
// Not in cache, check database
|
||||
session = &Session{}
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow(
|
||||
`SELECT user_id, email, name, is_admin, expires_at
|
||||
FROM sessions WHERE session_id = ?`,
|
||||
sessionID,
|
||||
).Scan(&session.UserID, &session.Email, &session.Name, &session.IsAdmin, &session.ExpiresAt)
|
||||
})
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, false
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to query session from database: %v", err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
delete(a.sessionStore, sessionID)
|
||||
a.DeleteSession(sessionID)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Refresh admin status from database
|
||||
var isAdmin bool
|
||||
err := a.db.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin)
|
||||
})
|
||||
if err == nil {
|
||||
session.IsAdmin = isAdmin
|
||||
}
|
||||
|
||||
// Add to cache
|
||||
a.cacheMu.Lock()
|
||||
a.sessionCache[sessionID] = session
|
||||
a.cacheMu.Unlock()
|
||||
|
||||
return session, true
|
||||
}
|
||||
|
||||
// DeleteSession deletes a session
|
||||
// DeleteSession deletes a session from both cache and database
|
||||
func (a *Auth) DeleteSession(sessionID string) {
|
||||
delete(a.sessionStore, sessionID)
|
||||
// Delete from cache
|
||||
a.cacheMu.Lock()
|
||||
delete(a.sessionCache, sessionID)
|
||||
a.cacheMu.Unlock()
|
||||
|
||||
// Delete from database
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("DELETE FROM sessions WHERE session_id = ?", sessionID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to delete session from database: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// IsProductionMode returns true if running in production mode
|
||||
// This is a package-level function that checks the environment variable
|
||||
// For config-based checks, use Config.IsProductionMode()
|
||||
func IsProductionMode() bool {
|
||||
// Check environment variable first for backwards compatibility
|
||||
if os.Getenv("PRODUCTION") == "true" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsProductionModeFromConfig returns true if production mode is enabled in config
|
||||
func (a *Auth) IsProductionModeFromConfig() bool {
|
||||
return a.cfg.IsProductionMode()
|
||||
}
|
||||
|
||||
// Middleware creates an authentication middleware
|
||||
@@ -426,24 +689,24 @@ func (a *Auth) Middleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// Add user info to request context
|
||||
ctx := context.WithValue(r.Context(), "user_id", session.UserID)
|
||||
ctx = context.WithValue(ctx, "user_email", session.Email)
|
||||
ctx = context.WithValue(ctx, "user_name", session.Name)
|
||||
ctx = context.WithValue(ctx, "is_admin", session.IsAdmin)
|
||||
// Add user info to request context using typed keys
|
||||
ctx := context.WithValue(r.Context(), contextKeyUserID, session.UserID)
|
||||
ctx = context.WithValue(ctx, contextKeyUserEmail, session.Email)
|
||||
ctx = context.WithValue(ctx, contextKeyUserName, session.Name)
|
||||
ctx = context.WithValue(ctx, contextKeyIsAdmin, session.IsAdmin)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserID gets the user ID from context
|
||||
func GetUserID(ctx context.Context) (int64, bool) {
|
||||
userID, ok := ctx.Value("user_id").(int64)
|
||||
userID, ok := ctx.Value(contextKeyUserID).(int64)
|
||||
return userID, ok
|
||||
}
|
||||
|
||||
// IsAdmin checks if the user in context is an admin
|
||||
func IsAdmin(ctx context.Context) bool {
|
||||
isAdmin, ok := ctx.Value("is_admin").(bool)
|
||||
isAdmin, ok := ctx.Value(contextKeyIsAdmin).(bool)
|
||||
return ok && isAdmin
|
||||
}
|
||||
|
||||
@@ -478,19 +741,19 @@ func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// Add user info to request context
|
||||
ctx := context.WithValue(r.Context(), "user_id", session.UserID)
|
||||
ctx = context.WithValue(ctx, "user_email", session.Email)
|
||||
ctx = context.WithValue(ctx, "user_name", session.Name)
|
||||
ctx = context.WithValue(ctx, "is_admin", session.IsAdmin)
|
||||
// Add user info to request context using typed keys
|
||||
ctx := context.WithValue(r.Context(), contextKeyUserID, session.UserID)
|
||||
ctx = context.WithValue(ctx, contextKeyUserEmail, session.Email)
|
||||
ctx = context.WithValue(ctx, contextKeyUserName, session.Name)
|
||||
ctx = context.WithValue(ctx, contextKeyIsAdmin, session.IsAdmin)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// IsLocalLoginEnabled returns whether local login is enabled
|
||||
// Local login is enabled when ENABLE_LOCAL_AUTH environment variable is set to "true"
|
||||
// Checks database config first, falls back to environment variable
|
||||
func (a *Auth) IsLocalLoginEnabled() bool {
|
||||
return os.Getenv("ENABLE_LOCAL_AUTH") == "true"
|
||||
return a.cfg.IsLocalAuthEnabled()
|
||||
}
|
||||
|
||||
// IsGoogleOAuthConfigured returns whether Google OAuth is configured
|
||||
@@ -511,10 +774,12 @@ func (a *Auth) LocalLogin(username, password string) (*Session, error) {
|
||||
var dbEmail, dbName, passwordHash string
|
||||
var isAdmin bool
|
||||
|
||||
err := a.db.QueryRow(
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow(
|
||||
"SELECT id, email, name, password_hash, is_admin FROM users WHERE email = ? AND oauth_provider = 'local'",
|
||||
email,
|
||||
).Scan(&userID, &dbEmail, &dbName, &passwordHash, &isAdmin)
|
||||
})
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("invalid credentials")
|
||||
@@ -539,7 +804,7 @@ func (a *Auth) LocalLogin(username, password string) (*Session, error) {
|
||||
Email: dbEmail,
|
||||
Name: dbName,
|
||||
IsAdmin: isAdmin,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
ExpiresAt: time.Now().Add(SessionDuration),
|
||||
}
|
||||
|
||||
return session, nil
|
||||
@@ -558,7 +823,9 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error)
|
||||
|
||||
// Check if user already exists
|
||||
var exists bool
|
||||
err = a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ?)", email).Scan(&exists)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ?)", email).Scan(&exists)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if user exists: %w", err)
|
||||
}
|
||||
@@ -574,15 +841,27 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error)
|
||||
|
||||
// Check if this is the first user (make them admin)
|
||||
var userCount int
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check user count: %w", err)
|
||||
}
|
||||
isAdmin := userCount == 0
|
||||
|
||||
// Create user
|
||||
var userID int64
|
||||
err = a.db.QueryRow(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?) RETURNING id",
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
result, err := conn.Exec(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?)",
|
||||
email, name, email, string(hashedPassword), isAdmin,
|
||||
).Scan(&userID)
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userID, err = result.LastInsertId()
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
@@ -593,7 +872,7 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error)
|
||||
Email: email,
|
||||
Name: name,
|
||||
IsAdmin: isAdmin,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
ExpiresAt: time.Now().Add(SessionDuration),
|
||||
}
|
||||
|
||||
return session, nil
|
||||
@@ -603,7 +882,9 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error)
|
||||
func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) error {
|
||||
// Get current password hash
|
||||
var passwordHash string
|
||||
err := a.db.QueryRow("SELECT password_hash FROM users WHERE id = ? AND oauth_provider = 'local'", userID).Scan(&passwordHash)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT password_hash FROM users WHERE id = ? AND oauth_provider = 'local'", userID).Scan(&passwordHash)
|
||||
})
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("user not found or not a local user")
|
||||
}
|
||||
@@ -628,7 +909,10 @@ func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) err
|
||||
}
|
||||
|
||||
// Update password
|
||||
_, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), userID)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), userID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
@@ -640,7 +924,9 @@ func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) err
|
||||
func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error {
|
||||
// Verify user exists and is a local user
|
||||
var exists bool
|
||||
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ? AND oauth_provider = 'local')", targetUserID).Scan(&exists)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ? AND oauth_provider = 'local')", targetUserID).Scan(&exists)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if user exists: %w", err)
|
||||
}
|
||||
@@ -655,7 +941,10 @@ func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error
|
||||
}
|
||||
|
||||
// Update password
|
||||
_, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), targetUserID)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), targetUserID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
@@ -666,7 +955,9 @@ func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error
|
||||
// GetFirstUserID returns the ID of the first user (user with the lowest ID)
|
||||
func (a *Auth) GetFirstUserID() (int64, error) {
|
||||
var firstUserID int64
|
||||
err := a.db.QueryRow("SELECT id FROM users ORDER BY id ASC LIMIT 1").Scan(&firstUserID)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT id FROM users ORDER BY id ASC LIMIT 1").Scan(&firstUserID)
|
||||
})
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, fmt.Errorf("no users found")
|
||||
}
|
||||
@@ -680,7 +971,9 @@ func (a *Auth) GetFirstUserID() (int64, error) {
|
||||
func (a *Auth) SetUserAdminStatus(targetUserID int64, isAdmin bool) error {
|
||||
// Verify user exists
|
||||
var exists bool
|
||||
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", targetUserID).Scan(&exists)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", targetUserID).Scan(&exists)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if user exists: %w", err)
|
||||
}
|
||||
@@ -698,7 +991,10 @@ func (a *Auth) SetUserAdminStatus(targetUserID int64, isAdmin bool) error {
|
||||
}
|
||||
|
||||
// Update admin status
|
||||
_, err = a.db.Exec("UPDATE users SET is_admin = ? WHERE id = ?", isAdmin, targetUserID)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("UPDATE users SET is_admin = ? WHERE id = ?", isAdmin, targetUserID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update admin status: %w", err)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,8 @@ import (
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"jiggablend/internal/config"
|
||||
"jiggablend/internal/database"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -14,21 +15,14 @@ import (
|
||||
|
||||
// Secrets handles API key management
|
||||
type Secrets struct {
|
||||
db *sql.DB
|
||||
db *database.DB
|
||||
cfg *config.Config
|
||||
RegistrationMu sync.Mutex // Protects concurrent runner registrations
|
||||
fixedAPIKey string // Fixed API key from environment variable (optional)
|
||||
}
|
||||
|
||||
// NewSecrets creates a new secrets manager
|
||||
func NewSecrets(db *sql.DB) (*Secrets, error) {
|
||||
s := &Secrets{db: db}
|
||||
|
||||
// Check for fixed API key from environment
|
||||
if fixedKey := os.Getenv("FIXED_API_KEY"); fixedKey != "" {
|
||||
s.fixedAPIKey = fixedKey
|
||||
}
|
||||
|
||||
return s, nil
|
||||
func NewSecrets(db *database.DB, cfg *config.Config) (*Secrets, error) {
|
||||
return &Secrets{db: db, cfg: cfg}, nil
|
||||
}
|
||||
|
||||
// APIKeyInfo represents information about an API key
|
||||
@@ -61,25 +55,34 @@ func (s *Secrets) GenerateRunnerAPIKey(createdBy int64, name, description string
|
||||
keyHash := sha256.Sum256([]byte(key))
|
||||
keyHashStr := hex.EncodeToString(keyHash[:])
|
||||
|
||||
_, err = s.db.Exec(
|
||||
var keyInfo APIKeyInfo
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
result, err := conn.Exec(
|
||||
`INSERT INTO runner_api_keys (key_prefix, key_hash, name, description, scope, is_active, created_by)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
keyPrefix, keyHashStr, name, description, scope, true, createdBy,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to store API key: %w", err)
|
||||
return fmt.Errorf("failed to store API key: %w", err)
|
||||
}
|
||||
|
||||
keyID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get inserted key ID: %w", err)
|
||||
}
|
||||
|
||||
// Get the inserted key info
|
||||
var keyInfo APIKeyInfo
|
||||
err = s.db.QueryRow(
|
||||
err = conn.QueryRow(
|
||||
`SELECT id, name, description, scope, is_active, created_at, created_by
|
||||
FROM runner_api_keys WHERE key_prefix = ?`,
|
||||
keyPrefix,
|
||||
FROM runner_api_keys WHERE id = ?`,
|
||||
keyID,
|
||||
).Scan(&keyInfo.ID, &keyInfo.Name, &keyInfo.Description, &keyInfo.Scope, &keyInfo.IsActive, &keyInfo.CreatedAt, &keyInfo.CreatedBy)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve created API key: %w", err)
|
||||
return nil, fmt.Errorf("failed to create API key: %w", err)
|
||||
}
|
||||
|
||||
keyInfo.Key = key
|
||||
@@ -91,18 +94,25 @@ func (s *Secrets) generateAPIKey() (string, error) {
|
||||
// Generate random suffix
|
||||
randomBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
randomStr := hex.EncodeToString(randomBytes)
|
||||
|
||||
// Generate a unique prefix (jk_r followed by 1 random digit)
|
||||
prefixDigit := make([]byte, 1)
|
||||
if _, err := rand.Read(prefixDigit); err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to generate prefix digit: %w", err)
|
||||
}
|
||||
|
||||
prefix := fmt.Sprintf("jk_r%d", prefixDigit[0]%10)
|
||||
return fmt.Sprintf("%s_%s", prefix, randomStr), nil
|
||||
key := fmt.Sprintf("%s_%s", prefix, randomStr)
|
||||
|
||||
// Validate generated key format
|
||||
if !strings.HasPrefix(key, "jk_r") {
|
||||
return "", fmt.Errorf("generated invalid API key format: %s", key)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// ValidateRunnerAPIKey validates an API key and returns the key ID and scope if valid
|
||||
@@ -111,8 +121,9 @@ func (s *Secrets) ValidateRunnerAPIKey(apiKey string) (int64, string, error) {
|
||||
return 0, "", fmt.Errorf("API key is required")
|
||||
}
|
||||
|
||||
// Check fixed API key first (for testing/development)
|
||||
if s.fixedAPIKey != "" && apiKey == s.fixedAPIKey {
|
||||
// Check fixed API key first (from database config)
|
||||
fixedKey := s.cfg.FixedAPIKey()
|
||||
if fixedKey != "" && apiKey == fixedKey {
|
||||
// Return a special ID for fixed API key (doesn't exist in database)
|
||||
return -1, "manager", nil
|
||||
}
|
||||
@@ -137,42 +148,50 @@ func (s *Secrets) ValidateRunnerAPIKey(apiKey string) (int64, string, error) {
|
||||
var scope string
|
||||
var isActive bool
|
||||
|
||||
err := s.db.QueryRow(
|
||||
err := s.db.With(func(conn *sql.DB) error {
|
||||
err := conn.QueryRow(
|
||||
`SELECT id, scope, is_active FROM runner_api_keys
|
||||
WHERE key_prefix = ? AND key_hash = ?`,
|
||||
keyPrefix, keyHashStr,
|
||||
).Scan(&keyID, &scope, &isActive)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, "", fmt.Errorf("API key not found or invalid - please check that the key is correct and active")
|
||||
return fmt.Errorf("API key not found or invalid - please check that the key is correct and active")
|
||||
}
|
||||
if err != nil {
|
||||
return 0, "", fmt.Errorf("failed to validate API key: %w", err)
|
||||
return fmt.Errorf("failed to validate API key: %w", err)
|
||||
}
|
||||
|
||||
if !isActive {
|
||||
return 0, "", fmt.Errorf("API key is inactive")
|
||||
return fmt.Errorf("API key is inactive")
|
||||
}
|
||||
|
||||
// Update last_used_at (don't fail if this update fails)
|
||||
s.db.Exec(`UPDATE runner_api_keys SET last_used_at = ? WHERE id = ?`, time.Now(), keyID)
|
||||
conn.Exec(`UPDATE runner_api_keys SET last_used_at = ? WHERE id = ?`, time.Now(), keyID)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
|
||||
return keyID, scope, nil
|
||||
}
|
||||
|
||||
// ListRunnerAPIKeys lists all runner API keys
|
||||
func (s *Secrets) ListRunnerAPIKeys() ([]APIKeyInfo, error) {
|
||||
rows, err := s.db.Query(
|
||||
var keys []APIKeyInfo
|
||||
err := s.db.With(func(conn *sql.DB) error {
|
||||
rows, err := conn.Query(
|
||||
`SELECT id, key_prefix, name, description, scope, is_active, created_at, created_by
|
||||
FROM runner_api_keys
|
||||
ORDER BY created_at DESC`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query API keys: %w", err)
|
||||
return fmt.Errorf("failed to query API keys: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var keys []APIKeyInfo
|
||||
for rows.Next() {
|
||||
var key APIKeyInfo
|
||||
var description sql.NullString
|
||||
@@ -188,20 +207,28 @@ func (s *Secrets) ListRunnerAPIKeys() ([]APIKeyInfo, error) {
|
||||
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// RevokeRunnerAPIKey revokes (deactivates) a runner API key
|
||||
func (s *Secrets) RevokeRunnerAPIKey(keyID int64) error {
|
||||
_, err := s.db.Exec("UPDATE runner_api_keys SET is_active = false WHERE id = ?", keyID)
|
||||
return s.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("UPDATE runner_api_keys SET is_active = false WHERE id = ?", keyID)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteRunnerAPIKey deletes a runner API key
|
||||
func (s *Secrets) DeleteRunnerAPIKey(keyID int64) error {
|
||||
_, err := s.db.Exec("DELETE FROM runner_api_keys WHERE id = ?", keyID)
|
||||
return s.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("DELETE FROM runner_api_keys WHERE id = ?", keyID)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
|
||||
303
internal/config/config.go
Normal file
303
internal/config/config.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"jiggablend/internal/database"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Config keys stored in database
|
||||
const (
|
||||
KeyGoogleClientID = "google_client_id"
|
||||
KeyGoogleClientSecret = "google_client_secret"
|
||||
KeyGoogleRedirectURL = "google_redirect_url"
|
||||
KeyDiscordClientID = "discord_client_id"
|
||||
KeyDiscordClientSecret = "discord_client_secret"
|
||||
KeyDiscordRedirectURL = "discord_redirect_url"
|
||||
KeyEnableLocalAuth = "enable_local_auth"
|
||||
KeyFixedAPIKey = "fixed_api_key"
|
||||
KeyRegistrationEnabled = "registration_enabled"
|
||||
KeyProductionMode = "production_mode"
|
||||
KeyAllowedOrigins = "allowed_origins"
|
||||
)
|
||||
|
||||
// Config manages application configuration stored in the database
|
||||
type Config struct {
|
||||
db *database.DB
|
||||
}
|
||||
|
||||
// NewConfig creates a new config manager
|
||||
func NewConfig(db *database.DB) *Config {
|
||||
return &Config{db: db}
|
||||
}
|
||||
|
||||
// InitializeFromEnv loads configuration from environment variables on first run
|
||||
// Environment variables take precedence only if the config key doesn't exist in the database
|
||||
// This allows first-run setup via env vars, then subsequent runs use database values
|
||||
func (c *Config) InitializeFromEnv() error {
|
||||
envMappings := []struct {
|
||||
envKey string
|
||||
configKey string
|
||||
sensitive bool
|
||||
}{
|
||||
{"GOOGLE_CLIENT_ID", KeyGoogleClientID, false},
|
||||
{"GOOGLE_CLIENT_SECRET", KeyGoogleClientSecret, true},
|
||||
{"GOOGLE_REDIRECT_URL", KeyGoogleRedirectURL, false},
|
||||
{"DISCORD_CLIENT_ID", KeyDiscordClientID, false},
|
||||
{"DISCORD_CLIENT_SECRET", KeyDiscordClientSecret, true},
|
||||
{"DISCORD_REDIRECT_URL", KeyDiscordRedirectURL, false},
|
||||
{"ENABLE_LOCAL_AUTH", KeyEnableLocalAuth, false},
|
||||
{"FIXED_API_KEY", KeyFixedAPIKey, true},
|
||||
{"PRODUCTION", KeyProductionMode, false},
|
||||
{"ALLOWED_ORIGINS", KeyAllowedOrigins, false},
|
||||
}
|
||||
|
||||
for _, mapping := range envMappings {
|
||||
envValue := os.Getenv(mapping.envKey)
|
||||
if envValue == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if config already exists in database
|
||||
exists, err := c.Exists(mapping.configKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check config %s: %w", mapping.configKey, err)
|
||||
}
|
||||
|
||||
if !exists {
|
||||
// Store env value in database
|
||||
if err := c.Set(mapping.configKey, envValue); err != nil {
|
||||
return fmt.Errorf("failed to store config %s: %w", mapping.configKey, err)
|
||||
}
|
||||
if mapping.sensitive {
|
||||
log.Printf("Stored config from env: %s = [REDACTED]", mapping.configKey)
|
||||
} else {
|
||||
log.Printf("Stored config from env: %s = %s", mapping.configKey, envValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a config value from the database
|
||||
func (c *Config) Get(key string) (string, error) {
|
||||
var value string
|
||||
err := c.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT value FROM settings WHERE key = ?", key).Scan(&value)
|
||||
})
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get config %s: %w", key, err)
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// GetWithDefault retrieves a config value or returns a default if not set
|
||||
func (c *Config) GetWithDefault(key, defaultValue string) string {
|
||||
value, err := c.Get(key)
|
||||
if err != nil || value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// GetBool retrieves a boolean config value
|
||||
func (c *Config) GetBool(key string) (bool, error) {
|
||||
value, err := c.Get(key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return value == "true" || value == "1", nil
|
||||
}
|
||||
|
||||
// GetBoolWithDefault retrieves a boolean config value or returns a default
|
||||
func (c *Config) GetBoolWithDefault(key string, defaultValue bool) bool {
|
||||
value, err := c.GetBool(key)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
// If the key doesn't exist, Get returns empty string which becomes false
|
||||
// Check if key exists to distinguish between "false" and "not set"
|
||||
exists, _ := c.Exists(key)
|
||||
if !exists {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// GetInt retrieves an integer config value
|
||||
func (c *Config) GetInt(key string) (int, error) {
|
||||
value, err := c.Get(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if value == "" {
|
||||
return 0, nil
|
||||
}
|
||||
return strconv.Atoi(value)
|
||||
}
|
||||
|
||||
// GetIntWithDefault retrieves an integer config value or returns a default
|
||||
func (c *Config) GetIntWithDefault(key string, defaultValue int) int {
|
||||
value, err := c.GetInt(key)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
exists, _ := c.Exists(key)
|
||||
if !exists {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// Set stores a config value in the database
|
||||
func (c *Config) Set(key, value string) error {
|
||||
// Use upsert pattern
|
||||
exists, err := c.Exists(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.db.With(func(conn *sql.DB) error {
|
||||
if exists {
|
||||
_, err = conn.Exec(
|
||||
"UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?",
|
||||
value, key,
|
||||
)
|
||||
} else {
|
||||
_, err = conn.Exec(
|
||||
"INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)",
|
||||
key, value,
|
||||
)
|
||||
}
|
||||
return err
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set config %s: %w", key, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetBool stores a boolean config value
|
||||
func (c *Config) SetBool(key string, value bool) error {
|
||||
strValue := "false"
|
||||
if value {
|
||||
strValue = "true"
|
||||
}
|
||||
return c.Set(key, strValue)
|
||||
}
|
||||
|
||||
// SetInt stores an integer config value
|
||||
func (c *Config) SetInt(key string, value int) error {
|
||||
return c.Set(key, strconv.Itoa(value))
|
||||
}
|
||||
|
||||
// Delete removes a config value from the database
|
||||
func (c *Config) Delete(key string) error {
|
||||
err := c.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("DELETE FROM settings WHERE key = ?", key)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete config %s: %w", key, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if a config key exists in the database
|
||||
func (c *Config) Exists(key string) (bool, error) {
|
||||
var exists bool
|
||||
err := c.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", key).Scan(&exists)
|
||||
})
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check config existence %s: %w", key, err)
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// GetAll returns all config values (for debugging/admin purposes)
|
||||
func (c *Config) GetAll() (map[string]string, error) {
|
||||
var result map[string]string
|
||||
err := c.db.With(func(conn *sql.DB) error {
|
||||
rows, err := conn.Query("SELECT key, value FROM settings")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get all config: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
result = make(map[string]string)
|
||||
for rows.Next() {
|
||||
var key, value string
|
||||
if err := rows.Scan(&key, &value); err != nil {
|
||||
return fmt.Errorf("failed to scan config row: %w", err)
|
||||
}
|
||||
result[key] = value
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// --- Convenience methods for specific config values ---
|
||||
|
||||
// GoogleClientID returns the Google OAuth client ID
|
||||
func (c *Config) GoogleClientID() string {
|
||||
return c.GetWithDefault(KeyGoogleClientID, "")
|
||||
}
|
||||
|
||||
// GoogleClientSecret returns the Google OAuth client secret
|
||||
func (c *Config) GoogleClientSecret() string {
|
||||
return c.GetWithDefault(KeyGoogleClientSecret, "")
|
||||
}
|
||||
|
||||
// GoogleRedirectURL returns the Google OAuth redirect URL
|
||||
func (c *Config) GoogleRedirectURL() string {
|
||||
return c.GetWithDefault(KeyGoogleRedirectURL, "")
|
||||
}
|
||||
|
||||
// DiscordClientID returns the Discord OAuth client ID
|
||||
func (c *Config) DiscordClientID() string {
|
||||
return c.GetWithDefault(KeyDiscordClientID, "")
|
||||
}
|
||||
|
||||
// DiscordClientSecret returns the Discord OAuth client secret
|
||||
func (c *Config) DiscordClientSecret() string {
|
||||
return c.GetWithDefault(KeyDiscordClientSecret, "")
|
||||
}
|
||||
|
||||
// DiscordRedirectURL returns the Discord OAuth redirect URL
|
||||
func (c *Config) DiscordRedirectURL() string {
|
||||
return c.GetWithDefault(KeyDiscordRedirectURL, "")
|
||||
}
|
||||
|
||||
// IsLocalAuthEnabled returns whether local authentication is enabled
|
||||
func (c *Config) IsLocalAuthEnabled() bool {
|
||||
return c.GetBoolWithDefault(KeyEnableLocalAuth, false)
|
||||
}
|
||||
|
||||
// FixedAPIKey returns the fixed API key for testing
|
||||
func (c *Config) FixedAPIKey() string {
|
||||
return c.GetWithDefault(KeyFixedAPIKey, "")
|
||||
}
|
||||
|
||||
// IsProductionMode returns whether production mode is enabled
|
||||
func (c *Config) IsProductionMode() bool {
|
||||
return c.GetBoolWithDefault(KeyProductionMode, false)
|
||||
}
|
||||
|
||||
// AllowedOrigins returns the allowed CORS origins
|
||||
func (c *Config) AllowedOrigins() string {
|
||||
return c.GetWithDefault(KeyAllowedOrigins, "")
|
||||
}
|
||||
|
||||
@@ -4,18 +4,20 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
_ "github.com/marcboeker/go-duckdb/v2"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// DB wraps the database connection
|
||||
// DB wraps the database connection with mutex protection
|
||||
type DB struct {
|
||||
*sql.DB
|
||||
db *sql.DB
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewDB creates a new database connection
|
||||
func NewDB(dbPath string) (*DB, error) {
|
||||
db, err := sql.Open("duckdb", dbPath)
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
@@ -24,7 +26,12 @@ func NewDB(dbPath string) (*DB, error) {
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
database := &DB{DB: db}
|
||||
// Enable foreign keys for SQLite
|
||||
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
|
||||
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
|
||||
}
|
||||
|
||||
database := &DB{db: db}
|
||||
if err := database.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("failed to migrate database: %w", err)
|
||||
}
|
||||
@@ -32,58 +39,74 @@ func NewDB(dbPath string) (*DB, error) {
|
||||
return database, nil
|
||||
}
|
||||
|
||||
// With executes a function with mutex-protected access to the database
|
||||
// The function receives the underlying *sql.DB connection
|
||||
func (db *DB) With(fn func(*sql.DB) error) error {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
return fn(db.db)
|
||||
}
|
||||
|
||||
// WithTx executes a function within a transaction with mutex protection
|
||||
// The function receives a *sql.Tx transaction
|
||||
// If the function returns an error, the transaction is rolled back
|
||||
// If the function returns nil, the transaction is committed
|
||||
func (db *DB) WithTx(fn func(*sql.Tx) error) error {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
tx, err := db.db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
if rbErr := tx.Rollback(); rbErr != nil {
|
||||
return fmt.Errorf("transaction error: %w, rollback error: %v", err, rbErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrate runs database migrations
|
||||
func (db *DB) migrate() error {
|
||||
// Create sequences for auto-incrementing primary keys
|
||||
sequences := []string{
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_users_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_jobs_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_runners_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_tasks_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_job_files_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_manager_secrets_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_registration_tokens_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_runner_api_keys_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_task_logs_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_task_steps_id START 1`,
|
||||
}
|
||||
|
||||
for _, seq := range sequences {
|
||||
if _, err := db.Exec(seq); err != nil {
|
||||
return fmt.Errorf("failed to create sequence: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// SQLite uses INTEGER PRIMARY KEY AUTOINCREMENT instead of sequences
|
||||
schema := `
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_users_id'),
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
oauth_provider TEXT NOT NULL,
|
||||
oauth_id TEXT NOT NULL,
|
||||
password_hash TEXT,
|
||||
is_admin BOOLEAN NOT NULL DEFAULT false,
|
||||
is_admin INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(oauth_provider, oauth_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS runner_api_keys (
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_runner_api_keys_id'),
|
||||
key_prefix TEXT NOT NULL, -- First part of API key (e.g., "jk_r1")
|
||||
key_hash TEXT NOT NULL, -- SHA256 hash of full API key
|
||||
name TEXT NOT NULL, -- Human-readable name
|
||||
description TEXT, -- Optional description
|
||||
scope TEXT NOT NULL DEFAULT 'user', -- 'manager' or 'user' - manager scope allows all jobs, user scope only allows jobs from key owner
|
||||
is_active BOOLEAN NOT NULL DEFAULT true,
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
key_prefix TEXT NOT NULL,
|
||||
key_hash TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
scope TEXT NOT NULL DEFAULT 'user',
|
||||
is_active INTEGER NOT NULL DEFAULT 1,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
created_by BIGINT,
|
||||
created_by INTEGER,
|
||||
FOREIGN KEY (created_by) REFERENCES users(id),
|
||||
UNIQUE(key_prefix)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_jobs_id'),
|
||||
user_id BIGINT NOT NULL,
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL,
|
||||
job_type TEXT NOT NULL DEFAULT 'render',
|
||||
name TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
@@ -91,9 +114,11 @@ func (db *DB) migrate() error {
|
||||
frame_start INTEGER,
|
||||
frame_end INTEGER,
|
||||
output_format TEXT,
|
||||
allow_parallel_runners BOOLEAN NOT NULL DEFAULT true,
|
||||
allow_parallel_runners INTEGER NOT NULL DEFAULT 1,
|
||||
timeout_seconds INTEGER DEFAULT 86400,
|
||||
blend_metadata TEXT,
|
||||
retry_count INTEGER NOT NULL DEFAULT 0,
|
||||
max_retries INTEGER NOT NULL DEFAULT 3,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
@@ -102,25 +127,25 @@ func (db *DB) migrate() error {
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS runners (
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_runners_id'),
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
hostname TEXT NOT NULL,
|
||||
ip_address TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'offline',
|
||||
last_heartbeat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
capabilities TEXT,
|
||||
api_key_id BIGINT, -- Reference to the API key used for this runner
|
||||
api_key_scope TEXT NOT NULL DEFAULT 'user', -- Scope of the API key ('manager' or 'user')
|
||||
api_key_id INTEGER,
|
||||
api_key_scope TEXT NOT NULL DEFAULT 'user',
|
||||
priority INTEGER NOT NULL DEFAULT 100,
|
||||
fingerprint TEXT, -- Hardware fingerprint (NULL for fixed API keys)
|
||||
fingerprint TEXT,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (api_key_id) REFERENCES runner_api_keys(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_tasks_id'),
|
||||
job_id BIGINT NOT NULL,
|
||||
runner_id BIGINT,
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
job_id INTEGER NOT NULL,
|
||||
runner_id INTEGER,
|
||||
frame_start INTEGER NOT NULL,
|
||||
frame_end INTEGER NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
@@ -133,44 +158,50 @@ func (db *DB) migrate() error {
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
error_message TEXT
|
||||
error_message TEXT,
|
||||
FOREIGN KEY (job_id) REFERENCES jobs(id),
|
||||
FOREIGN KEY (runner_id) REFERENCES runners(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS job_files (
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_job_files_id'),
|
||||
job_id BIGINT NOT NULL,
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
job_id INTEGER NOT NULL,
|
||||
file_type TEXT NOT NULL,
|
||||
file_path TEXT NOT NULL,
|
||||
file_name TEXT NOT NULL,
|
||||
file_size BIGINT NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
file_size INTEGER NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (job_id) REFERENCES jobs(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS manager_secrets (
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_manager_secrets_id'),
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
secret TEXT UNIQUE NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS task_logs (
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_task_logs_id'),
|
||||
task_id BIGINT NOT NULL,
|
||||
runner_id BIGINT,
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_id INTEGER NOT NULL,
|
||||
runner_id INTEGER,
|
||||
log_level TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
step_name TEXT,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (task_id) REFERENCES tasks(id),
|
||||
FOREIGN KEY (runner_id) REFERENCES runners(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS task_steps (
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_task_steps_id'),
|
||||
task_id BIGINT NOT NULL,
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_id INTEGER NOT NULL,
|
||||
step_name TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
duration_ms INTEGER,
|
||||
error_message TEXT
|
||||
error_message TEXT,
|
||||
FOREIGN KEY (task_id) REFERENCES tasks(id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_user_id ON jobs(user_id);
|
||||
@@ -184,6 +215,7 @@ func (db *DB) migrate() error {
|
||||
CREATE INDEX IF NOT EXISTS idx_job_files_job_id ON job_files(job_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_runner_api_keys_prefix ON runner_api_keys(key_prefix);
|
||||
CREATE INDEX IF NOT EXISTS idx_runner_api_keys_active ON runner_api_keys(is_active);
|
||||
CREATE INDEX IF NOT EXISTS idx_runner_api_keys_created_by ON runner_api_keys(created_by);
|
||||
CREATE INDEX IF NOT EXISTS idx_runners_api_key_id ON runners(api_key_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_task_logs_task_id_created_at ON task_logs(task_id, created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_task_logs_task_id_id ON task_logs(task_id, id DESC);
|
||||
@@ -196,9 +228,28 @@ func (db *DB) migrate() error {
|
||||
value TEXT NOT NULL,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT UNIQUE NOT NULL,
|
||||
user_id INTEGER NOT NULL,
|
||||
email TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
is_admin INTEGER NOT NULL DEFAULT 0,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_session_id ON sessions(session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_expires_at ON sessions(expires_at);
|
||||
`
|
||||
|
||||
if _, err := db.Exec(schema); err != nil {
|
||||
if err := db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec(schema)
|
||||
return err
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to create schema: %w", err)
|
||||
}
|
||||
|
||||
@@ -212,19 +263,26 @@ func (db *DB) migrate() error {
|
||||
}
|
||||
|
||||
for _, migration := range migrations {
|
||||
// DuckDB supports IF NOT EXISTS for ALTER TABLE, so we can safely execute
|
||||
if _, err := db.Exec(migration); err != nil {
|
||||
if err := db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec(migration)
|
||||
return err
|
||||
}); err != nil {
|
||||
// Log but don't fail - column might already exist or table might not exist yet
|
||||
// This is fine for migrations that run after schema creation
|
||||
// For the file_size migration, if it fails (e.g., already BIGINT), that's fine
|
||||
log.Printf("Migration warning: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize registration_enabled setting (default: true) if it doesn't exist
|
||||
var settingCount int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
|
||||
err := db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
|
||||
})
|
||||
if err == nil && settingCount == 0 {
|
||||
_, err = db.Exec("INSERT INTO settings (key, value) VALUES (?, ?)", "registration_enabled", "true")
|
||||
err = db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("INSERT INTO settings (key, value) VALUES (?, ?)", "registration_enabled", "true")
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
// Log but don't fail - setting might have been created by another process
|
||||
log.Printf("Note: Could not initialize registration_enabled setting: %v", err)
|
||||
@@ -234,7 +292,16 @@ func (db *DB) migrate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ping checks the database connection
|
||||
func (db *DB) Ping() error {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
return db.db.Ping()
|
||||
}
|
||||
|
||||
// Close closes the database connection
|
||||
func (db *DB) Close() error {
|
||||
return db.DB.Close()
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
return db.db.Close()
|
||||
}
|
||||
|
||||
@@ -1,31 +1,87 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
|
||||
// Level represents log severity
|
||||
type Level int
|
||||
|
||||
const (
|
||||
LevelDebug Level = iota
|
||||
LevelInfo
|
||||
LevelWarn
|
||||
LevelError
|
||||
)
|
||||
|
||||
var levelNames = map[Level]string{
|
||||
LevelDebug: "DEBUG",
|
||||
LevelInfo: "INFO",
|
||||
LevelWarn: "WARN",
|
||||
LevelError: "ERROR",
|
||||
}
|
||||
|
||||
// ParseLevel parses a level string into a Level
|
||||
func ParseLevel(s string) Level {
|
||||
switch s {
|
||||
case "debug", "DEBUG":
|
||||
return LevelDebug
|
||||
case "info", "INFO":
|
||||
return LevelInfo
|
||||
case "warn", "WARN", "warning", "WARNING":
|
||||
return LevelWarn
|
||||
case "error", "ERROR":
|
||||
return LevelError
|
||||
default:
|
||||
return LevelInfo
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
defaultLogger *Logger
|
||||
once sync.Once
|
||||
currentLevel Level = LevelInfo
|
||||
)
|
||||
|
||||
// Logger wraps the standard log.Logger with file and stdout output
|
||||
// Logger wraps the standard log.Logger with optional file output and levels
|
||||
type Logger struct {
|
||||
*log.Logger
|
||||
fileWriter io.WriteCloser
|
||||
}
|
||||
|
||||
// Init initializes the default logger with both file and stdout output
|
||||
func Init(logDir, logFileName string, maxSizeMB int, maxBackups int, maxAgeDays int) error {
|
||||
// SetLevel sets the global log level
|
||||
func SetLevel(level Level) {
|
||||
currentLevel = level
|
||||
}
|
||||
|
||||
// GetLevel returns the current log level
|
||||
func GetLevel() Level {
|
||||
return currentLevel
|
||||
}
|
||||
|
||||
// InitStdout initializes the logger to only write to stdout
|
||||
func InitStdout() {
|
||||
once.Do(func() {
|
||||
log.SetOutput(os.Stdout)
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
defaultLogger = &Logger{
|
||||
Logger: log.Default(),
|
||||
fileWriter: nil,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// InitWithFile initializes the logger with both file and stdout output
|
||||
// The file is truncated on each start
|
||||
func InitWithFile(logPath string) error {
|
||||
var err error
|
||||
once.Do(func() {
|
||||
defaultLogger, err = New(logDir, logFileName, maxSizeMB, maxBackups, maxAgeDays)
|
||||
defaultLogger, err = NewWithFile(logPath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -37,22 +93,19 @@ func Init(logDir, logFileName string, maxSizeMB int, maxBackups int, maxAgeDays
|
||||
return err
|
||||
}
|
||||
|
||||
// New creates a new logger that writes to both stdout and a log file
|
||||
func New(logDir, logFileName string, maxSizeMB int, maxBackups int, maxAgeDays int) (*Logger, error) {
|
||||
// NewWithFile creates a new logger that writes to both stdout and a log file
|
||||
// The file is truncated on each start
|
||||
func NewWithFile(logPath string) (*Logger, error) {
|
||||
// Ensure log directory exists
|
||||
logDir := filepath.Dir(logPath)
|
||||
if err := os.MkdirAll(logDir, 0755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logPath := filepath.Join(logDir, logFileName)
|
||||
|
||||
// Create file writer with rotation
|
||||
fileWriter := &lumberjack.Logger{
|
||||
Filename: logPath,
|
||||
MaxSize: maxSizeMB, // megabytes
|
||||
MaxBackups: maxBackups, // number of backup files
|
||||
MaxAge: maxAgeDays, // days
|
||||
Compress: true, // compress old log files
|
||||
// Create/truncate the log file
|
||||
fileWriter, err := os.Create(logPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create multi-writer that writes to both stdout and file
|
||||
@@ -80,48 +133,91 @@ func GetDefault() *Logger {
|
||||
return defaultLogger
|
||||
}
|
||||
|
||||
// Printf logs a formatted message
|
||||
func Printf(format string, v ...interface{}) {
|
||||
if defaultLogger != nil {
|
||||
defaultLogger.Printf(format, v...)
|
||||
} else {
|
||||
log.Printf(format, v...)
|
||||
// logf logs a formatted message at the given level
|
||||
func logf(level Level, format string, v ...interface{}) {
|
||||
if level < currentLevel {
|
||||
return
|
||||
}
|
||||
prefix := fmt.Sprintf("[%s] ", levelNames[level])
|
||||
msg := fmt.Sprintf(format, v...)
|
||||
log.Print(prefix + msg)
|
||||
}
|
||||
|
||||
// Print logs a message
|
||||
func Print(v ...interface{}) {
|
||||
if defaultLogger != nil {
|
||||
defaultLogger.Print(v...)
|
||||
} else {
|
||||
log.Print(v...)
|
||||
// logln logs a message at the given level
|
||||
func logln(level Level, v ...interface{}) {
|
||||
if level < currentLevel {
|
||||
return
|
||||
}
|
||||
prefix := fmt.Sprintf("[%s] ", levelNames[level])
|
||||
msg := fmt.Sprint(v...)
|
||||
log.Print(prefix + msg)
|
||||
}
|
||||
|
||||
// Println logs a message with newline
|
||||
func Println(v ...interface{}) {
|
||||
if defaultLogger != nil {
|
||||
defaultLogger.Println(v...)
|
||||
} else {
|
||||
log.Println(v...)
|
||||
}
|
||||
// Debug logs a debug message
|
||||
func Debug(v ...interface{}) {
|
||||
logln(LevelDebug, v...)
|
||||
}
|
||||
|
||||
// Fatal logs a message and exits
|
||||
// Debugf logs a formatted debug message
|
||||
func Debugf(format string, v ...interface{}) {
|
||||
logf(LevelDebug, format, v...)
|
||||
}
|
||||
|
||||
// Info logs an info message
|
||||
func Info(v ...interface{}) {
|
||||
logln(LevelInfo, v...)
|
||||
}
|
||||
|
||||
// Infof logs a formatted info message
|
||||
func Infof(format string, v ...interface{}) {
|
||||
logf(LevelInfo, format, v...)
|
||||
}
|
||||
|
||||
// Warn logs a warning message
|
||||
func Warn(v ...interface{}) {
|
||||
logln(LevelWarn, v...)
|
||||
}
|
||||
|
||||
// Warnf logs a formatted warning message
|
||||
func Warnf(format string, v ...interface{}) {
|
||||
logf(LevelWarn, format, v...)
|
||||
}
|
||||
|
||||
// Error logs an error message
|
||||
func Error(v ...interface{}) {
|
||||
logln(LevelError, v...)
|
||||
}
|
||||
|
||||
// Errorf logs a formatted error message
|
||||
func Errorf(format string, v ...interface{}) {
|
||||
logf(LevelError, format, v...)
|
||||
}
|
||||
|
||||
// Fatal logs an error message and exits
|
||||
func Fatal(v ...interface{}) {
|
||||
if defaultLogger != nil {
|
||||
defaultLogger.Fatal(v...)
|
||||
} else {
|
||||
log.Fatal(v...)
|
||||
}
|
||||
logln(LevelError, v...)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Fatalf logs a formatted message and exits
|
||||
// Fatalf logs a formatted error message and exits
|
||||
func Fatalf(format string, v ...interface{}) {
|
||||
if defaultLogger != nil {
|
||||
defaultLogger.Fatalf(format, v...)
|
||||
} else {
|
||||
log.Fatalf(format, v...)
|
||||
}
|
||||
logf(LevelError, format, v...)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// --- Backwards compatibility (maps to Info level) ---
|
||||
|
||||
// Printf logs a formatted message at Info level
|
||||
func Printf(format string, v ...interface{}) {
|
||||
logf(LevelInfo, format, v...)
|
||||
}
|
||||
|
||||
// Print logs a message at Info level
|
||||
func Print(v ...interface{}) {
|
||||
logln(LevelInfo, v...)
|
||||
}
|
||||
|
||||
// Println logs a message at Info level
|
||||
func Println(v ...interface{}) {
|
||||
logln(LevelInfo, v...)
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"jiggablend/pkg/executils"
|
||||
"jiggablend/pkg/scripts"
|
||||
"jiggablend/pkg/types"
|
||||
|
||||
@@ -45,19 +46,19 @@ type Client struct {
|
||||
stopChan chan struct{}
|
||||
stepStartTimes map[string]time.Time // key: "taskID:stepName"
|
||||
stepTimesMu sync.RWMutex
|
||||
workspaceDir string // Persistent workspace directory for this runner
|
||||
runningProcs sync.Map // map[int64]*exec.Cmd - tracks running processes by task ID
|
||||
capabilities map[string]interface{} // Cached capabilities from initial probe (includes bools and numbers)
|
||||
capabilitiesMu sync.RWMutex // Protects capabilities
|
||||
hwAccelCache map[string]bool // Cached hardware acceleration detection results
|
||||
hwAccelCacheMu sync.RWMutex // Protects hwAccelCache
|
||||
vaapiDevices []string // Cached VAAPI device paths (all available devices)
|
||||
vaapiDevicesMu sync.RWMutex // Protects vaapiDevices
|
||||
allocatedDevices map[int64]string // map[taskID]device - tracks which device is allocated to which task
|
||||
allocatedDevicesMu sync.RWMutex // Protects allocatedDevices
|
||||
longRunningClient *http.Client // HTTP client for long-running operations (no timeout)
|
||||
fingerprint string // Unique hardware fingerprint for this runner
|
||||
fingerprintMu sync.RWMutex // Protects fingerprint
|
||||
workspaceDir string // Persistent workspace directory for this runner
|
||||
processTracker *executils.ProcessTracker // Tracks running processes for cleanup
|
||||
capabilities map[string]interface{} // Cached capabilities from initial probe (includes bools and numbers)
|
||||
capabilitiesMu sync.RWMutex // Protects capabilities
|
||||
hwAccelCache map[string]bool // Cached hardware acceleration detection results
|
||||
hwAccelCacheMu sync.RWMutex // Protects hwAccelCache
|
||||
vaapiDevices []string // Cached VAAPI device paths (all available devices)
|
||||
vaapiDevicesMu sync.RWMutex // Protects vaapiDevices
|
||||
allocatedDevices map[int64]string // map[taskID]device - tracks which device is allocated to which task
|
||||
allocatedDevicesMu sync.RWMutex // Protects allocatedDevices
|
||||
longRunningClient *http.Client // HTTP client for long-running operations (no timeout)
|
||||
fingerprint string // Unique hardware fingerprint for this runner
|
||||
fingerprintMu sync.RWMutex // Protects fingerprint
|
||||
}
|
||||
|
||||
// NewClient creates a new runner client
|
||||
@@ -70,6 +71,7 @@ func NewClient(managerURL, name, hostname string) *Client {
|
||||
longRunningClient: &http.Client{Timeout: 0}, // No timeout for long-running operations (context downloads, file uploads/downloads)
|
||||
stopChan: make(chan struct{}),
|
||||
stepStartTimes: make(map[string]time.Time),
|
||||
processTracker: executils.NewProcessTracker(),
|
||||
}
|
||||
// Generate fingerprint immediately
|
||||
client.generateFingerprint()
|
||||
@@ -226,12 +228,6 @@ func (c *Client) probeCapabilities() map[string]interface{} {
|
||||
c.probeGPUCapabilities(capabilities)
|
||||
} else {
|
||||
capabilities["ffmpeg"] = false
|
||||
// Set defaults when ffmpeg is not available
|
||||
capabilities["vaapi"] = false
|
||||
capabilities["vaapi_gpu_count"] = 0
|
||||
capabilities["nvenc"] = false
|
||||
capabilities["nvenc_gpu_count"] = 0
|
||||
capabilities["video_gpu_count"] = 0
|
||||
}
|
||||
|
||||
return capabilities
|
||||
@@ -256,60 +252,6 @@ func (c *Client) probeGPUCapabilities(capabilities map[string]interface{}) {
|
||||
log.Printf("Available hardware encoders: %v", getKeys(hwEncoders))
|
||||
}
|
||||
|
||||
// Check for VAAPI devices and count them
|
||||
log.Printf("Checking for VAAPI hardware acceleration...")
|
||||
|
||||
// First check if encoder is listed (more reliable than testing)
|
||||
cmd := exec.Command("ffmpeg", "-hide_banner", "-encoders")
|
||||
output, err := cmd.CombinedOutput()
|
||||
hasVAAPIEncoder := false
|
||||
if err == nil {
|
||||
encoderOutput := string(output)
|
||||
if strings.Contains(encoderOutput, "h264_vaapi") {
|
||||
hasVAAPIEncoder = true
|
||||
log.Printf("VAAPI encoder (h264_vaapi) found in ffmpeg encoders list")
|
||||
}
|
||||
}
|
||||
|
||||
if hasVAAPIEncoder {
|
||||
// Try to find and test devices
|
||||
vaapiDevices := c.findVAAPIDevices()
|
||||
capabilities["vaapi_gpu_count"] = len(vaapiDevices)
|
||||
if len(vaapiDevices) > 0 {
|
||||
capabilities["vaapi"] = true
|
||||
log.Printf("VAAPI detected: %d GPU device(s) available: %v", len(vaapiDevices), vaapiDevices)
|
||||
} else {
|
||||
capabilities["vaapi"] = false
|
||||
log.Printf("VAAPI encoder available but no working devices found")
|
||||
log.Printf(" This might indicate:")
|
||||
log.Printf(" - Missing or incorrect GPU drivers")
|
||||
log.Printf(" - Missing libva or mesa-va-drivers packages")
|
||||
log.Printf(" - Permission issues accessing /dev/dri devices")
|
||||
log.Printf(" - GPU not properly initialized")
|
||||
}
|
||||
} else {
|
||||
capabilities["vaapi"] = false
|
||||
capabilities["vaapi_gpu_count"] = 0
|
||||
log.Printf("VAAPI encoder not available in ffmpeg")
|
||||
log.Printf(" This might indicate:")
|
||||
log.Printf(" - FFmpeg was not compiled with VAAPI support")
|
||||
log.Printf(" - Missing libva development libraries during FFmpeg compilation")
|
||||
}
|
||||
|
||||
// Check for NVENC (NVIDIA) - try to detect multiple GPUs
|
||||
log.Printf("Checking for NVENC hardware acceleration...")
|
||||
if c.checkEncoderAvailable("h264_nvenc") {
|
||||
capabilities["nvenc"] = true
|
||||
// Try to detect actual GPU count using nvidia-smi if available
|
||||
nvencCount := c.detectNVENCCount()
|
||||
capabilities["nvenc_gpu_count"] = nvencCount
|
||||
log.Printf("NVENC detected: %d GPU(s)", nvencCount)
|
||||
} else {
|
||||
capabilities["nvenc"] = false
|
||||
capabilities["nvenc_gpu_count"] = 0
|
||||
log.Printf("NVENC encoder not available")
|
||||
}
|
||||
|
||||
// Check for other hardware encoders (for completeness)
|
||||
log.Printf("Checking for other hardware encoders...")
|
||||
if c.checkEncoderAvailable("h264_qsv") {
|
||||
@@ -368,73 +310,6 @@ func (c *Client) probeGPUCapabilities(capabilities map[string]interface{}) {
|
||||
capabilities["mediacodec"] = false
|
||||
capabilities["mediacodec_gpu_count"] = 0
|
||||
}
|
||||
|
||||
// Calculate total GPU count for video encoding
|
||||
// Priority: VAAPI > NVENC > QSV > VideoToolbox > AMF > others
|
||||
vaapiCount := 0
|
||||
if count, ok := capabilities["vaapi_gpu_count"].(int); ok {
|
||||
vaapiCount = count
|
||||
}
|
||||
nvencCount := 0
|
||||
if count, ok := capabilities["nvenc_gpu_count"].(int); ok {
|
||||
nvencCount = count
|
||||
}
|
||||
qsvCount := 0
|
||||
if count, ok := capabilities["qsv_gpu_count"].(int); ok {
|
||||
qsvCount = count
|
||||
}
|
||||
videotoolboxCount := 0
|
||||
if count, ok := capabilities["videotoolbox_gpu_count"].(int); ok {
|
||||
videotoolboxCount = count
|
||||
}
|
||||
amfCount := 0
|
||||
if count, ok := capabilities["amf_gpu_count"].(int); ok {
|
||||
amfCount = count
|
||||
}
|
||||
|
||||
// Total GPU count - use the best available (they can't be used simultaneously)
|
||||
totalGPUs := vaapiCount
|
||||
if totalGPUs == 0 {
|
||||
totalGPUs = nvencCount
|
||||
}
|
||||
if totalGPUs == 0 {
|
||||
totalGPUs = qsvCount
|
||||
}
|
||||
if totalGPUs == 0 {
|
||||
totalGPUs = videotoolboxCount
|
||||
}
|
||||
if totalGPUs == 0 {
|
||||
totalGPUs = amfCount
|
||||
}
|
||||
capabilities["video_gpu_count"] = totalGPUs
|
||||
|
||||
if totalGPUs > 0 {
|
||||
log.Printf("Total video GPU count: %d", totalGPUs)
|
||||
} else {
|
||||
log.Printf("No hardware-accelerated video encoding GPUs detected (will use software encoding)")
|
||||
}
|
||||
}
|
||||
|
||||
// detectNVENCCount tries to detect the actual number of NVIDIA GPUs using nvidia-smi
|
||||
func (c *Client) detectNVENCCount() int {
|
||||
// Try to use nvidia-smi to count GPUs
|
||||
cmd := exec.Command("nvidia-smi", "--list-gpus")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err == nil {
|
||||
// Count lines that contain "GPU" (each GPU is listed on a separate line)
|
||||
lines := strings.Split(string(output), "\n")
|
||||
count := 0
|
||||
for _, line := range lines {
|
||||
if strings.Contains(line, "GPU") {
|
||||
count++
|
||||
}
|
||||
}
|
||||
if count > 0 {
|
||||
return count
|
||||
}
|
||||
}
|
||||
// Fallback to 1 if nvidia-smi is not available
|
||||
return 1
|
||||
}
|
||||
|
||||
// getKeys returns all keys from a map as a slice (helper function)
|
||||
@@ -926,29 +801,13 @@ func (c *Client) sendLog(taskID int64, logLevel types.LogLevel, message, stepNam
|
||||
// KillAllProcesses kills all running processes tracked by this client
|
||||
func (c *Client) KillAllProcesses() {
|
||||
log.Printf("Killing all running processes...")
|
||||
var killedCount int
|
||||
c.runningProcs.Range(func(key, value interface{}) bool {
|
||||
taskID := key.(int64)
|
||||
cmd := value.(*exec.Cmd)
|
||||
if cmd.Process != nil {
|
||||
log.Printf("Killing process for task %d (PID: %d)", taskID, cmd.Process.Pid)
|
||||
// Try graceful kill first (SIGTERM)
|
||||
if err := cmd.Process.Signal(os.Interrupt); err != nil {
|
||||
log.Printf("Failed to send SIGINT to process %d: %v", cmd.Process.Pid, err)
|
||||
}
|
||||
// Give it a moment to clean up
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// Force kill if still running
|
||||
if err := cmd.Process.Kill(); err != nil {
|
||||
log.Printf("Failed to kill process %d: %v", cmd.Process.Pid, err)
|
||||
} else {
|
||||
killedCount++
|
||||
}
|
||||
}
|
||||
// Release any allocated device for this task
|
||||
c.releaseVAAPIDevice(taskID)
|
||||
return true
|
||||
})
|
||||
killedCount := c.processTracker.KillAll()
|
||||
// Release all allocated VAAPI devices
|
||||
c.allocatedDevicesMu.Lock()
|
||||
for taskID := range c.allocatedDevices {
|
||||
delete(c.allocatedDevices, taskID)
|
||||
}
|
||||
c.allocatedDevicesMu.Unlock()
|
||||
log.Printf("Killed %d process(es)", killedCount)
|
||||
}
|
||||
|
||||
@@ -1272,55 +1131,33 @@ func (c *Client) processTask(task map[string]interface{}, jobName string, output
|
||||
|
||||
// Run Blender with GPU enabled via Python script
|
||||
// Use -s (start) and -e (end) for frame ranges, or -f for single frame
|
||||
// Use Blender's automatic frame numbering with #### pattern
|
||||
var cmd *exec.Cmd
|
||||
args := []string{"-b", blendFile, "--python", scriptPath}
|
||||
if enableExecution {
|
||||
args = append(args, "--enable-autoexec")
|
||||
}
|
||||
// Always render frames individually for precise control over file naming
|
||||
// This avoids Blender's automatic frame numbering quirks
|
||||
for frame := frameStart; frame <= frameEnd; frame++ {
|
||||
// Create temp output pattern for this frame
|
||||
tempPattern := filepath.Join(outputDir, fmt.Sprintf("temp_frame.%s", strings.ToLower(renderFormat)))
|
||||
tempAbsPattern, _ := filepath.Abs(tempPattern)
|
||||
|
||||
// Build args for this specific frame
|
||||
frameArgs := []string{"-b", blendFile, "--python", scriptPath}
|
||||
if enableExecution {
|
||||
frameArgs = append(frameArgs, "--enable-autoexec")
|
||||
}
|
||||
frameArgs = append(frameArgs, "-o", tempAbsPattern, "-f", fmt.Sprintf("%d", frame))
|
||||
// Output pattern uses #### which Blender will replace with frame numbers
|
||||
outputPattern := filepath.Join(outputDir, fmt.Sprintf("frame_####.%s", strings.ToLower(renderFormat)))
|
||||
outputAbsPattern, _ := filepath.Abs(outputPattern)
|
||||
args = append(args, "-o", outputAbsPattern)
|
||||
|
||||
c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Rendering frame %d...", frame), "render_blender")
|
||||
|
||||
frameCmd := exec.Command("blender", frameArgs...)
|
||||
frameCmd.Dir = workDir
|
||||
frameCmd.Env = os.Environ()
|
||||
|
||||
// Run this frame
|
||||
if output, err := frameCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("blender failed on frame %d: %v (output: %s)", frame, err, string(output))
|
||||
c.sendLog(taskID, types.LogLevelError, errMsg, "render_blender")
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
// Immediately rename the temp file to the proper frame-numbered name
|
||||
finalName := fmt.Sprintf("frame_%04d.%s", frame, strings.ToLower(renderFormat))
|
||||
finalPath := filepath.Join(outputDir, finalName)
|
||||
tempPath := filepath.Join(outputDir, fmt.Sprintf("temp_frame.%s", strings.ToLower(renderFormat)))
|
||||
|
||||
if err := os.Rename(tempPath, finalPath); err != nil {
|
||||
errMsg := fmt.Sprintf("failed to rename temp file for frame %d: %v", frame, err)
|
||||
c.sendLog(taskID, types.LogLevelError, errMsg, "render_blender")
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Completed frame %d -> %s", frame, finalName), "render_blender")
|
||||
if frameStart == frameEnd {
|
||||
// Single frame
|
||||
args = append(args, "-f", fmt.Sprintf("%d", frameStart))
|
||||
c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Rendering frame %d...", frameStart), "render_blender")
|
||||
} else {
|
||||
// Frame range
|
||||
args = append(args, "-s", fmt.Sprintf("%d", frameStart), "-e", fmt.Sprintf("%d", frameEnd))
|
||||
c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Rendering frames %d-%d...", frameStart, frameEnd), "render_blender")
|
||||
}
|
||||
|
||||
// Skip the rest of the function since we handled all frames above
|
||||
c.sendStepUpdate(taskID, "render_blender", types.StepStatusCompleted, "")
|
||||
return nil
|
||||
// Create and run Blender command
|
||||
cmd = exec.Command("blender", args...)
|
||||
cmd.Dir = workDir
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
// Blender will handle headless rendering automatically
|
||||
// We preserve the environment to allow GPU access if available
|
||||
|
||||
@@ -1350,8 +1187,8 @@ func (c *Client) processTask(task map[string]interface{}, jobName string, output
|
||||
}
|
||||
|
||||
// Register process for cleanup on shutdown
|
||||
c.runningProcs.Store(taskID, cmd)
|
||||
defer c.runningProcs.Delete(taskID)
|
||||
c.processTracker.Track(taskID, cmd)
|
||||
defer c.processTracker.Untrack(taskID)
|
||||
|
||||
// Stream stdout line by line
|
||||
stdoutDone := make(chan bool)
|
||||
@@ -1396,15 +1233,23 @@ func (c *Client) processTask(task map[string]interface{}, jobName string, output
|
||||
<-stdoutDone
|
||||
<-stderrDone
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("blender failed: %v", err)
|
||||
var errMsg string
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
if exitErr.ExitCode() == 137 {
|
||||
errMsg = "Blender was killed due to excessive memory usage (OOM)"
|
||||
} else {
|
||||
errMsg = fmt.Sprintf("blender failed: %v", err)
|
||||
}
|
||||
} else {
|
||||
errMsg = fmt.Sprintf("blender failed: %v", err)
|
||||
}
|
||||
c.sendLog(taskID, types.LogLevelError, errMsg, "render_blender")
|
||||
c.sendStepUpdate(taskID, "render_blender", types.StepStatusFailed, errMsg)
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
// For frame ranges, we rendered each frame individually with temp naming
|
||||
// The files are already properly named during the individual frame rendering
|
||||
// No additional renaming needed
|
||||
// Blender has rendered frames with automatic numbering using the #### pattern
|
||||
// Files will be named like frame_0001.png, frame_0002.png, etc.
|
||||
|
||||
// Find rendered output file(s)
|
||||
// For frame ranges, we'll find all frames in the upload step
|
||||
@@ -1748,8 +1593,8 @@ func (c *Client) processVideoGenerationTask(task map[string]interface{}, jobID i
|
||||
// Extract frame number pattern (e.g., frame_2470.exr -> frame_%04d.exr)
|
||||
baseName := filepath.Base(firstFrame)
|
||||
// Find the numeric part and replace it with %04d pattern
|
||||
// Use regex to find digits (including negative) after underscore and before extension
|
||||
re := regexp.MustCompile(`_(-?\d+)\.`)
|
||||
// Use regex to find digits (positive only, negative frames not supported) after underscore and before extension
|
||||
re := regexp.MustCompile(`_(\d+)\.`)
|
||||
var pattern string
|
||||
var startNumber int
|
||||
frameNumStr := re.FindStringSubmatch(baseName)
|
||||
@@ -1763,6 +1608,7 @@ func (c *Client) processVideoGenerationTask(task map[string]interface{}, jobID i
|
||||
startNumber = extractFrameNumber(baseName)
|
||||
pattern = strings.Replace(baseName, fmt.Sprintf("%d", startNumber), "%04d", 1)
|
||||
}
|
||||
// Pattern path should be in workDir where the frame files are downloaded
|
||||
patternPath := filepath.Join(workDir, pattern)
|
||||
|
||||
// Allocate a VAAPI device for this task (if available)
|
||||
@@ -1891,8 +1737,8 @@ func (c *Client) processVideoGenerationTask(task map[string]interface{}, jobID i
|
||||
}
|
||||
|
||||
// Register process for cleanup on shutdown
|
||||
c.runningProcs.Store(taskID, cmd)
|
||||
defer c.runningProcs.Delete(taskID)
|
||||
c.processTracker.Track(taskID, cmd)
|
||||
defer c.processTracker.Untrack(taskID)
|
||||
|
||||
// Stream stdout line by line
|
||||
stdoutDone := make(chan bool)
|
||||
@@ -1959,26 +1805,25 @@ func (c *Client) processVideoGenerationTask(task map[string]interface{}, jobID i
|
||||
<-stderrDone
|
||||
|
||||
if err != nil {
|
||||
var errMsg string
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
if exitErr.ExitCode() == 137 {
|
||||
errMsg = "FFmpeg was killed due to excessive memory usage (OOM)"
|
||||
} else {
|
||||
errMsg = fmt.Sprintf("ffmpeg encoding failed: %v", err)
|
||||
}
|
||||
} else {
|
||||
errMsg = fmt.Sprintf("ffmpeg encoding failed: %v", err)
|
||||
}
|
||||
// Check for size-related errors and provide helpful messages
|
||||
if sizeErr := c.checkFFmpegSizeError("ffmpeg encoding failed"); sizeErr != nil {
|
||||
if sizeErr := c.checkFFmpegSizeError(errMsg); sizeErr != nil {
|
||||
c.sendLog(taskID, types.LogLevelError, sizeErr.Error(), "generate_video")
|
||||
c.sendStepUpdate(taskID, "generate_video", types.StepStatusFailed, sizeErr.Error())
|
||||
return sizeErr
|
||||
}
|
||||
|
||||
// Try alternative method with concat demuxer
|
||||
c.sendLog(taskID, types.LogLevelWarn, "Primary ffmpeg encoding failed, trying concat method...", "generate_video")
|
||||
err = c.generateMP4WithConcat(frameFiles, outputMP4, workDir, allocatedDevice, outputFormat, codec, pixFmt, useAlpha, useHardware, frameRate)
|
||||
if err != nil {
|
||||
// Check for size errors in concat method too
|
||||
if sizeErr := c.checkFFmpegSizeError(err.Error()); sizeErr != nil {
|
||||
c.sendLog(taskID, types.LogLevelError, sizeErr.Error(), "generate_video")
|
||||
c.sendStepUpdate(taskID, "generate_video", types.StepStatusFailed, sizeErr.Error())
|
||||
return sizeErr
|
||||
}
|
||||
c.sendStepUpdate(taskID, "generate_video", types.StepStatusFailed, err.Error())
|
||||
return err
|
||||
}
|
||||
c.sendLog(taskID, types.LogLevelError, errMsg, "generate_video")
|
||||
c.sendStepUpdate(taskID, "generate_video", types.StepStatusFailed, errMsg)
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
// Check if MP4 was created
|
||||
@@ -2771,7 +2616,7 @@ func (c *Client) testGenericEncoder(encoder string) bool {
|
||||
|
||||
// generateMP4WithConcat uses ffmpeg concat demuxer as fallback
|
||||
// device parameter is optional - if provided, it will be used for VAAPI encoding
|
||||
func (c *Client) generateMP4WithConcat(frameFiles []string, outputMP4, workDir string, device string, outputFormat string, codec string, pixFmt string, useAlpha bool, useHardware bool, frameRate float64) error {
|
||||
func (c *Client) generateMP4WithConcat(taskID int, frameFiles []string, outputMP4, workDir string, device string, outputFormat string, codec string, pixFmt string, useAlpha bool, useHardware bool, frameRate float64) error {
|
||||
// Create file list for ffmpeg concat demuxer
|
||||
listFile := filepath.Join(workDir, "frames.txt")
|
||||
listFileHandle, err := os.Create(listFile)
|
||||
@@ -2907,11 +2752,23 @@ func (c *Client) generateMP4WithConcat(frameFiles []string, outputMP4, workDir s
|
||||
<-stderrDone
|
||||
|
||||
if err != nil {
|
||||
var errMsg string
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
if exitErr.ExitCode() == 137 {
|
||||
errMsg = "FFmpeg was killed due to excessive memory usage (OOM)"
|
||||
} else {
|
||||
errMsg = fmt.Sprintf("ffmpeg concat failed: %v", err)
|
||||
}
|
||||
} else {
|
||||
errMsg = fmt.Sprintf("ffmpeg concat failed: %v", err)
|
||||
}
|
||||
// Check for size-related errors
|
||||
if sizeErr := c.checkFFmpegSizeError("ffmpeg concat failed"); sizeErr != nil {
|
||||
if sizeErr := c.checkFFmpegSizeError(errMsg); sizeErr != nil {
|
||||
return sizeErr
|
||||
}
|
||||
return fmt.Errorf("ffmpeg concat failed: %w", err)
|
||||
c.sendLog(int64(taskID), types.LogLevelError, errMsg, "generate_video")
|
||||
c.sendStepUpdate(int64(taskID), "generate_video", types.StepStatusFailed, errMsg)
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(outputMP4); os.IsNotExist(err) {
|
||||
@@ -3695,8 +3552,8 @@ sys.stdout.flush()
|
||||
}
|
||||
|
||||
// Register process for cleanup on shutdown
|
||||
c.runningProcs.Store(taskID, cmd)
|
||||
defer c.runningProcs.Delete(taskID)
|
||||
c.processTracker.Track(taskID, cmd)
|
||||
defer c.processTracker.Untrack(taskID)
|
||||
|
||||
// Stream stdout line by line and collect for JSON parsing
|
||||
stdoutDone := make(chan bool)
|
||||
@@ -3743,7 +3600,16 @@ sys.stdout.flush()
|
||||
<-stdoutDone
|
||||
<-stderrDone
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("blender metadata extraction failed: %v", err)
|
||||
var errMsg string
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
if exitErr.ExitCode() == 137 {
|
||||
errMsg = "Blender metadata extraction was killed due to excessive memory usage (OOM)"
|
||||
} else {
|
||||
errMsg = fmt.Sprintf("blender metadata extraction failed: %v", err)
|
||||
}
|
||||
} else {
|
||||
errMsg = fmt.Sprintf("blender metadata extraction failed: %v", err)
|
||||
}
|
||||
c.sendLog(taskID, types.LogLevelError, errMsg, "extract_metadata")
|
||||
c.sendStepUpdate(taskID, "extract_metadata", types.StepStatusFailed, errMsg)
|
||||
return errors.New(errMsg)
|
||||
|
||||
@@ -378,43 +378,47 @@ func (s *Storage) CreateJobContext(jobID int64) (string, error) {
|
||||
return "", fmt.Errorf("failed to open file %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
info, err := file.Stat()
|
||||
// Use a function closure to ensure file is closed even on error
|
||||
err = func() error {
|
||||
defer file.Close()
|
||||
|
||||
info, err := file.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stat file %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Get relative path for tar header
|
||||
relPath, err := filepath.Rel(jobPath, filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get relative path for %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Normalize path separators for tar (use forward slashes)
|
||||
tarPath := filepath.ToSlash(relPath)
|
||||
|
||||
// Create tar header
|
||||
header, err := tar.FileInfoHeader(info, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create tar header for %s: %w", filePath, err)
|
||||
}
|
||||
header.Name = tarPath
|
||||
|
||||
// Write header
|
||||
if err := tarWriter.WriteHeader(header); err != nil {
|
||||
return fmt.Errorf("failed to write tar header for %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Copy file contents using streaming
|
||||
if _, err := io.Copy(tarWriter, file); err != nil {
|
||||
return fmt.Errorf("failed to write file %s to tar: %w", filePath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return "", fmt.Errorf("failed to stat file %s: %w", filePath, err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Get relative path for tar header
|
||||
relPath, err := filepath.Rel(jobPath, filePath)
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return "", fmt.Errorf("failed to get relative path for %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Normalize path separators for tar (use forward slashes)
|
||||
tarPath := filepath.ToSlash(relPath)
|
||||
|
||||
// Create tar header
|
||||
header, err := tar.FileInfoHeader(info, "")
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return "", fmt.Errorf("failed to create tar header for %s: %w", filePath, err)
|
||||
}
|
||||
header.Name = tarPath
|
||||
|
||||
// Write header
|
||||
if err := tarWriter.WriteHeader(header); err != nil {
|
||||
file.Close()
|
||||
return "", fmt.Errorf("failed to write tar header for %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Copy file contents using streaming
|
||||
if _, err := io.Copy(tarWriter, file); err != nil {
|
||||
file.Close()
|
||||
return "", fmt.Errorf("failed to write file %s to tar: %w", filePath, err)
|
||||
}
|
||||
|
||||
file.Close()
|
||||
}
|
||||
|
||||
// Ensure all data is flushed
|
||||
@@ -550,42 +554,47 @@ func (s *Storage) CreateJobContextFromDir(sourceDir string, jobID int64, exclude
|
||||
return "", fmt.Errorf("failed to open file %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
info, err := file.Stat()
|
||||
// Use a function closure to ensure file is closed even on error
|
||||
err = func() error {
|
||||
defer file.Close()
|
||||
|
||||
info, err := file.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stat file %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Get relative path and strip common prefix if present
|
||||
relPath := relPaths[i]
|
||||
tarPath := filepath.ToSlash(relPath)
|
||||
|
||||
// Strip common prefix if found
|
||||
if commonPrefix != "" && strings.HasPrefix(tarPath, commonPrefix) {
|
||||
tarPath = strings.TrimPrefix(tarPath, commonPrefix)
|
||||
}
|
||||
|
||||
// Create tar header
|
||||
header, err := tar.FileInfoHeader(info, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create tar header for %s: %w", filePath, err)
|
||||
}
|
||||
header.Name = tarPath
|
||||
|
||||
// Write header
|
||||
if err := tarWriter.WriteHeader(header); err != nil {
|
||||
return fmt.Errorf("failed to write tar header for %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Copy file contents using streaming
|
||||
if _, err := io.Copy(tarWriter, file); err != nil {
|
||||
return fmt.Errorf("failed to write file %s to tar: %w", filePath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return "", fmt.Errorf("failed to stat file %s: %w", filePath, err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Get relative path and strip common prefix if present
|
||||
relPath := relPaths[i]
|
||||
tarPath := filepath.ToSlash(relPath)
|
||||
|
||||
// Strip common prefix if found
|
||||
if commonPrefix != "" && strings.HasPrefix(tarPath, commonPrefix) {
|
||||
tarPath = strings.TrimPrefix(tarPath, commonPrefix)
|
||||
}
|
||||
|
||||
// Create tar header
|
||||
header, err := tar.FileInfoHeader(info, "")
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return "", fmt.Errorf("failed to create tar header for %s: %w", filePath, err)
|
||||
}
|
||||
header.Name = tarPath
|
||||
|
||||
// Write header
|
||||
if err := tarWriter.WriteHeader(header); err != nil {
|
||||
file.Close()
|
||||
return "", fmt.Errorf("failed to write tar header for %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Copy file contents using streaming
|
||||
if _, err := io.Copy(tarWriter, file); err != nil {
|
||||
file.Close()
|
||||
return "", fmt.Errorf("failed to write file %s to tar: %w", filePath, err)
|
||||
}
|
||||
|
||||
file.Close()
|
||||
}
|
||||
|
||||
// Ensure all data is flushed
|
||||
|
||||
Reference in New Issue
Block a user