Files
jiggablend/internal/manager/manager.go
Justin Harms b51b96a618
All checks were successful
PR Check / check-and-test (pull_request) Successful in 26s
Refactor job status handling to prevent race conditions
- Removed redundant error handling in handleListJobTasks.
- Introduced per-job mutexes in Manager to serialize updateJobStatusFromTasks calls, ensuring thread safety during concurrent task completions.
- Added methods to manage job status update mutexes, including creation and cleanup after job completion or failure.
- Improved error handling in handleGetJobStatusForRunner by consolidating error checks.
2026-01-02 18:22:55 -06:00

1308 lines
39 KiB
Go

package api
import (
"compress/gzip"
"database/sql"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"time"
authpkg "jiggablend/internal/auth"
"jiggablend/internal/config"
"jiggablend/internal/database"
"jiggablend/internal/storage"
"jiggablend/pkg/types"
"jiggablend/web"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/cors"
"github.com/gorilla/websocket"
)
// Configuration constants
const (
// WebSocket timeouts
WSReadDeadline = 90 * time.Second
WSPingInterval = 30 * time.Second
WSWriteDeadline = 10 * time.Second
// Task timeouts
RenderTimeout = 60 * 60 // 1 hour for frame rendering
VideoEncodeTimeout = 60 * 60 * 24 // 24 hours for encoding
// Limits
MaxUploadSize = 50 << 30 // 50 GB
RunnerHeartbeatTimeout = 90 * time.Second
TaskDistributionInterval = 10 * time.Second
ProgressUpdateThrottle = 2 * time.Second
// Cookie settings
SessionCookieMaxAge = 86400 // 24 hours
)
// Manager represents the manager server
type Manager struct {
db *database.DB
cfg *config.Config
auth *authpkg.Auth
secrets *authpkg.Secrets
storage *storage.Storage
router *chi.Mux
// WebSocket connections
wsUpgrader websocket.Upgrader
// DEPRECATED: Old frontend WebSocket connection maps (kept for backwards compatibility)
// These will be removed in a future release. Use clientConns instead.
frontendConns map[string]*websocket.Conn // key: "jobId:taskId"
frontendConnsMu sync.RWMutex
frontendConnsWriteMu map[string]*sync.Mutex
frontendConnsWriteMuMu sync.RWMutex
jobListConns map[int64]*websocket.Conn
jobListConnsMu sync.RWMutex
jobConns map[string]*websocket.Conn
jobConnsMu sync.RWMutex
jobConnsWriteMu map[string]*sync.Mutex
jobConnsWriteMuMu sync.RWMutex
// Per-job runner WebSocket connections (polling-based flow)
// Key is "job-{jobId}-task-{taskId}"
runnerJobConns map[string]*websocket.Conn
runnerJobConnsMu sync.RWMutex
runnerJobConnsWriteMu map[string]*sync.Mutex
runnerJobConnsWriteMuMu sync.RWMutex
// Throttling for progress updates (per job)
progressUpdateTimes map[int64]time.Time // key: jobID
progressUpdateTimesMu sync.RWMutex
// Throttling for task status updates (per task)
taskUpdateTimes map[int64]time.Time // key: taskID
taskUpdateTimesMu sync.RWMutex
// Per-job mutexes to serialize updateJobStatusFromTasks calls and prevent race conditions
jobStatusUpdateMu map[int64]*sync.Mutex // key: jobID
jobStatusUpdateMuMu sync.RWMutex
// Client WebSocket connections (new unified WebSocket)
// Key is "userID:connID" to support multiple tabs per user
clientConns map[string]*ClientConnection
clientConnsMu sync.RWMutex
connIDCounter uint64 // Atomic counter for generating unique connection IDs
// Upload session tracking
uploadSessions map[string]*UploadSession // sessionId -> session info
uploadSessionsMu sync.RWMutex
// Verbose WebSocket logging (set to true to enable detailed WebSocket logs)
verboseWSLogging bool
// Server start time for health checks
startTime time.Time
}
// ClientConnection represents a client WebSocket connection with subscriptions
type ClientConnection struct {
Conn *websocket.Conn
UserID int64
ConnID string // Unique connection ID (userID:connID)
IsAdmin bool
Subscriptions map[string]bool // channel -> subscribed
SubsMu sync.RWMutex // Protects Subscriptions map
WriteMu *sync.Mutex
}
// UploadSession tracks upload and processing progress
type UploadSession struct {
SessionID string
UserID int64
Progress float64
Status string // "uploading", "processing", "extracting_metadata", "creating_context", "completed", "error"
Message string
CreatedAt time.Time
}
// NewManager creates a new manager server
func NewManager(db *database.DB, cfg *config.Config, auth *authpkg.Auth, storage *storage.Storage) (*Manager, error) {
secrets, err := authpkg.NewSecrets(db, cfg)
if err != nil {
return nil, fmt.Errorf("failed to initialize secrets: %w", err)
}
s := &Manager{
db: db,
cfg: cfg,
auth: auth,
secrets: secrets,
storage: storage,
router: chi.NewRouter(),
startTime: time.Now(),
wsUpgrader: websocket.Upgrader{
CheckOrigin: checkWebSocketOrigin,
ReadBufferSize: 1024,
WriteBufferSize: 1024,
},
// DEPRECATED: Initialize old frontend WebSocket maps for backward compatibility
frontendConns: make(map[string]*websocket.Conn),
frontendConnsWriteMu: make(map[string]*sync.Mutex),
jobListConns: make(map[int64]*websocket.Conn),
jobConns: make(map[string]*websocket.Conn),
jobConnsWriteMu: make(map[string]*sync.Mutex),
progressUpdateTimes: make(map[int64]time.Time),
taskUpdateTimes: make(map[int64]time.Time),
clientConns: make(map[string]*ClientConnection),
uploadSessions: make(map[string]*UploadSession),
// Per-job runner WebSocket connections
runnerJobConns: make(map[string]*websocket.Conn),
runnerJobConnsWriteMu: make(map[string]*sync.Mutex),
runnerJobConnsWriteMuMu: sync.RWMutex{}, // Initialize the new field
// Per-job mutexes for serializing status updates
jobStatusUpdateMu: make(map[int64]*sync.Mutex),
}
// Check for required external tools
if err := s.checkRequiredTools(); err != nil {
return nil, err
}
s.setupMiddleware()
s.setupRoutes()
s.StartBackgroundTasks()
// On startup, check for runners that are marked online but not actually connected
// This handles the case where the manager restarted and lost track of connections
go s.recoverRunnersOnStartup()
return s, nil
}
// checkRequiredTools verifies that required external tools are available
func (s *Manager) checkRequiredTools() error {
// Check for zstd (required for zstd-compressed blend files)
if err := exec.Command("zstd", "--version").Run(); err != nil {
return fmt.Errorf("zstd not found - required for compressed blend file support. Install with: apt install zstd")
}
log.Printf("Found zstd for compressed blend file support")
// Check for xz (required for decompressing blender archives)
if err := exec.Command("xz", "--version").Run(); err != nil {
return fmt.Errorf("xz not found - required for decompressing blender archives. Install with: apt install xz-utils")
}
log.Printf("Found xz for blender archive decompression")
return nil
}
// checkWebSocketOrigin validates WebSocket connection origins
// In production mode, only allows same-origin connections or configured allowed origins
func checkWebSocketOrigin(r *http.Request) bool {
origin := r.Header.Get("Origin")
if origin == "" {
// No origin header - allow (could be non-browser client like runner)
return true
}
// In development mode, allow all origins
// Note: This function doesn't have access to Server, so we use authpkg.IsProductionMode()
// which checks environment variable. The server setup uses s.cfg.IsProductionMode() for consistency.
if !authpkg.IsProductionMode() {
return true
}
// In production, check against allowed origins
allowedOrigins := os.Getenv("ALLOWED_ORIGINS")
if allowedOrigins == "" {
// Default to same-origin only
host := r.Host
return strings.HasSuffix(origin, "://"+host) || strings.HasSuffix(origin, "://"+strings.Split(host, ":")[0])
}
// Check against configured allowed origins
for _, allowed := range strings.Split(allowedOrigins, ",") {
allowed = strings.TrimSpace(allowed)
if allowed == "*" {
return true
}
if origin == allowed {
return true
}
}
log.Printf("WebSocket origin rejected: %s (allowed: %s)", origin, allowedOrigins)
return false
}
// RateLimiter provides simple in-memory rate limiting per IP
type RateLimiter struct {
requests map[string][]time.Time
mu sync.RWMutex
limit int // max requests
window time.Duration // time window
}
// NewRateLimiter creates a new rate limiter
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
rl := &RateLimiter{
requests: make(map[string][]time.Time),
limit: limit,
window: window,
}
// Start cleanup goroutine
go rl.cleanup()
return rl
}
// Allow checks if a request from the given IP is allowed
func (rl *RateLimiter) Allow(ip string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
cutoff := now.Add(-rl.window)
// Get existing requests and filter old ones
reqs := rl.requests[ip]
validReqs := make([]time.Time, 0, len(reqs))
for _, t := range reqs {
if t.After(cutoff) {
validReqs = append(validReqs, t)
}
}
// Check if under limit
if len(validReqs) >= rl.limit {
rl.requests[ip] = validReqs
return false
}
// Add this request
validReqs = append(validReqs, now)
rl.requests[ip] = validReqs
return true
}
// cleanup periodically removes old entries
func (rl *RateLimiter) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
rl.mu.Lock()
cutoff := time.Now().Add(-rl.window)
for ip, reqs := range rl.requests {
validReqs := make([]time.Time, 0, len(reqs))
for _, t := range reqs {
if t.After(cutoff) {
validReqs = append(validReqs, t)
}
}
if len(validReqs) == 0 {
delete(rl.requests, ip)
} else {
rl.requests[ip] = validReqs
}
}
rl.mu.Unlock()
}
}
// Global rate limiters for different endpoint types
var (
// General API rate limiter: 100 requests per minute per IP
apiRateLimiter = NewRateLimiter(100, time.Minute)
// Auth rate limiter: 10 requests per minute per IP (stricter for login attempts)
authRateLimiter = NewRateLimiter(10, time.Minute)
)
// rateLimitMiddleware applies rate limiting based on client IP
func rateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get client IP (handle proxied requests)
ip := r.RemoteAddr
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
// Take the first IP in the chain
if idx := strings.Index(forwarded, ","); idx != -1 {
ip = strings.TrimSpace(forwarded[:idx])
} else {
ip = strings.TrimSpace(forwarded)
}
} else if realIP := r.Header.Get("X-Real-IP"); realIP != "" {
ip = strings.TrimSpace(realIP)
}
if !limiter.Allow(ip) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Retry-After", "60")
w.WriteHeader(http.StatusTooManyRequests)
json.NewEncoder(w).Encode(map[string]string{
"error": "Rate limit exceeded. Please try again later.",
})
return
}
next.ServeHTTP(w, r)
})
}
}
// setupMiddleware configures middleware
func (s *Manager) setupMiddleware() {
s.router.Use(middleware.Logger)
s.router.Use(middleware.Recoverer)
// Note: Timeout middleware is NOT applied globally to avoid conflicts with WebSocket connections
// WebSocket connections are long-lived and should not have HTTP timeouts
// Check production mode from config
isProduction := s.cfg.IsProductionMode()
// Add rate limiting (applied in production mode only, or when explicitly enabled)
if isProduction || os.Getenv("ENABLE_RATE_LIMITING") == "true" {
s.router.Use(rateLimitMiddleware(apiRateLimiter))
log.Printf("Rate limiting enabled: 100 requests/minute per IP")
}
// Add gzip compression for JSON responses
s.router.Use(gzipMiddleware)
// Configure CORS based on environment
corsOptions := cors.Options{
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "Range", "If-None-Match"},
ExposedHeaders: []string{"Link", "Content-Range", "Accept-Ranges", "Content-Length", "ETag"},
AllowCredentials: true,
MaxAge: 300,
}
// In production, restrict CORS origins
if isProduction {
allowedOrigins := s.cfg.AllowedOrigins()
if allowedOrigins != "" {
corsOptions.AllowedOrigins = strings.Split(allowedOrigins, ",")
for i := range corsOptions.AllowedOrigins {
corsOptions.AllowedOrigins[i] = strings.TrimSpace(corsOptions.AllowedOrigins[i])
}
} else {
// Default to no origins in production if not configured
// This effectively disables cross-origin requests
corsOptions.AllowedOrigins = []string{}
}
log.Printf("Production mode: CORS restricted to origins: %v", corsOptions.AllowedOrigins)
} else {
// Development mode: allow all origins
corsOptions.AllowedOrigins = []string{"*"}
}
s.router.Use(cors.Handler(corsOptions))
}
// gzipMiddleware compresses responses with gzip if client supports it
func gzipMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip compression for WebSocket upgrades
if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" {
next.ServeHTTP(w, r)
return
}
// Check if client accepts gzip
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
next.ServeHTTP(w, r)
return
}
// Create gzip writer
gz := gzip.NewWriter(w)
defer gz.Close()
w.Header().Set("Content-Encoding", "gzip")
w.Header().Set("Vary", "Accept-Encoding")
// Wrap response writer
gzw := &gzipResponseWriter{Writer: gz, ResponseWriter: w}
next.ServeHTTP(gzw, r)
})
}
// gzipResponseWriter wraps http.ResponseWriter to add gzip compression
type gzipResponseWriter struct {
io.Writer
http.ResponseWriter
}
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
}
func (w *gzipResponseWriter) WriteHeader(statusCode int) {
// Don't set Content-Length when using gzip - it will be set automatically
w.ResponseWriter.WriteHeader(statusCode)
}
// setupRoutes configures routes
func (s *Manager) setupRoutes() {
// Health check endpoint (unauthenticated)
s.router.Get("/api/health", s.handleHealthCheck)
// Public routes (with stricter rate limiting for auth endpoints)
s.router.Route("/api/auth", func(r chi.Router) {
// Apply stricter rate limiting to auth endpoints in production
if s.cfg.IsProductionMode() || os.Getenv("ENABLE_RATE_LIMITING") == "true" {
r.Use(rateLimitMiddleware(authRateLimiter))
}
r.Get("/providers", s.handleGetAuthProviders)
r.Get("/google/login", s.handleGoogleLogin)
r.Get("/google/callback", s.handleGoogleCallback)
r.Get("/discord/login", s.handleDiscordLogin)
r.Get("/discord/callback", s.handleDiscordCallback)
r.Get("/local/available", s.handleLocalLoginAvailable)
r.Post("/local/register", s.handleLocalRegister)
r.Post("/local/login", s.handleLocalLogin)
r.Post("/logout", s.handleLogout)
r.Get("/me", s.handleGetMe)
r.Post("/change-password", s.handleChangePassword)
})
// Protected routes
s.router.Route("/api/jobs", func(r chi.Router) {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(s.auth.Middleware(next.ServeHTTP))
})
r.Post("/", s.handleCreateJob)
r.Post("/upload", s.handleUploadFileForJobCreation) // Upload before job creation
r.Get("/", s.handleListJobs)
r.Get("/summary", s.handleListJobsSummary)
r.Post("/batch", s.handleBatchGetJobs)
r.Get("/{id}", s.handleGetJob)
r.Delete("/{id}", s.handleCancelJob)
r.Post("/{id}/delete", s.handleDeleteJob)
r.Post("/{id}/upload", s.handleUploadJobFile)
r.Get("/{id}/files", s.handleListJobFiles)
r.Get("/{id}/files/count", s.handleGetJobFilesCount)
r.Get("/{id}/context", s.handleListContextArchive)
r.Get("/{id}/files/{fileId}/download", s.handleDownloadJobFile)
r.Get("/{id}/files/{fileId}/preview-exr", s.handlePreviewEXR)
r.Get("/{id}/video", s.handleStreamVideo)
r.Get("/{id}/metadata", s.handleGetJobMetadata)
r.Get("/{id}/tasks", s.handleListJobTasks)
r.Get("/{id}/tasks/summary", s.handleListJobTasksSummary)
r.Post("/{id}/tasks/batch", s.handleBatchGetTasks)
r.Get("/{id}/tasks/{taskId}/logs", s.handleGetTaskLogs)
r.Get("/{id}/tasks/{taskId}/steps", s.handleGetTaskSteps)
r.Post("/{id}/tasks/{taskId}/retry", s.handleRetryTask)
// WebSocket route for unified client WebSocket
r.With(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Apply authentication middleware first
s.auth.Middleware(func(w http.ResponseWriter, r *http.Request) {
// Remove timeout middleware for WebSocket
next.ServeHTTP(w, r)
})(w, r)
})
}).Get("/ws", s.handleClientWebSocket)
})
// Admin routes
s.router.Route("/api/admin", func(r chi.Router) {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(s.auth.AdminMiddleware(next.ServeHTTP))
})
r.Route("/runners", func(r chi.Router) {
r.Route("/api-keys", func(r chi.Router) {
r.Post("/", s.handleGenerateRunnerAPIKey)
r.Get("/", s.handleListRunnerAPIKeys)
r.Patch("/{id}/revoke", s.handleRevokeRunnerAPIKey)
r.Delete("/{id}", s.handleDeleteRunnerAPIKey)
})
r.Get("/", s.handleListRunnersAdmin)
r.Post("/{id}/verify", s.handleVerifyRunner)
r.Delete("/{id}", s.handleDeleteRunner)
})
r.Route("/users", func(r chi.Router) {
r.Get("/", s.handleListUsers)
r.Get("/{id}/jobs", s.handleGetUserJobs)
r.Post("/{id}/admin", s.handleSetUserAdminStatus)
})
r.Route("/settings", func(r chi.Router) {
r.Get("/registration", s.handleGetRegistrationEnabled)
r.Post("/registration", s.handleSetRegistrationEnabled)
})
})
// Runner API
s.router.Route("/api/runner", func(r chi.Router) {
// Registration doesn't require auth (uses token)
r.With(middleware.Timeout(60*time.Second)).Post("/register", s.handleRegisterRunner)
// Polling-based endpoints (auth handled in handlers)
r.Get("/workers/{id}/next-job", s.handleNextJob)
// Per-job endpoints with job_token auth (no middleware, auth in handler)
r.Get("/jobs/{jobId}/ws", s.handleRunnerJobWebSocket)
r.Get("/jobs/{jobId}/context.tar", s.handleDownloadJobContextWithToken)
r.Post("/jobs/{jobId}/upload", s.handleUploadFileWithToken)
// Runner API endpoints (uses API key auth)
r.Group(func(r chi.Router) {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(s.runnerAuthMiddleware(next.ServeHTTP))
})
r.Get("/blender/download", s.handleDownloadBlender)
r.Get("/jobs/{jobId}/files", s.handleGetJobFilesForRunner)
r.Get("/jobs/{jobId}/metadata", s.handleGetJobMetadataForRunner)
r.Get("/files/{jobId}/{fileName}", s.handleDownloadFileForRunner)
})
})
// Blender versions API (public, for job submission page)
s.router.Get("/api/blender/versions", s.handleGetBlenderVersions)
// Serve static files (embedded React app with SPA fallback)
s.router.Handle("/*", web.SPAHandler())
}
// ServeHTTP implements http.Handler
func (s *Manager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.router.ServeHTTP(w, r)
}
// JSON response helpers
func (s *Manager) respondJSON(w http.ResponseWriter, status int, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(data); err != nil {
log.Printf("Failed to encode JSON response: %v", err)
}
}
func (s *Manager) respondError(w http.ResponseWriter, status int, message string) {
s.respondJSON(w, status, map[string]string{"error": message})
}
// createSessionCookie creates a secure session cookie with appropriate flags for the environment
func createSessionCookie(sessionID string) *http.Cookie {
cookie := &http.Cookie{
Name: "session_id",
Value: sessionID,
Path: "/",
MaxAge: SessionCookieMaxAge,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
// In production mode, set Secure flag to require HTTPS
if authpkg.IsProductionMode() {
cookie.Secure = true
}
return cookie
}
// handleHealthCheck returns server health status
func (s *Manager) handleHealthCheck(w http.ResponseWriter, r *http.Request) {
// Check database connectivity
dbHealthy := true
if err := s.db.Ping(); err != nil {
dbHealthy = false
log.Printf("Health check: database ping failed: %v", err)
}
// Count online runners (based on recent heartbeat)
var runnerCount int
s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT COUNT(*) FROM runners WHERE status = ?`,
types.RunnerStatusOnline,
).Scan(&runnerCount)
})
// Count connected clients
s.clientConnsMu.RLock()
clientCount := len(s.clientConns)
s.clientConnsMu.RUnlock()
// Calculate uptime
uptime := time.Since(s.startTime)
// Get memory stats
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
status := "healthy"
statusCode := http.StatusOK
if !dbHealthy {
status = "degraded"
statusCode = http.StatusServiceUnavailable
}
response := map[string]interface{}{
"status": status,
"uptime_seconds": int64(uptime.Seconds()),
"database": dbHealthy,
"connected_runners": runnerCount,
"connected_clients": clientCount,
"memory": map[string]interface{}{
"alloc_mb": memStats.Alloc / 1024 / 1024,
"total_alloc_mb": memStats.TotalAlloc / 1024 / 1024,
"sys_mb": memStats.Sys / 1024 / 1024,
"num_gc": memStats.NumGC,
},
"timestamp": time.Now().Unix(),
}
s.respondJSON(w, statusCode, response)
}
// Auth handlers
func (s *Manager) handleGoogleLogin(w http.ResponseWriter, r *http.Request) {
url, err := s.auth.GoogleLoginURL()
if err != nil {
s.respondError(w, http.StatusInternalServerError, err.Error())
return
}
http.Redirect(w, r, url, http.StatusFound)
}
func (s *Manager) handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
if code == "" {
s.respondError(w, http.StatusBadRequest, "Missing code parameter")
return
}
session, err := s.auth.GoogleCallback(r.Context(), code)
if err != nil {
// If registration is disabled, redirect back to login with error
if err.Error() == "registration is disabled" {
http.Redirect(w, r, "/?error=registration_disabled", http.StatusFound)
return
}
s.respondError(w, http.StatusInternalServerError, err.Error())
return
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, createSessionCookie(sessionID))
http.Redirect(w, r, "/", http.StatusFound)
}
func (s *Manager) handleDiscordLogin(w http.ResponseWriter, r *http.Request) {
url, err := s.auth.DiscordLoginURL()
if err != nil {
s.respondError(w, http.StatusInternalServerError, err.Error())
return
}
http.Redirect(w, r, url, http.StatusFound)
}
func (s *Manager) handleDiscordCallback(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
if code == "" {
s.respondError(w, http.StatusBadRequest, "Missing code parameter")
return
}
session, err := s.auth.DiscordCallback(r.Context(), code)
if err != nil {
// If registration is disabled, redirect back to login with error
if err.Error() == "registration is disabled" {
http.Redirect(w, r, "/?error=registration_disabled", http.StatusFound)
return
}
s.respondError(w, http.StatusInternalServerError, err.Error())
return
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, createSessionCookie(sessionID))
http.Redirect(w, r, "/", http.StatusFound)
}
func (s *Manager) handleLogout(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("session_id")
if err == nil {
s.auth.DeleteSession(cookie.Value)
}
// Create an expired cookie to clear the session
expiredCookie := &http.Cookie{
Name: "session_id",
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
// Use s.cfg.IsProductionMode() for consistency with other server methods
if s.cfg.IsProductionMode() {
expiredCookie.Secure = true
}
http.SetCookie(w, expiredCookie)
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Logged out"})
}
func (s *Manager) handleGetMe(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("session_id")
if err != nil {
log.Printf("Authentication failed: missing session cookie in /auth/me")
s.respondError(w, http.StatusUnauthorized, "Not authenticated")
return
}
session, ok := s.auth.GetSession(cookie.Value)
if !ok {
log.Printf("Authentication failed: invalid session cookie in /auth/me")
s.respondError(w, http.StatusUnauthorized, "Invalid session")
return
}
s.respondJSON(w, http.StatusOK, map[string]interface{}{
"id": session.UserID,
"email": session.Email,
"name": session.Name,
"is_admin": session.IsAdmin,
})
}
func (s *Manager) handleGetAuthProviders(w http.ResponseWriter, r *http.Request) {
s.respondJSON(w, http.StatusOK, map[string]bool{
"google": s.auth.IsGoogleOAuthConfigured(),
"discord": s.auth.IsDiscordOAuthConfigured(),
"local": s.auth.IsLocalLoginEnabled(),
})
}
func (s *Manager) handleLocalLoginAvailable(w http.ResponseWriter, r *http.Request) {
s.respondJSON(w, http.StatusOK, map[string]bool{
"available": s.auth.IsLocalLoginEnabled(),
})
}
func (s *Manager) handleLocalRegister(w http.ResponseWriter, r *http.Request) {
var req struct {
Email string `json:"email"`
Name string `json:"name"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
return
}
if req.Email == "" || req.Name == "" || req.Password == "" {
s.respondError(w, http.StatusBadRequest, "Email, name, and password are required")
return
}
if len(req.Password) < 8 {
s.respondError(w, http.StatusBadRequest, "Password must be at least 8 characters long")
return
}
session, err := s.auth.RegisterLocalUser(req.Email, req.Name, req.Password)
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, createSessionCookie(sessionID))
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
"message": "Registration successful",
"user": map[string]interface{}{
"id": session.UserID,
"email": session.Email,
"name": session.Name,
"is_admin": session.IsAdmin,
},
})
}
func (s *Manager) handleLocalLogin(w http.ResponseWriter, r *http.Request) {
var req struct {
Username string `json:"username"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
return
}
if req.Username == "" || req.Password == "" {
s.respondError(w, http.StatusBadRequest, "Username and password are required")
return
}
session, err := s.auth.LocalLogin(req.Username, req.Password)
if err != nil {
log.Printf("Authentication failed: invalid credentials for username '%s'", req.Username)
s.respondError(w, http.StatusUnauthorized, "Invalid credentials")
return
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, createSessionCookie(sessionID))
s.respondJSON(w, http.StatusOK, map[string]interface{}{
"message": "Login successful",
"user": map[string]interface{}{
"id": session.UserID,
"email": session.Email,
"name": session.Name,
"is_admin": session.IsAdmin,
},
})
}
func (s *Manager) handleChangePassword(w http.ResponseWriter, r *http.Request) {
userID, err := getUserID(r)
if err != nil {
s.respondError(w, http.StatusUnauthorized, err.Error())
return
}
var req struct {
OldPassword string `json:"old_password"`
NewPassword string `json:"new_password"`
TargetUserID *int64 `json:"target_user_id,omitempty"` // For admin to change other users' passwords
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
return
}
if req.NewPassword == "" {
s.respondError(w, http.StatusBadRequest, "New password is required")
return
}
if len(req.NewPassword) < 8 {
s.respondError(w, http.StatusBadRequest, "Password must be at least 8 characters long")
return
}
isAdmin := authpkg.IsAdmin(r.Context())
// If target_user_id is provided and user is admin, allow changing other user's password
if req.TargetUserID != nil && isAdmin {
if err := s.auth.AdminChangePassword(*req.TargetUserID, req.NewPassword); err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Password changed successfully"})
return
}
// Otherwise, user is changing their own password (requires old password)
if req.OldPassword == "" {
s.respondError(w, http.StatusBadRequest, "Old password is required")
return
}
if err := s.auth.ChangePassword(userID, req.OldPassword, req.NewPassword); err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Password changed successfully"})
}
// Helper to get user ID from context
func getUserID(r *http.Request) (int64, error) {
userID, ok := authpkg.GetUserID(r.Context())
if !ok {
return 0, fmt.Errorf("user ID not found in context")
}
return userID, nil
}
// Helper to parse ID from URL
func parseID(r *http.Request, param string) (int64, error) {
idStr := chi.URLParam(r, param)
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
return 0, fmt.Errorf("invalid ID: %s", idStr)
}
return id, nil
}
// StartBackgroundTasks starts background goroutines for error recovery
func (s *Manager) StartBackgroundTasks() {
go s.recoverStuckTasks()
go s.cleanupOldRenderJobs()
go s.cleanupOldTempDirectories()
go s.cleanupOldOfflineRunners()
go s.cleanupOldUploadSessions()
}
// recoverRunnersOnStartup marks runners as offline on startup
// In the polling model, runners will update their status when they poll for jobs
func (s *Manager) recoverRunnersOnStartup() {
log.Printf("Recovering runners on startup: marking all as offline...")
// Mark all runners as offline - they'll be marked online when they poll
var runnersAffected int64
err := s.db.With(func(conn *sql.DB) error {
result, err := conn.Exec(
`UPDATE runners SET status = ? WHERE status = ?`,
types.RunnerStatusOffline, types.RunnerStatusOnline,
)
if err != nil {
return err
}
runnersAffected, _ = result.RowsAffected()
return nil
})
if err != nil {
log.Printf("Failed to mark runners as offline on startup: %v", err)
return
}
if runnersAffected > 0 {
log.Printf("Marked %d runners as offline on startup", runnersAffected)
}
// Reset any running tasks that were assigned to runners
// They will be picked up by runners when they poll
var tasksAffected int64
err = s.db.With(func(conn *sql.DB) error {
result, err := conn.Exec(
`UPDATE tasks SET runner_id = NULL, status = ?, started_at = NULL
WHERE status = ?`,
types.TaskStatusPending, types.TaskStatusRunning,
)
if err != nil {
return err
}
tasksAffected, _ = result.RowsAffected()
return nil
})
if err != nil {
log.Printf("Failed to reset running tasks on startup: %v", err)
return
}
if tasksAffected > 0 {
log.Printf("Reset %d running tasks to pending on startup", tasksAffected)
}
}
// recoverStuckTasks periodically checks for dead runners and stuck tasks
func (s *Manager) recoverStuckTasks() {
ticker := time.NewTicker(TaskDistributionInterval)
defer ticker.Stop()
for range ticker.C {
func() {
defer func() {
if r := recover(); r != nil {
log.Printf("Panic in recoverStuckTasks: %v", r)
}
}()
// Find dead runners (no heartbeat for configured timeout)
// In polling model, heartbeat is updated when runner polls for jobs
var deadRunnerIDs []int64
cutoffTime := time.Now().Add(-RunnerHeartbeatTimeout)
err := s.db.With(func(conn *sql.DB) error {
rows, err := conn.Query(
`SELECT id FROM runners
WHERE last_heartbeat < ?
AND status = ?`,
cutoffTime, types.RunnerStatusOnline,
)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var runnerID int64
if err := rows.Scan(&runnerID); err == nil {
deadRunnerIDs = append(deadRunnerIDs, runnerID)
}
}
return nil
})
if err != nil {
log.Printf("Failed to query dead runners: %v", err)
return
}
// Reset tasks assigned to dead runners
for _, runnerID := range deadRunnerIDs {
s.resetRunnerTasks(runnerID)
// Mark runner as offline
s.db.With(func(conn *sql.DB) error {
_, _ = conn.Exec(
`UPDATE runners SET status = ? WHERE id = ?`,
types.RunnerStatusOffline, runnerID,
)
return nil
})
}
// Check for task timeouts
s.recoverTaskTimeouts()
}()
}
}
// recoverTaskTimeouts handles tasks that have exceeded their timeout
// Timeouts are treated as runner failures (not task failures) and retry indefinitely
func (s *Manager) recoverTaskTimeouts() {
// Find tasks running longer than their timeout
var tasks []struct {
taskID int64
jobID int64
runnerID sql.NullInt64
timeoutSeconds sql.NullInt64
startedAt time.Time
}
err := s.db.With(func(conn *sql.DB) error {
rows, err := conn.Query(
`SELECT t.id, t.job_id, t.runner_id, t.timeout_seconds, t.started_at
FROM tasks t
WHERE t.status = ?
AND t.started_at IS NOT NULL
AND (t.completed_at IS NULL OR t.completed_at < datetime('now', '-30 seconds'))
AND (t.timeout_seconds IS NULL OR
(julianday('now') - julianday(t.started_at)) * 86400 > t.timeout_seconds)`,
types.TaskStatusRunning,
)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var task struct {
taskID int64
jobID int64
runnerID sql.NullInt64
timeoutSeconds sql.NullInt64
startedAt time.Time
}
err := rows.Scan(&task.taskID, &task.jobID, &task.runnerID, &task.timeoutSeconds, &task.startedAt)
if err != nil {
log.Printf("Failed to scan task row in recoverTaskTimeouts: %v", err)
continue
}
tasks = append(tasks, task)
}
return nil
})
if err != nil {
log.Printf("Failed to query timed out tasks: %v", err)
return
}
for _, task := range tasks {
taskID := task.taskID
jobID := task.jobID
timeoutSeconds := task.timeoutSeconds
startedAt := task.startedAt
// Use default timeout if not set (5 minutes for frame tasks, 24 hours for FFmpeg)
timeout := 300 // 5 minutes default
if timeoutSeconds.Valid {
timeout = int(timeoutSeconds.Int64)
}
// Check if actually timed out
if time.Since(startedAt).Seconds() < float64(timeout) {
continue
}
// Timeouts are runner failures - always reset to pending and increment runner_failure_count
// This does NOT count against retry_count (which is for actual task failures like Blender crashes)
err = s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(`UPDATE tasks SET status = ? WHERE id = ?`, types.TaskStatusPending, taskID)
if err != nil {
return err
}
_, err = conn.Exec(`UPDATE tasks SET runner_id = NULL WHERE id = ?`, taskID)
if err != nil {
return err
}
_, err = conn.Exec(`UPDATE tasks SET current_step = NULL WHERE id = ?`, taskID)
if err != nil {
return err
}
_, err = conn.Exec(`UPDATE tasks SET started_at = NULL WHERE id = ?`, taskID)
if err != nil {
return err
}
_, err = conn.Exec(`UPDATE tasks SET runner_failure_count = runner_failure_count + 1 WHERE id = ?`, taskID)
if err != nil {
return err
}
// Clear steps and logs for fresh retry
_, err = conn.Exec(`DELETE FROM task_steps WHERE task_id = ?`, taskID)
if err != nil {
return err
}
_, err = conn.Exec(`DELETE FROM task_logs WHERE task_id = ?`, taskID)
return err
})
if err == nil {
// Broadcast task reset to clients (includes steps_cleared and logs_cleared flags)
s.broadcastTaskUpdate(jobID, taskID, "task_reset", map[string]interface{}{
"status": types.TaskStatusPending,
"runner_id": nil,
"current_step": nil,
"started_at": nil,
"steps_cleared": true,
"logs_cleared": true,
})
// Update job status
s.updateJobStatusFromTasks(jobID)
log.Printf("Reset timed out task %d: %v", taskID, err)
} else {
log.Printf("Failed to reset timed out task %d: %v", taskID, err)
}
}
}
// cleanupOldTempDirectories periodically cleans up old temporary directories
func (s *Manager) cleanupOldTempDirectories() {
// Run cleanup every hour
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
// Run once immediately on startup
s.cleanupOldTempDirectoriesOnce()
for range ticker.C {
s.cleanupOldTempDirectoriesOnce()
}
}
// cleanupOldTempDirectoriesOnce removes temp directories older than 1 hour
func (s *Manager) cleanupOldTempDirectoriesOnce() {
defer func() {
if r := recover(); r != nil {
log.Printf("Panic in cleanupOldTempDirectories: %v", r)
}
}()
tempPath := filepath.Join(s.storage.BasePath(), "temp")
// Check if temp directory exists
if _, err := os.Stat(tempPath); os.IsNotExist(err) {
return
}
// Read all entries in temp directory
entries, err := os.ReadDir(tempPath)
if err != nil {
log.Printf("Failed to read temp directory: %v", err)
return
}
now := time.Now()
cleanedCount := 0
// Check upload sessions to avoid deleting active uploads
s.uploadSessionsMu.RLock()
activeSessions := make(map[string]bool)
for sessionID := range s.uploadSessions {
activeSessions[sessionID] = true
}
s.uploadSessionsMu.RUnlock()
for _, entry := range entries {
if !entry.IsDir() {
continue
}
entryPath := filepath.Join(tempPath, entry.Name())
// Skip if this directory has an active upload session
if activeSessions[entryPath] {
continue
}
// Get directory info to check modification time
info, err := entry.Info()
if err != nil {
continue
}
// Remove directories older than 1 hour (only if no active session)
age := now.Sub(info.ModTime())
if age > 1*time.Hour {
if err := os.RemoveAll(entryPath); err != nil {
log.Printf("Warning: Failed to clean up old temp directory %s: %v", entryPath, err)
} else {
cleanedCount++
log.Printf("Cleaned up old temp directory: %s (age: %v)", entryPath, age)
}
}
}
if cleanedCount > 0 {
log.Printf("Cleaned up %d old temp directories", cleanedCount)
}
}
// cleanupOldUploadSessions periodically cleans up abandoned upload sessions
func (s *Manager) cleanupOldUploadSessions() {
// Run cleanup every 10 minutes
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
// Run once immediately on startup
s.cleanupOldUploadSessionsOnce()
for range ticker.C {
s.cleanupOldUploadSessionsOnce()
}
}
// cleanupOldUploadSessionsOnce removes upload sessions older than 1 hour
func (s *Manager) cleanupOldUploadSessionsOnce() {
defer func() {
if r := recover(); r != nil {
log.Printf("Panic in cleanupOldUploadSessions: %v", r)
}
}()
s.uploadSessionsMu.Lock()
defer s.uploadSessionsMu.Unlock()
now := time.Now()
cleanedCount := 0
for sessionID, session := range s.uploadSessions {
// Remove sessions older than 1 hour
age := now.Sub(session.CreatedAt)
if age > 1*time.Hour {
delete(s.uploadSessions, sessionID)
cleanedCount++
log.Printf("Cleaned up abandoned upload session: %s (user: %d, status: %s, age: %v)",
sessionID, session.UserID, session.Status, age)
}
}
if cleanedCount > 0 {
log.Printf("Cleaned up %d abandoned upload sessions", cleanedCount)
}
}