package api import ( "compress/gzip" "database/sql" "encoding/json" "fmt" "io" "log" "net/http" "os" "os/exec" "path/filepath" "runtime" "strconv" "strings" "sync" "time" authpkg "jiggablend/internal/auth" "jiggablend/internal/config" "jiggablend/internal/database" "jiggablend/internal/storage" "jiggablend/pkg/types" "jiggablend/web" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/cors" "github.com/gorilla/websocket" ) // Configuration constants const ( // WebSocket timeouts WSReadDeadline = 90 * time.Second WSPingInterval = 30 * time.Second WSWriteDeadline = 10 * time.Second // Task timeouts RenderTimeout = 60 * 60 // 1 hour for frame rendering VideoEncodeTimeout = 60 * 60 * 24 // 24 hours for encoding // Limits MaxUploadSize = 50 << 30 // 50 GB RunnerHeartbeatTimeout = 90 * time.Second TaskDistributionInterval = 10 * time.Second ProgressUpdateThrottle = 2 * time.Second // Cookie settings SessionCookieMaxAge = 86400 // 24 hours ) // Manager represents the manager server type Manager struct { db *database.DB cfg *config.Config auth *authpkg.Auth secrets *authpkg.Secrets storage *storage.Storage router *chi.Mux // WebSocket connections wsUpgrader websocket.Upgrader // DEPRECATED: Old frontend WebSocket connection maps (kept for backwards compatibility) // These will be removed in a future release. Use clientConns instead. frontendConns map[string]*websocket.Conn // key: "jobId:taskId" frontendConnsMu sync.RWMutex frontendConnsWriteMu map[string]*sync.Mutex frontendConnsWriteMuMu sync.RWMutex jobListConns map[int64]*websocket.Conn jobListConnsMu sync.RWMutex jobConns map[string]*websocket.Conn jobConnsMu sync.RWMutex jobConnsWriteMu map[string]*sync.Mutex jobConnsWriteMuMu sync.RWMutex // Per-job runner WebSocket connections (polling-based flow) // Key is "job-{jobId}-task-{taskId}" runnerJobConns map[string]*websocket.Conn runnerJobConnsMu sync.RWMutex runnerJobConnsWriteMu map[string]*sync.Mutex runnerJobConnsWriteMuMu sync.RWMutex // Throttling for progress updates (per job) progressUpdateTimes map[int64]time.Time // key: jobID progressUpdateTimesMu sync.RWMutex // Throttling for task status updates (per task) taskUpdateTimes map[int64]time.Time // key: taskID taskUpdateTimesMu sync.RWMutex // Client WebSocket connections (new unified WebSocket) // Key is "userID:connID" to support multiple tabs per user clientConns map[string]*ClientConnection clientConnsMu sync.RWMutex connIDCounter uint64 // Atomic counter for generating unique connection IDs // Upload session tracking uploadSessions map[string]*UploadSession // sessionId -> session info uploadSessionsMu sync.RWMutex // Verbose WebSocket logging (set to true to enable detailed WebSocket logs) verboseWSLogging bool // Server start time for health checks startTime time.Time } // ClientConnection represents a client WebSocket connection with subscriptions type ClientConnection struct { Conn *websocket.Conn UserID int64 ConnID string // Unique connection ID (userID:connID) IsAdmin bool Subscriptions map[string]bool // channel -> subscribed SubsMu sync.RWMutex // Protects Subscriptions map WriteMu *sync.Mutex } // UploadSession tracks upload and processing progress type UploadSession struct { SessionID string UserID int64 Progress float64 Status string // "uploading", "processing", "extracting_metadata", "creating_context", "completed", "error" Message string CreatedAt time.Time } // NewManager creates a new manager server func NewManager(db *database.DB, cfg *config.Config, auth *authpkg.Auth, storage *storage.Storage) (*Manager, error) { secrets, err := authpkg.NewSecrets(db, cfg) if err != nil { return nil, fmt.Errorf("failed to initialize secrets: %w", err) } s := &Manager{ db: db, cfg: cfg, auth: auth, secrets: secrets, storage: storage, router: chi.NewRouter(), startTime: time.Now(), wsUpgrader: websocket.Upgrader{ CheckOrigin: checkWebSocketOrigin, ReadBufferSize: 1024, WriteBufferSize: 1024, }, // DEPRECATED: Initialize old frontend WebSocket maps for backward compatibility frontendConns: make(map[string]*websocket.Conn), frontendConnsWriteMu: make(map[string]*sync.Mutex), jobListConns: make(map[int64]*websocket.Conn), jobConns: make(map[string]*websocket.Conn), jobConnsWriteMu: make(map[string]*sync.Mutex), progressUpdateTimes: make(map[int64]time.Time), taskUpdateTimes: make(map[int64]time.Time), clientConns: make(map[string]*ClientConnection), uploadSessions: make(map[string]*UploadSession), // Per-job runner WebSocket connections runnerJobConns: make(map[string]*websocket.Conn), runnerJobConnsWriteMu: make(map[string]*sync.Mutex), runnerJobConnsWriteMuMu: sync.RWMutex{}, // Initialize the new field } // Check for required external tools if err := s.checkRequiredTools(); err != nil { return nil, err } s.setupMiddleware() s.setupRoutes() s.StartBackgroundTasks() // On startup, check for runners that are marked online but not actually connected // This handles the case where the manager restarted and lost track of connections go s.recoverRunnersOnStartup() return s, nil } // checkRequiredTools verifies that required external tools are available func (s *Manager) checkRequiredTools() error { // Check for zstd (required for zstd-compressed blend files) if err := exec.Command("zstd", "--version").Run(); err != nil { return fmt.Errorf("zstd not found - required for compressed blend file support. Install with: apt install zstd") } log.Printf("Found zstd for compressed blend file support") // Check for xz (required for decompressing blender archives) if err := exec.Command("xz", "--version").Run(); err != nil { return fmt.Errorf("xz not found - required for decompressing blender archives. Install with: apt install xz-utils") } log.Printf("Found xz for blender archive decompression") return nil } // checkWebSocketOrigin validates WebSocket connection origins // In production mode, only allows same-origin connections or configured allowed origins func checkWebSocketOrigin(r *http.Request) bool { origin := r.Header.Get("Origin") if origin == "" { // No origin header - allow (could be non-browser client like runner) return true } // In development mode, allow all origins // Note: This function doesn't have access to Server, so we use authpkg.IsProductionMode() // which checks environment variable. The server setup uses s.cfg.IsProductionMode() for consistency. if !authpkg.IsProductionMode() { return true } // In production, check against allowed origins allowedOrigins := os.Getenv("ALLOWED_ORIGINS") if allowedOrigins == "" { // Default to same-origin only host := r.Host return strings.HasSuffix(origin, "://"+host) || strings.HasSuffix(origin, "://"+strings.Split(host, ":")[0]) } // Check against configured allowed origins for _, allowed := range strings.Split(allowedOrigins, ",") { allowed = strings.TrimSpace(allowed) if allowed == "*" { return true } if origin == allowed { return true } } log.Printf("WebSocket origin rejected: %s (allowed: %s)", origin, allowedOrigins) return false } // RateLimiter provides simple in-memory rate limiting per IP type RateLimiter struct { requests map[string][]time.Time mu sync.RWMutex limit int // max requests window time.Duration // time window } // NewRateLimiter creates a new rate limiter func NewRateLimiter(limit int, window time.Duration) *RateLimiter { rl := &RateLimiter{ requests: make(map[string][]time.Time), limit: limit, window: window, } // Start cleanup goroutine go rl.cleanup() return rl } // Allow checks if a request from the given IP is allowed func (rl *RateLimiter) Allow(ip string) bool { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() cutoff := now.Add(-rl.window) // Get existing requests and filter old ones reqs := rl.requests[ip] validReqs := make([]time.Time, 0, len(reqs)) for _, t := range reqs { if t.After(cutoff) { validReqs = append(validReqs, t) } } // Check if under limit if len(validReqs) >= rl.limit { rl.requests[ip] = validReqs return false } // Add this request validReqs = append(validReqs, now) rl.requests[ip] = validReqs return true } // cleanup periodically removes old entries func (rl *RateLimiter) cleanup() { ticker := time.NewTicker(5 * time.Minute) for range ticker.C { rl.mu.Lock() cutoff := time.Now().Add(-rl.window) for ip, reqs := range rl.requests { validReqs := make([]time.Time, 0, len(reqs)) for _, t := range reqs { if t.After(cutoff) { validReqs = append(validReqs, t) } } if len(validReqs) == 0 { delete(rl.requests, ip) } else { rl.requests[ip] = validReqs } } rl.mu.Unlock() } } // Global rate limiters for different endpoint types var ( // General API rate limiter: 100 requests per minute per IP apiRateLimiter = NewRateLimiter(100, time.Minute) // Auth rate limiter: 10 requests per minute per IP (stricter for login attempts) authRateLimiter = NewRateLimiter(10, time.Minute) ) // rateLimitMiddleware applies rate limiting based on client IP func rateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Get client IP (handle proxied requests) ip := r.RemoteAddr if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" { // Take the first IP in the chain if idx := strings.Index(forwarded, ","); idx != -1 { ip = strings.TrimSpace(forwarded[:idx]) } else { ip = strings.TrimSpace(forwarded) } } else if realIP := r.Header.Get("X-Real-IP"); realIP != "" { ip = strings.TrimSpace(realIP) } if !limiter.Allow(ip) { w.Header().Set("Content-Type", "application/json") w.Header().Set("Retry-After", "60") w.WriteHeader(http.StatusTooManyRequests) json.NewEncoder(w).Encode(map[string]string{ "error": "Rate limit exceeded. Please try again later.", }) return } next.ServeHTTP(w, r) }) } } // setupMiddleware configures middleware func (s *Manager) setupMiddleware() { s.router.Use(middleware.Logger) s.router.Use(middleware.Recoverer) // Note: Timeout middleware is NOT applied globally to avoid conflicts with WebSocket connections // WebSocket connections are long-lived and should not have HTTP timeouts // Check production mode from config isProduction := s.cfg.IsProductionMode() // Add rate limiting (applied in production mode only, or when explicitly enabled) if isProduction || os.Getenv("ENABLE_RATE_LIMITING") == "true" { s.router.Use(rateLimitMiddleware(apiRateLimiter)) log.Printf("Rate limiting enabled: 100 requests/minute per IP") } // Add gzip compression for JSON responses s.router.Use(gzipMiddleware) // Configure CORS based on environment corsOptions := cors.Options{ AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"}, AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "Range", "If-None-Match"}, ExposedHeaders: []string{"Link", "Content-Range", "Accept-Ranges", "Content-Length", "ETag"}, AllowCredentials: true, MaxAge: 300, } // In production, restrict CORS origins if isProduction { allowedOrigins := s.cfg.AllowedOrigins() if allowedOrigins != "" { corsOptions.AllowedOrigins = strings.Split(allowedOrigins, ",") for i := range corsOptions.AllowedOrigins { corsOptions.AllowedOrigins[i] = strings.TrimSpace(corsOptions.AllowedOrigins[i]) } } else { // Default to no origins in production if not configured // This effectively disables cross-origin requests corsOptions.AllowedOrigins = []string{} } log.Printf("Production mode: CORS restricted to origins: %v", corsOptions.AllowedOrigins) } else { // Development mode: allow all origins corsOptions.AllowedOrigins = []string{"*"} } s.router.Use(cors.Handler(corsOptions)) } // gzipMiddleware compresses responses with gzip if client supports it func gzipMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip compression for WebSocket upgrades if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" { next.ServeHTTP(w, r) return } // Check if client accepts gzip if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { next.ServeHTTP(w, r) return } // Create gzip writer gz := gzip.NewWriter(w) defer gz.Close() w.Header().Set("Content-Encoding", "gzip") w.Header().Set("Vary", "Accept-Encoding") // Wrap response writer gzw := &gzipResponseWriter{Writer: gz, ResponseWriter: w} next.ServeHTTP(gzw, r) }) } // gzipResponseWriter wraps http.ResponseWriter to add gzip compression type gzipResponseWriter struct { io.Writer http.ResponseWriter } func (w *gzipResponseWriter) Write(b []byte) (int, error) { return w.Writer.Write(b) } func (w *gzipResponseWriter) WriteHeader(statusCode int) { // Don't set Content-Length when using gzip - it will be set automatically w.ResponseWriter.WriteHeader(statusCode) } // setupRoutes configures routes func (s *Manager) setupRoutes() { // Health check endpoint (unauthenticated) s.router.Get("/api/health", s.handleHealthCheck) // Public routes (with stricter rate limiting for auth endpoints) s.router.Route("/api/auth", func(r chi.Router) { // Apply stricter rate limiting to auth endpoints in production if s.cfg.IsProductionMode() || os.Getenv("ENABLE_RATE_LIMITING") == "true" { r.Use(rateLimitMiddleware(authRateLimiter)) } r.Get("/providers", s.handleGetAuthProviders) r.Get("/google/login", s.handleGoogleLogin) r.Get("/google/callback", s.handleGoogleCallback) r.Get("/discord/login", s.handleDiscordLogin) r.Get("/discord/callback", s.handleDiscordCallback) r.Get("/local/available", s.handleLocalLoginAvailable) r.Post("/local/register", s.handleLocalRegister) r.Post("/local/login", s.handleLocalLogin) r.Post("/logout", s.handleLogout) r.Get("/me", s.handleGetMe) r.Post("/change-password", s.handleChangePassword) }) // Protected routes s.router.Route("/api/jobs", func(r chi.Router) { r.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(s.auth.Middleware(next.ServeHTTP)) }) r.Post("/", s.handleCreateJob) r.Post("/upload", s.handleUploadFileForJobCreation) // Upload before job creation r.Get("/", s.handleListJobs) r.Get("/summary", s.handleListJobsSummary) r.Post("/batch", s.handleBatchGetJobs) r.Get("/{id}", s.handleGetJob) r.Delete("/{id}", s.handleCancelJob) r.Post("/{id}/delete", s.handleDeleteJob) r.Post("/{id}/upload", s.handleUploadJobFile) r.Get("/{id}/files", s.handleListJobFiles) r.Get("/{id}/files/count", s.handleGetJobFilesCount) r.Get("/{id}/context", s.handleListContextArchive) r.Get("/{id}/files/{fileId}/download", s.handleDownloadJobFile) r.Get("/{id}/files/{fileId}/preview-exr", s.handlePreviewEXR) r.Get("/{id}/video", s.handleStreamVideo) r.Get("/{id}/metadata", s.handleGetJobMetadata) r.Get("/{id}/tasks", s.handleListJobTasks) r.Get("/{id}/tasks/summary", s.handleListJobTasksSummary) r.Post("/{id}/tasks/batch", s.handleBatchGetTasks) r.Get("/{id}/tasks/{taskId}/logs", s.handleGetTaskLogs) r.Get("/{id}/tasks/{taskId}/steps", s.handleGetTaskSteps) r.Post("/{id}/tasks/{taskId}/retry", s.handleRetryTask) // WebSocket route for unified client WebSocket r.With(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Apply authentication middleware first s.auth.Middleware(func(w http.ResponseWriter, r *http.Request) { // Remove timeout middleware for WebSocket next.ServeHTTP(w, r) })(w, r) }) }).Get("/ws", s.handleClientWebSocket) }) // Admin routes s.router.Route("/api/admin", func(r chi.Router) { r.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(s.auth.AdminMiddleware(next.ServeHTTP)) }) r.Route("/runners", func(r chi.Router) { r.Route("/api-keys", func(r chi.Router) { r.Post("/", s.handleGenerateRunnerAPIKey) r.Get("/", s.handleListRunnerAPIKeys) r.Patch("/{id}/revoke", s.handleRevokeRunnerAPIKey) r.Delete("/{id}", s.handleDeleteRunnerAPIKey) }) r.Get("/", s.handleListRunnersAdmin) r.Post("/{id}/verify", s.handleVerifyRunner) r.Delete("/{id}", s.handleDeleteRunner) }) r.Route("/users", func(r chi.Router) { r.Get("/", s.handleListUsers) r.Get("/{id}/jobs", s.handleGetUserJobs) r.Post("/{id}/admin", s.handleSetUserAdminStatus) }) r.Route("/settings", func(r chi.Router) { r.Get("/registration", s.handleGetRegistrationEnabled) r.Post("/registration", s.handleSetRegistrationEnabled) }) }) // Runner API s.router.Route("/api/runner", func(r chi.Router) { // Registration doesn't require auth (uses token) r.With(middleware.Timeout(60*time.Second)).Post("/register", s.handleRegisterRunner) // Polling-based endpoints (auth handled in handlers) r.Get("/workers/{id}/next-job", s.handleNextJob) // Per-job endpoints with job_token auth (no middleware, auth in handler) r.Get("/jobs/{jobId}/ws", s.handleRunnerJobWebSocket) r.Get("/jobs/{jobId}/context.tar", s.handleDownloadJobContextWithToken) r.Post("/jobs/{jobId}/upload", s.handleUploadFileWithToken) // Runner API endpoints (uses API key auth) r.Group(func(r chi.Router) { r.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(s.runnerAuthMiddleware(next.ServeHTTP)) }) r.Get("/blender/download", s.handleDownloadBlender) r.Get("/jobs/{jobId}/files", s.handleGetJobFilesForRunner) r.Get("/jobs/{jobId}/metadata", s.handleGetJobMetadataForRunner) r.Get("/files/{jobId}/{fileName}", s.handleDownloadFileForRunner) }) }) // Blender versions API (public, for job submission page) s.router.Get("/api/blender/versions", s.handleGetBlenderVersions) // Serve static files (embedded React app with SPA fallback) s.router.Handle("/*", web.SPAHandler()) } // ServeHTTP implements http.Handler func (s *Manager) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.router.ServeHTTP(w, r) } // JSON response helpers func (s *Manager) respondJSON(w http.ResponseWriter, status int, data interface{}) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) if err := json.NewEncoder(w).Encode(data); err != nil { log.Printf("Failed to encode JSON response: %v", err) } } func (s *Manager) respondError(w http.ResponseWriter, status int, message string) { s.respondJSON(w, status, map[string]string{"error": message}) } // createSessionCookie creates a secure session cookie with appropriate flags for the environment func createSessionCookie(sessionID string) *http.Cookie { cookie := &http.Cookie{ Name: "session_id", Value: sessionID, Path: "/", MaxAge: SessionCookieMaxAge, HttpOnly: true, SameSite: http.SameSiteLaxMode, } // In production mode, set Secure flag to require HTTPS if authpkg.IsProductionMode() { cookie.Secure = true } return cookie } // handleHealthCheck returns server health status func (s *Manager) handleHealthCheck(w http.ResponseWriter, r *http.Request) { // Check database connectivity dbHealthy := true if err := s.db.Ping(); err != nil { dbHealthy = false log.Printf("Health check: database ping failed: %v", err) } // Count online runners (based on recent heartbeat) var runnerCount int s.db.With(func(conn *sql.DB) error { return conn.QueryRow( `SELECT COUNT(*) FROM runners WHERE status = ?`, types.RunnerStatusOnline, ).Scan(&runnerCount) }) // Count connected clients s.clientConnsMu.RLock() clientCount := len(s.clientConns) s.clientConnsMu.RUnlock() // Calculate uptime uptime := time.Since(s.startTime) // Get memory stats var memStats runtime.MemStats runtime.ReadMemStats(&memStats) status := "healthy" statusCode := http.StatusOK if !dbHealthy { status = "degraded" statusCode = http.StatusServiceUnavailable } response := map[string]interface{}{ "status": status, "uptime_seconds": int64(uptime.Seconds()), "database": dbHealthy, "connected_runners": runnerCount, "connected_clients": clientCount, "memory": map[string]interface{}{ "alloc_mb": memStats.Alloc / 1024 / 1024, "total_alloc_mb": memStats.TotalAlloc / 1024 / 1024, "sys_mb": memStats.Sys / 1024 / 1024, "num_gc": memStats.NumGC, }, "timestamp": time.Now().Unix(), } s.respondJSON(w, statusCode, response) } // Auth handlers func (s *Manager) handleGoogleLogin(w http.ResponseWriter, r *http.Request) { url, err := s.auth.GoogleLoginURL() if err != nil { s.respondError(w, http.StatusInternalServerError, err.Error()) return } http.Redirect(w, r, url, http.StatusFound) } func (s *Manager) handleGoogleCallback(w http.ResponseWriter, r *http.Request) { code := r.URL.Query().Get("code") if code == "" { s.respondError(w, http.StatusBadRequest, "Missing code parameter") return } session, err := s.auth.GoogleCallback(r.Context(), code) if err != nil { // If registration is disabled, redirect back to login with error if err.Error() == "registration is disabled" { http.Redirect(w, r, "/?error=registration_disabled", http.StatusFound) return } s.respondError(w, http.StatusInternalServerError, err.Error()) return } sessionID := s.auth.CreateSession(session) http.SetCookie(w, createSessionCookie(sessionID)) http.Redirect(w, r, "/", http.StatusFound) } func (s *Manager) handleDiscordLogin(w http.ResponseWriter, r *http.Request) { url, err := s.auth.DiscordLoginURL() if err != nil { s.respondError(w, http.StatusInternalServerError, err.Error()) return } http.Redirect(w, r, url, http.StatusFound) } func (s *Manager) handleDiscordCallback(w http.ResponseWriter, r *http.Request) { code := r.URL.Query().Get("code") if code == "" { s.respondError(w, http.StatusBadRequest, "Missing code parameter") return } session, err := s.auth.DiscordCallback(r.Context(), code) if err != nil { // If registration is disabled, redirect back to login with error if err.Error() == "registration is disabled" { http.Redirect(w, r, "/?error=registration_disabled", http.StatusFound) return } s.respondError(w, http.StatusInternalServerError, err.Error()) return } sessionID := s.auth.CreateSession(session) http.SetCookie(w, createSessionCookie(sessionID)) http.Redirect(w, r, "/", http.StatusFound) } func (s *Manager) handleLogout(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("session_id") if err == nil { s.auth.DeleteSession(cookie.Value) } // Create an expired cookie to clear the session expiredCookie := &http.Cookie{ Name: "session_id", Value: "", Path: "/", MaxAge: -1, HttpOnly: true, SameSite: http.SameSiteLaxMode, } // Use s.cfg.IsProductionMode() for consistency with other server methods if s.cfg.IsProductionMode() { expiredCookie.Secure = true } http.SetCookie(w, expiredCookie) s.respondJSON(w, http.StatusOK, map[string]string{"message": "Logged out"}) } func (s *Manager) handleGetMe(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("session_id") if err != nil { log.Printf("Authentication failed: missing session cookie in /auth/me") s.respondError(w, http.StatusUnauthorized, "Not authenticated") return } session, ok := s.auth.GetSession(cookie.Value) if !ok { log.Printf("Authentication failed: invalid session cookie in /auth/me") s.respondError(w, http.StatusUnauthorized, "Invalid session") return } s.respondJSON(w, http.StatusOK, map[string]interface{}{ "id": session.UserID, "email": session.Email, "name": session.Name, "is_admin": session.IsAdmin, }) } func (s *Manager) handleGetAuthProviders(w http.ResponseWriter, r *http.Request) { s.respondJSON(w, http.StatusOK, map[string]bool{ "google": s.auth.IsGoogleOAuthConfigured(), "discord": s.auth.IsDiscordOAuthConfigured(), "local": s.auth.IsLocalLoginEnabled(), }) } func (s *Manager) handleLocalLoginAvailable(w http.ResponseWriter, r *http.Request) { s.respondJSON(w, http.StatusOK, map[string]bool{ "available": s.auth.IsLocalLoginEnabled(), }) } func (s *Manager) handleLocalRegister(w http.ResponseWriter, r *http.Request) { var req struct { Email string `json:"email"` Name string `json:"name"` Password string `json:"password"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err)) return } if req.Email == "" || req.Name == "" || req.Password == "" { s.respondError(w, http.StatusBadRequest, "Email, name, and password are required") return } if len(req.Password) < 8 { s.respondError(w, http.StatusBadRequest, "Password must be at least 8 characters long") return } session, err := s.auth.RegisterLocalUser(req.Email, req.Name, req.Password) if err != nil { s.respondError(w, http.StatusBadRequest, err.Error()) return } sessionID := s.auth.CreateSession(session) http.SetCookie(w, createSessionCookie(sessionID)) s.respondJSON(w, http.StatusCreated, map[string]interface{}{ "message": "Registration successful", "user": map[string]interface{}{ "id": session.UserID, "email": session.Email, "name": session.Name, "is_admin": session.IsAdmin, }, }) } func (s *Manager) handleLocalLogin(w http.ResponseWriter, r *http.Request) { var req struct { Username string `json:"username"` Password string `json:"password"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err)) return } if req.Username == "" || req.Password == "" { s.respondError(w, http.StatusBadRequest, "Username and password are required") return } session, err := s.auth.LocalLogin(req.Username, req.Password) if err != nil { log.Printf("Authentication failed: invalid credentials for username '%s'", req.Username) s.respondError(w, http.StatusUnauthorized, "Invalid credentials") return } sessionID := s.auth.CreateSession(session) http.SetCookie(w, createSessionCookie(sessionID)) s.respondJSON(w, http.StatusOK, map[string]interface{}{ "message": "Login successful", "user": map[string]interface{}{ "id": session.UserID, "email": session.Email, "name": session.Name, "is_admin": session.IsAdmin, }, }) } func (s *Manager) handleChangePassword(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) return } var req struct { OldPassword string `json:"old_password"` NewPassword string `json:"new_password"` TargetUserID *int64 `json:"target_user_id,omitempty"` // For admin to change other users' passwords } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err)) return } if req.NewPassword == "" { s.respondError(w, http.StatusBadRequest, "New password is required") return } if len(req.NewPassword) < 8 { s.respondError(w, http.StatusBadRequest, "Password must be at least 8 characters long") return } isAdmin := authpkg.IsAdmin(r.Context()) // If target_user_id is provided and user is admin, allow changing other user's password if req.TargetUserID != nil && isAdmin { if err := s.auth.AdminChangePassword(*req.TargetUserID, req.NewPassword); err != nil { s.respondError(w, http.StatusBadRequest, err.Error()) return } s.respondJSON(w, http.StatusOK, map[string]string{"message": "Password changed successfully"}) return } // Otherwise, user is changing their own password (requires old password) if req.OldPassword == "" { s.respondError(w, http.StatusBadRequest, "Old password is required") return } if err := s.auth.ChangePassword(userID, req.OldPassword, req.NewPassword); err != nil { s.respondError(w, http.StatusBadRequest, err.Error()) return } s.respondJSON(w, http.StatusOK, map[string]string{"message": "Password changed successfully"}) } // Helper to get user ID from context func getUserID(r *http.Request) (int64, error) { userID, ok := authpkg.GetUserID(r.Context()) if !ok { return 0, fmt.Errorf("user ID not found in context") } return userID, nil } // Helper to parse ID from URL func parseID(r *http.Request, param string) (int64, error) { idStr := chi.URLParam(r, param) id, err := strconv.ParseInt(idStr, 10, 64) if err != nil { return 0, fmt.Errorf("invalid ID: %s", idStr) } return id, nil } // StartBackgroundTasks starts background goroutines for error recovery func (s *Manager) StartBackgroundTasks() { go s.recoverStuckTasks() go s.cleanupOldRenderJobs() go s.cleanupOldTempDirectories() go s.cleanupOldOfflineRunners() go s.cleanupOldUploadSessions() } // recoverRunnersOnStartup marks runners as offline on startup // In the polling model, runners will update their status when they poll for jobs func (s *Manager) recoverRunnersOnStartup() { log.Printf("Recovering runners on startup: marking all as offline...") // Mark all runners as offline - they'll be marked online when they poll var runnersAffected int64 err := s.db.With(func(conn *sql.DB) error { result, err := conn.Exec( `UPDATE runners SET status = ? WHERE status = ?`, types.RunnerStatusOffline, types.RunnerStatusOnline, ) if err != nil { return err } runnersAffected, _ = result.RowsAffected() return nil }) if err != nil { log.Printf("Failed to mark runners as offline on startup: %v", err) return } if runnersAffected > 0 { log.Printf("Marked %d runners as offline on startup", runnersAffected) } // Reset any running tasks that were assigned to runners // They will be picked up by runners when they poll var tasksAffected int64 err = s.db.With(func(conn *sql.DB) error { result, err := conn.Exec( `UPDATE tasks SET runner_id = NULL, status = ?, started_at = NULL WHERE status = ?`, types.TaskStatusPending, types.TaskStatusRunning, ) if err != nil { return err } tasksAffected, _ = result.RowsAffected() return nil }) if err != nil { log.Printf("Failed to reset running tasks on startup: %v", err) return } if tasksAffected > 0 { log.Printf("Reset %d running tasks to pending on startup", tasksAffected) } } // recoverStuckTasks periodically checks for dead runners and stuck tasks func (s *Manager) recoverStuckTasks() { ticker := time.NewTicker(TaskDistributionInterval) defer ticker.Stop() for range ticker.C { func() { defer func() { if r := recover(); r != nil { log.Printf("Panic in recoverStuckTasks: %v", r) } }() // Find dead runners (no heartbeat for configured timeout) // In polling model, heartbeat is updated when runner polls for jobs var deadRunnerIDs []int64 cutoffTime := time.Now().Add(-RunnerHeartbeatTimeout) err := s.db.With(func(conn *sql.DB) error { rows, err := conn.Query( `SELECT id FROM runners WHERE last_heartbeat < ? AND status = ?`, cutoffTime, types.RunnerStatusOnline, ) if err != nil { return err } defer rows.Close() for rows.Next() { var runnerID int64 if err := rows.Scan(&runnerID); err == nil { deadRunnerIDs = append(deadRunnerIDs, runnerID) } } return nil }) if err != nil { log.Printf("Failed to query dead runners: %v", err) return } // Reset tasks assigned to dead runners for _, runnerID := range deadRunnerIDs { s.resetRunnerTasks(runnerID) // Mark runner as offline s.db.With(func(conn *sql.DB) error { _, _ = conn.Exec( `UPDATE runners SET status = ? WHERE id = ?`, types.RunnerStatusOffline, runnerID, ) return nil }) } // Check for task timeouts s.recoverTaskTimeouts() }() } } // recoverTaskTimeouts handles tasks that have exceeded their timeout // Timeouts are treated as runner failures (not task failures) and retry indefinitely func (s *Manager) recoverTaskTimeouts() { // Find tasks running longer than their timeout var tasks []struct { taskID int64 jobID int64 runnerID sql.NullInt64 timeoutSeconds sql.NullInt64 startedAt time.Time } err := s.db.With(func(conn *sql.DB) error { rows, err := conn.Query( `SELECT t.id, t.job_id, t.runner_id, t.timeout_seconds, t.started_at FROM tasks t WHERE t.status = ? AND t.started_at IS NOT NULL AND (t.completed_at IS NULL OR t.completed_at < datetime('now', '-30 seconds')) AND (t.timeout_seconds IS NULL OR (julianday('now') - julianday(t.started_at)) * 86400 > t.timeout_seconds)`, types.TaskStatusRunning, ) if err != nil { return err } defer rows.Close() for rows.Next() { var task struct { taskID int64 jobID int64 runnerID sql.NullInt64 timeoutSeconds sql.NullInt64 startedAt time.Time } err := rows.Scan(&task.taskID, &task.jobID, &task.runnerID, &task.timeoutSeconds, &task.startedAt) if err != nil { log.Printf("Failed to scan task row in recoverTaskTimeouts: %v", err) continue } tasks = append(tasks, task) } return nil }) if err != nil { log.Printf("Failed to query timed out tasks: %v", err) return } for _, task := range tasks { taskID := task.taskID jobID := task.jobID timeoutSeconds := task.timeoutSeconds startedAt := task.startedAt // Use default timeout if not set (5 minutes for frame tasks, 24 hours for FFmpeg) timeout := 300 // 5 minutes default if timeoutSeconds.Valid { timeout = int(timeoutSeconds.Int64) } // Check if actually timed out if time.Since(startedAt).Seconds() < float64(timeout) { continue } // Timeouts are runner failures - always reset to pending and increment runner_failure_count // This does NOT count against retry_count (which is for actual task failures like Blender crashes) err = s.db.With(func(conn *sql.DB) error { _, err := conn.Exec(`UPDATE tasks SET status = ? WHERE id = ?`, types.TaskStatusPending, taskID) if err != nil { return err } _, err = conn.Exec(`UPDATE tasks SET runner_id = NULL WHERE id = ?`, taskID) if err != nil { return err } _, err = conn.Exec(`UPDATE tasks SET current_step = NULL WHERE id = ?`, taskID) if err != nil { return err } _, err = conn.Exec(`UPDATE tasks SET started_at = NULL WHERE id = ?`, taskID) if err != nil { return err } _, err = conn.Exec(`UPDATE tasks SET runner_failure_count = runner_failure_count + 1 WHERE id = ?`, taskID) if err != nil { return err } // Clear steps and logs for fresh retry _, err = conn.Exec(`DELETE FROM task_steps WHERE task_id = ?`, taskID) if err != nil { return err } _, err = conn.Exec(`DELETE FROM task_logs WHERE task_id = ?`, taskID) return err }) if err == nil { // Broadcast task reset to clients (includes steps_cleared and logs_cleared flags) s.broadcastTaskUpdate(jobID, taskID, "task_reset", map[string]interface{}{ "status": types.TaskStatusPending, "runner_id": nil, "current_step": nil, "started_at": nil, "steps_cleared": true, "logs_cleared": true, }) // Update job status s.updateJobStatusFromTasks(jobID) log.Printf("Reset timed out task %d: %v", taskID, err) } else { log.Printf("Failed to reset timed out task %d: %v", taskID, err) } } } // cleanupOldTempDirectories periodically cleans up old temporary directories func (s *Manager) cleanupOldTempDirectories() { // Run cleanup every hour ticker := time.NewTicker(1 * time.Hour) defer ticker.Stop() // Run once immediately on startup s.cleanupOldTempDirectoriesOnce() for range ticker.C { s.cleanupOldTempDirectoriesOnce() } } // cleanupOldTempDirectoriesOnce removes temp directories older than 1 hour func (s *Manager) cleanupOldTempDirectoriesOnce() { defer func() { if r := recover(); r != nil { log.Printf("Panic in cleanupOldTempDirectories: %v", r) } }() tempPath := filepath.Join(s.storage.BasePath(), "temp") // Check if temp directory exists if _, err := os.Stat(tempPath); os.IsNotExist(err) { return } // Read all entries in temp directory entries, err := os.ReadDir(tempPath) if err != nil { log.Printf("Failed to read temp directory: %v", err) return } now := time.Now() cleanedCount := 0 // Check upload sessions to avoid deleting active uploads s.uploadSessionsMu.RLock() activeSessions := make(map[string]bool) for sessionID := range s.uploadSessions { activeSessions[sessionID] = true } s.uploadSessionsMu.RUnlock() for _, entry := range entries { if !entry.IsDir() { continue } entryPath := filepath.Join(tempPath, entry.Name()) // Skip if this directory has an active upload session if activeSessions[entryPath] { continue } // Get directory info to check modification time info, err := entry.Info() if err != nil { continue } // Remove directories older than 1 hour (only if no active session) age := now.Sub(info.ModTime()) if age > 1*time.Hour { if err := os.RemoveAll(entryPath); err != nil { log.Printf("Warning: Failed to clean up old temp directory %s: %v", entryPath, err) } else { cleanedCount++ log.Printf("Cleaned up old temp directory: %s (age: %v)", entryPath, age) } } } if cleanedCount > 0 { log.Printf("Cleaned up %d old temp directories", cleanedCount) } } // cleanupOldUploadSessions periodically cleans up abandoned upload sessions func (s *Manager) cleanupOldUploadSessions() { // Run cleanup every 10 minutes ticker := time.NewTicker(10 * time.Minute) defer ticker.Stop() // Run once immediately on startup s.cleanupOldUploadSessionsOnce() for range ticker.C { s.cleanupOldUploadSessionsOnce() } } // cleanupOldUploadSessionsOnce removes upload sessions older than 1 hour func (s *Manager) cleanupOldUploadSessionsOnce() { defer func() { if r := recover(); r != nil { log.Printf("Panic in cleanupOldUploadSessions: %v", r) } }() s.uploadSessionsMu.Lock() defer s.uploadSessionsMu.Unlock() now := time.Now() cleanedCount := 0 for sessionID, session := range s.uploadSessions { // Remove sessions older than 1 hour age := now.Sub(session.CreatedAt) if age > 1*time.Hour { delete(s.uploadSessions, sessionID) cleanedCount++ log.Printf("Cleaned up abandoned upload session: %s (user: %d, status: %s, age: %v)", sessionID, session.UserID, session.Status, age) } } if cleanedCount > 0 { log.Printf("Cleaned up %d abandoned upload sessions", cleanedCount) } }