All checks were successful
PR Check / check-and-test (pull_request) Successful in 26s
- Removed redundant error handling in handleListJobTasks. - Introduced per-job mutexes in Manager to serialize updateJobStatusFromTasks calls, ensuring thread safety during concurrent task completions. - Added methods to manage job status update mutexes, including creation and cleanup after job completion or failure. - Improved error handling in handleGetJobStatusForRunner by consolidating error checks.
1308 lines
39 KiB
Go
1308 lines
39 KiB
Go
package api
|
|
|
|
import (
|
|
"compress/gzip"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
authpkg "jiggablend/internal/auth"
|
|
"jiggablend/internal/config"
|
|
"jiggablend/internal/database"
|
|
"jiggablend/internal/storage"
|
|
"jiggablend/pkg/types"
|
|
"jiggablend/web"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
"github.com/go-chi/cors"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// Configuration constants
|
|
const (
|
|
// WebSocket timeouts
|
|
WSReadDeadline = 90 * time.Second
|
|
WSPingInterval = 30 * time.Second
|
|
WSWriteDeadline = 10 * time.Second
|
|
|
|
// Task timeouts
|
|
RenderTimeout = 60 * 60 // 1 hour for frame rendering
|
|
VideoEncodeTimeout = 60 * 60 * 24 // 24 hours for encoding
|
|
|
|
// Limits
|
|
MaxUploadSize = 50 << 30 // 50 GB
|
|
RunnerHeartbeatTimeout = 90 * time.Second
|
|
TaskDistributionInterval = 10 * time.Second
|
|
ProgressUpdateThrottle = 2 * time.Second
|
|
|
|
// Cookie settings
|
|
SessionCookieMaxAge = 86400 // 24 hours
|
|
)
|
|
|
|
// Manager represents the manager server
|
|
type Manager struct {
|
|
db *database.DB
|
|
cfg *config.Config
|
|
auth *authpkg.Auth
|
|
secrets *authpkg.Secrets
|
|
storage *storage.Storage
|
|
router *chi.Mux
|
|
|
|
// WebSocket connections
|
|
wsUpgrader websocket.Upgrader
|
|
|
|
// DEPRECATED: Old frontend WebSocket connection maps (kept for backwards compatibility)
|
|
// These will be removed in a future release. Use clientConns instead.
|
|
frontendConns map[string]*websocket.Conn // key: "jobId:taskId"
|
|
frontendConnsMu sync.RWMutex
|
|
frontendConnsWriteMu map[string]*sync.Mutex
|
|
frontendConnsWriteMuMu sync.RWMutex
|
|
jobListConns map[int64]*websocket.Conn
|
|
jobListConnsMu sync.RWMutex
|
|
jobConns map[string]*websocket.Conn
|
|
jobConnsMu sync.RWMutex
|
|
jobConnsWriteMu map[string]*sync.Mutex
|
|
jobConnsWriteMuMu sync.RWMutex
|
|
|
|
// Per-job runner WebSocket connections (polling-based flow)
|
|
// Key is "job-{jobId}-task-{taskId}"
|
|
runnerJobConns map[string]*websocket.Conn
|
|
runnerJobConnsMu sync.RWMutex
|
|
runnerJobConnsWriteMu map[string]*sync.Mutex
|
|
runnerJobConnsWriteMuMu sync.RWMutex
|
|
|
|
// Throttling for progress updates (per job)
|
|
progressUpdateTimes map[int64]time.Time // key: jobID
|
|
progressUpdateTimesMu sync.RWMutex
|
|
// Throttling for task status updates (per task)
|
|
taskUpdateTimes map[int64]time.Time // key: taskID
|
|
taskUpdateTimesMu sync.RWMutex
|
|
// Per-job mutexes to serialize updateJobStatusFromTasks calls and prevent race conditions
|
|
jobStatusUpdateMu map[int64]*sync.Mutex // key: jobID
|
|
jobStatusUpdateMuMu sync.RWMutex
|
|
|
|
// Client WebSocket connections (new unified WebSocket)
|
|
// Key is "userID:connID" to support multiple tabs per user
|
|
clientConns map[string]*ClientConnection
|
|
clientConnsMu sync.RWMutex
|
|
connIDCounter uint64 // Atomic counter for generating unique connection IDs
|
|
|
|
// Upload session tracking
|
|
uploadSessions map[string]*UploadSession // sessionId -> session info
|
|
uploadSessionsMu sync.RWMutex
|
|
|
|
// Verbose WebSocket logging (set to true to enable detailed WebSocket logs)
|
|
verboseWSLogging bool
|
|
|
|
// Server start time for health checks
|
|
startTime time.Time
|
|
}
|
|
|
|
// ClientConnection represents a client WebSocket connection with subscriptions
|
|
type ClientConnection struct {
|
|
Conn *websocket.Conn
|
|
UserID int64
|
|
ConnID string // Unique connection ID (userID:connID)
|
|
IsAdmin bool
|
|
Subscriptions map[string]bool // channel -> subscribed
|
|
SubsMu sync.RWMutex // Protects Subscriptions map
|
|
WriteMu *sync.Mutex
|
|
}
|
|
|
|
// UploadSession tracks upload and processing progress
|
|
type UploadSession struct {
|
|
SessionID string
|
|
UserID int64
|
|
Progress float64
|
|
Status string // "uploading", "processing", "extracting_metadata", "creating_context", "completed", "error"
|
|
Message string
|
|
CreatedAt time.Time
|
|
}
|
|
|
|
// NewManager creates a new manager server
|
|
func NewManager(db *database.DB, cfg *config.Config, auth *authpkg.Auth, storage *storage.Storage) (*Manager, error) {
|
|
secrets, err := authpkg.NewSecrets(db, cfg)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to initialize secrets: %w", err)
|
|
}
|
|
|
|
s := &Manager{
|
|
db: db,
|
|
cfg: cfg,
|
|
auth: auth,
|
|
secrets: secrets,
|
|
storage: storage,
|
|
router: chi.NewRouter(),
|
|
startTime: time.Now(),
|
|
wsUpgrader: websocket.Upgrader{
|
|
CheckOrigin: checkWebSocketOrigin,
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
},
|
|
// DEPRECATED: Initialize old frontend WebSocket maps for backward compatibility
|
|
frontendConns: make(map[string]*websocket.Conn),
|
|
frontendConnsWriteMu: make(map[string]*sync.Mutex),
|
|
jobListConns: make(map[int64]*websocket.Conn),
|
|
jobConns: make(map[string]*websocket.Conn),
|
|
jobConnsWriteMu: make(map[string]*sync.Mutex),
|
|
progressUpdateTimes: make(map[int64]time.Time),
|
|
taskUpdateTimes: make(map[int64]time.Time),
|
|
clientConns: make(map[string]*ClientConnection),
|
|
uploadSessions: make(map[string]*UploadSession),
|
|
// Per-job runner WebSocket connections
|
|
runnerJobConns: make(map[string]*websocket.Conn),
|
|
runnerJobConnsWriteMu: make(map[string]*sync.Mutex),
|
|
runnerJobConnsWriteMuMu: sync.RWMutex{}, // Initialize the new field
|
|
// Per-job mutexes for serializing status updates
|
|
jobStatusUpdateMu: make(map[int64]*sync.Mutex),
|
|
}
|
|
|
|
// Check for required external tools
|
|
if err := s.checkRequiredTools(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
s.setupMiddleware()
|
|
s.setupRoutes()
|
|
s.StartBackgroundTasks()
|
|
|
|
// On startup, check for runners that are marked online but not actually connected
|
|
// This handles the case where the manager restarted and lost track of connections
|
|
go s.recoverRunnersOnStartup()
|
|
|
|
return s, nil
|
|
}
|
|
|
|
// checkRequiredTools verifies that required external tools are available
|
|
func (s *Manager) checkRequiredTools() error {
|
|
// Check for zstd (required for zstd-compressed blend files)
|
|
if err := exec.Command("zstd", "--version").Run(); err != nil {
|
|
return fmt.Errorf("zstd not found - required for compressed blend file support. Install with: apt install zstd")
|
|
}
|
|
log.Printf("Found zstd for compressed blend file support")
|
|
|
|
// Check for xz (required for decompressing blender archives)
|
|
if err := exec.Command("xz", "--version").Run(); err != nil {
|
|
return fmt.Errorf("xz not found - required for decompressing blender archives. Install with: apt install xz-utils")
|
|
}
|
|
log.Printf("Found xz for blender archive decompression")
|
|
|
|
return nil
|
|
}
|
|
|
|
// checkWebSocketOrigin validates WebSocket connection origins
|
|
// In production mode, only allows same-origin connections or configured allowed origins
|
|
func checkWebSocketOrigin(r *http.Request) bool {
|
|
origin := r.Header.Get("Origin")
|
|
if origin == "" {
|
|
// No origin header - allow (could be non-browser client like runner)
|
|
return true
|
|
}
|
|
|
|
// In development mode, allow all origins
|
|
// Note: This function doesn't have access to Server, so we use authpkg.IsProductionMode()
|
|
// which checks environment variable. The server setup uses s.cfg.IsProductionMode() for consistency.
|
|
if !authpkg.IsProductionMode() {
|
|
return true
|
|
}
|
|
|
|
// In production, check against allowed origins
|
|
allowedOrigins := os.Getenv("ALLOWED_ORIGINS")
|
|
if allowedOrigins == "" {
|
|
// Default to same-origin only
|
|
host := r.Host
|
|
return strings.HasSuffix(origin, "://"+host) || strings.HasSuffix(origin, "://"+strings.Split(host, ":")[0])
|
|
}
|
|
|
|
// Check against configured allowed origins
|
|
for _, allowed := range strings.Split(allowedOrigins, ",") {
|
|
allowed = strings.TrimSpace(allowed)
|
|
if allowed == "*" {
|
|
return true
|
|
}
|
|
if origin == allowed {
|
|
return true
|
|
}
|
|
}
|
|
|
|
log.Printf("WebSocket origin rejected: %s (allowed: %s)", origin, allowedOrigins)
|
|
return false
|
|
}
|
|
|
|
// RateLimiter provides simple in-memory rate limiting per IP
|
|
type RateLimiter struct {
|
|
requests map[string][]time.Time
|
|
mu sync.RWMutex
|
|
limit int // max requests
|
|
window time.Duration // time window
|
|
}
|
|
|
|
// NewRateLimiter creates a new rate limiter
|
|
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
|
|
rl := &RateLimiter{
|
|
requests: make(map[string][]time.Time),
|
|
limit: limit,
|
|
window: window,
|
|
}
|
|
// Start cleanup goroutine
|
|
go rl.cleanup()
|
|
return rl
|
|
}
|
|
|
|
// Allow checks if a request from the given IP is allowed
|
|
func (rl *RateLimiter) Allow(ip string) bool {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
cutoff := now.Add(-rl.window)
|
|
|
|
// Get existing requests and filter old ones
|
|
reqs := rl.requests[ip]
|
|
validReqs := make([]time.Time, 0, len(reqs))
|
|
for _, t := range reqs {
|
|
if t.After(cutoff) {
|
|
validReqs = append(validReqs, t)
|
|
}
|
|
}
|
|
|
|
// Check if under limit
|
|
if len(validReqs) >= rl.limit {
|
|
rl.requests[ip] = validReqs
|
|
return false
|
|
}
|
|
|
|
// Add this request
|
|
validReqs = append(validReqs, now)
|
|
rl.requests[ip] = validReqs
|
|
return true
|
|
}
|
|
|
|
// cleanup periodically removes old entries
|
|
func (rl *RateLimiter) cleanup() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
for range ticker.C {
|
|
rl.mu.Lock()
|
|
cutoff := time.Now().Add(-rl.window)
|
|
for ip, reqs := range rl.requests {
|
|
validReqs := make([]time.Time, 0, len(reqs))
|
|
for _, t := range reqs {
|
|
if t.After(cutoff) {
|
|
validReqs = append(validReqs, t)
|
|
}
|
|
}
|
|
if len(validReqs) == 0 {
|
|
delete(rl.requests, ip)
|
|
} else {
|
|
rl.requests[ip] = validReqs
|
|
}
|
|
}
|
|
rl.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// Global rate limiters for different endpoint types
|
|
var (
|
|
// General API rate limiter: 100 requests per minute per IP
|
|
apiRateLimiter = NewRateLimiter(100, time.Minute)
|
|
// Auth rate limiter: 10 requests per minute per IP (stricter for login attempts)
|
|
authRateLimiter = NewRateLimiter(10, time.Minute)
|
|
)
|
|
|
|
// rateLimitMiddleware applies rate limiting based on client IP
|
|
func rateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Get client IP (handle proxied requests)
|
|
ip := r.RemoteAddr
|
|
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
|
|
// Take the first IP in the chain
|
|
if idx := strings.Index(forwarded, ","); idx != -1 {
|
|
ip = strings.TrimSpace(forwarded[:idx])
|
|
} else {
|
|
ip = strings.TrimSpace(forwarded)
|
|
}
|
|
} else if realIP := r.Header.Get("X-Real-IP"); realIP != "" {
|
|
ip = strings.TrimSpace(realIP)
|
|
}
|
|
|
|
if !limiter.Allow(ip) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Retry-After", "60")
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"error": "Rate limit exceeded. Please try again later.",
|
|
})
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// setupMiddleware configures middleware
|
|
func (s *Manager) setupMiddleware() {
|
|
s.router.Use(middleware.Logger)
|
|
s.router.Use(middleware.Recoverer)
|
|
// Note: Timeout middleware is NOT applied globally to avoid conflicts with WebSocket connections
|
|
// WebSocket connections are long-lived and should not have HTTP timeouts
|
|
|
|
// Check production mode from config
|
|
isProduction := s.cfg.IsProductionMode()
|
|
|
|
// Add rate limiting (applied in production mode only, or when explicitly enabled)
|
|
if isProduction || os.Getenv("ENABLE_RATE_LIMITING") == "true" {
|
|
s.router.Use(rateLimitMiddleware(apiRateLimiter))
|
|
log.Printf("Rate limiting enabled: 100 requests/minute per IP")
|
|
}
|
|
|
|
// Add gzip compression for JSON responses
|
|
s.router.Use(gzipMiddleware)
|
|
|
|
// Configure CORS based on environment
|
|
corsOptions := cors.Options{
|
|
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"},
|
|
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "Range", "If-None-Match"},
|
|
ExposedHeaders: []string{"Link", "Content-Range", "Accept-Ranges", "Content-Length", "ETag"},
|
|
AllowCredentials: true,
|
|
MaxAge: 300,
|
|
}
|
|
|
|
// In production, restrict CORS origins
|
|
if isProduction {
|
|
allowedOrigins := s.cfg.AllowedOrigins()
|
|
if allowedOrigins != "" {
|
|
corsOptions.AllowedOrigins = strings.Split(allowedOrigins, ",")
|
|
for i := range corsOptions.AllowedOrigins {
|
|
corsOptions.AllowedOrigins[i] = strings.TrimSpace(corsOptions.AllowedOrigins[i])
|
|
}
|
|
} else {
|
|
// Default to no origins in production if not configured
|
|
// This effectively disables cross-origin requests
|
|
corsOptions.AllowedOrigins = []string{}
|
|
}
|
|
log.Printf("Production mode: CORS restricted to origins: %v", corsOptions.AllowedOrigins)
|
|
} else {
|
|
// Development mode: allow all origins
|
|
corsOptions.AllowedOrigins = []string{"*"}
|
|
}
|
|
|
|
s.router.Use(cors.Handler(corsOptions))
|
|
}
|
|
|
|
// gzipMiddleware compresses responses with gzip if client supports it
|
|
func gzipMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Skip compression for WebSocket upgrades
|
|
if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
// Check if client accepts gzip
|
|
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
// Create gzip writer
|
|
gz := gzip.NewWriter(w)
|
|
defer gz.Close()
|
|
|
|
w.Header().Set("Content-Encoding", "gzip")
|
|
w.Header().Set("Vary", "Accept-Encoding")
|
|
|
|
// Wrap response writer
|
|
gzw := &gzipResponseWriter{Writer: gz, ResponseWriter: w}
|
|
next.ServeHTTP(gzw, r)
|
|
})
|
|
}
|
|
|
|
// gzipResponseWriter wraps http.ResponseWriter to add gzip compression
|
|
type gzipResponseWriter struct {
|
|
io.Writer
|
|
http.ResponseWriter
|
|
}
|
|
|
|
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
|
|
return w.Writer.Write(b)
|
|
}
|
|
|
|
func (w *gzipResponseWriter) WriteHeader(statusCode int) {
|
|
// Don't set Content-Length when using gzip - it will be set automatically
|
|
w.ResponseWriter.WriteHeader(statusCode)
|
|
}
|
|
|
|
// setupRoutes configures routes
|
|
func (s *Manager) setupRoutes() {
|
|
// Health check endpoint (unauthenticated)
|
|
s.router.Get("/api/health", s.handleHealthCheck)
|
|
|
|
// Public routes (with stricter rate limiting for auth endpoints)
|
|
s.router.Route("/api/auth", func(r chi.Router) {
|
|
// Apply stricter rate limiting to auth endpoints in production
|
|
if s.cfg.IsProductionMode() || os.Getenv("ENABLE_RATE_LIMITING") == "true" {
|
|
r.Use(rateLimitMiddleware(authRateLimiter))
|
|
}
|
|
r.Get("/providers", s.handleGetAuthProviders)
|
|
r.Get("/google/login", s.handleGoogleLogin)
|
|
r.Get("/google/callback", s.handleGoogleCallback)
|
|
r.Get("/discord/login", s.handleDiscordLogin)
|
|
r.Get("/discord/callback", s.handleDiscordCallback)
|
|
r.Get("/local/available", s.handleLocalLoginAvailable)
|
|
r.Post("/local/register", s.handleLocalRegister)
|
|
r.Post("/local/login", s.handleLocalLogin)
|
|
r.Post("/logout", s.handleLogout)
|
|
r.Get("/me", s.handleGetMe)
|
|
r.Post("/change-password", s.handleChangePassword)
|
|
})
|
|
|
|
// Protected routes
|
|
s.router.Route("/api/jobs", func(r chi.Router) {
|
|
r.Use(func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(s.auth.Middleware(next.ServeHTTP))
|
|
})
|
|
r.Post("/", s.handleCreateJob)
|
|
r.Post("/upload", s.handleUploadFileForJobCreation) // Upload before job creation
|
|
r.Get("/", s.handleListJobs)
|
|
r.Get("/summary", s.handleListJobsSummary)
|
|
r.Post("/batch", s.handleBatchGetJobs)
|
|
r.Get("/{id}", s.handleGetJob)
|
|
r.Delete("/{id}", s.handleCancelJob)
|
|
r.Post("/{id}/delete", s.handleDeleteJob)
|
|
r.Post("/{id}/upload", s.handleUploadJobFile)
|
|
r.Get("/{id}/files", s.handleListJobFiles)
|
|
r.Get("/{id}/files/count", s.handleGetJobFilesCount)
|
|
r.Get("/{id}/context", s.handleListContextArchive)
|
|
r.Get("/{id}/files/{fileId}/download", s.handleDownloadJobFile)
|
|
r.Get("/{id}/files/{fileId}/preview-exr", s.handlePreviewEXR)
|
|
r.Get("/{id}/video", s.handleStreamVideo)
|
|
r.Get("/{id}/metadata", s.handleGetJobMetadata)
|
|
r.Get("/{id}/tasks", s.handleListJobTasks)
|
|
r.Get("/{id}/tasks/summary", s.handleListJobTasksSummary)
|
|
r.Post("/{id}/tasks/batch", s.handleBatchGetTasks)
|
|
r.Get("/{id}/tasks/{taskId}/logs", s.handleGetTaskLogs)
|
|
r.Get("/{id}/tasks/{taskId}/steps", s.handleGetTaskSteps)
|
|
r.Post("/{id}/tasks/{taskId}/retry", s.handleRetryTask)
|
|
// WebSocket route for unified client WebSocket
|
|
r.With(func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Apply authentication middleware first
|
|
s.auth.Middleware(func(w http.ResponseWriter, r *http.Request) {
|
|
// Remove timeout middleware for WebSocket
|
|
next.ServeHTTP(w, r)
|
|
})(w, r)
|
|
})
|
|
}).Get("/ws", s.handleClientWebSocket)
|
|
})
|
|
|
|
// Admin routes
|
|
s.router.Route("/api/admin", func(r chi.Router) {
|
|
r.Use(func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(s.auth.AdminMiddleware(next.ServeHTTP))
|
|
})
|
|
r.Route("/runners", func(r chi.Router) {
|
|
r.Route("/api-keys", func(r chi.Router) {
|
|
r.Post("/", s.handleGenerateRunnerAPIKey)
|
|
r.Get("/", s.handleListRunnerAPIKeys)
|
|
r.Patch("/{id}/revoke", s.handleRevokeRunnerAPIKey)
|
|
r.Delete("/{id}", s.handleDeleteRunnerAPIKey)
|
|
})
|
|
r.Get("/", s.handleListRunnersAdmin)
|
|
r.Post("/{id}/verify", s.handleVerifyRunner)
|
|
r.Delete("/{id}", s.handleDeleteRunner)
|
|
})
|
|
r.Route("/users", func(r chi.Router) {
|
|
r.Get("/", s.handleListUsers)
|
|
r.Get("/{id}/jobs", s.handleGetUserJobs)
|
|
r.Post("/{id}/admin", s.handleSetUserAdminStatus)
|
|
})
|
|
r.Route("/settings", func(r chi.Router) {
|
|
r.Get("/registration", s.handleGetRegistrationEnabled)
|
|
r.Post("/registration", s.handleSetRegistrationEnabled)
|
|
})
|
|
})
|
|
|
|
// Runner API
|
|
s.router.Route("/api/runner", func(r chi.Router) {
|
|
// Registration doesn't require auth (uses token)
|
|
r.With(middleware.Timeout(60*time.Second)).Post("/register", s.handleRegisterRunner)
|
|
|
|
// Polling-based endpoints (auth handled in handlers)
|
|
r.Get("/workers/{id}/next-job", s.handleNextJob)
|
|
|
|
// Per-job endpoints with job_token auth (no middleware, auth in handler)
|
|
r.Get("/jobs/{jobId}/ws", s.handleRunnerJobWebSocket)
|
|
r.Get("/jobs/{jobId}/context.tar", s.handleDownloadJobContextWithToken)
|
|
r.Post("/jobs/{jobId}/upload", s.handleUploadFileWithToken)
|
|
|
|
// Runner API endpoints (uses API key auth)
|
|
r.Group(func(r chi.Router) {
|
|
r.Use(func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(s.runnerAuthMiddleware(next.ServeHTTP))
|
|
})
|
|
r.Get("/blender/download", s.handleDownloadBlender)
|
|
r.Get("/jobs/{jobId}/files", s.handleGetJobFilesForRunner)
|
|
r.Get("/jobs/{jobId}/metadata", s.handleGetJobMetadataForRunner)
|
|
r.Get("/files/{jobId}/{fileName}", s.handleDownloadFileForRunner)
|
|
})
|
|
})
|
|
|
|
// Blender versions API (public, for job submission page)
|
|
s.router.Get("/api/blender/versions", s.handleGetBlenderVersions)
|
|
|
|
// Serve static files (embedded React app with SPA fallback)
|
|
s.router.Handle("/*", web.SPAHandler())
|
|
}
|
|
|
|
// ServeHTTP implements http.Handler
|
|
func (s *Manager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
s.router.ServeHTTP(w, r)
|
|
}
|
|
|
|
// JSON response helpers
|
|
func (s *Manager) respondJSON(w http.ResponseWriter, status int, data interface{}) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
if err := json.NewEncoder(w).Encode(data); err != nil {
|
|
log.Printf("Failed to encode JSON response: %v", err)
|
|
}
|
|
}
|
|
|
|
func (s *Manager) respondError(w http.ResponseWriter, status int, message string) {
|
|
s.respondJSON(w, status, map[string]string{"error": message})
|
|
}
|
|
|
|
// createSessionCookie creates a secure session cookie with appropriate flags for the environment
|
|
func createSessionCookie(sessionID string) *http.Cookie {
|
|
cookie := &http.Cookie{
|
|
Name: "session_id",
|
|
Value: sessionID,
|
|
Path: "/",
|
|
MaxAge: SessionCookieMaxAge,
|
|
HttpOnly: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
}
|
|
|
|
// In production mode, set Secure flag to require HTTPS
|
|
if authpkg.IsProductionMode() {
|
|
cookie.Secure = true
|
|
}
|
|
|
|
return cookie
|
|
}
|
|
|
|
// handleHealthCheck returns server health status
|
|
func (s *Manager) handleHealthCheck(w http.ResponseWriter, r *http.Request) {
|
|
// Check database connectivity
|
|
dbHealthy := true
|
|
if err := s.db.Ping(); err != nil {
|
|
dbHealthy = false
|
|
log.Printf("Health check: database ping failed: %v", err)
|
|
}
|
|
|
|
// Count online runners (based on recent heartbeat)
|
|
var runnerCount int
|
|
s.db.With(func(conn *sql.DB) error {
|
|
return conn.QueryRow(
|
|
`SELECT COUNT(*) FROM runners WHERE status = ?`,
|
|
types.RunnerStatusOnline,
|
|
).Scan(&runnerCount)
|
|
})
|
|
|
|
// Count connected clients
|
|
s.clientConnsMu.RLock()
|
|
clientCount := len(s.clientConns)
|
|
s.clientConnsMu.RUnlock()
|
|
|
|
// Calculate uptime
|
|
uptime := time.Since(s.startTime)
|
|
|
|
// Get memory stats
|
|
var memStats runtime.MemStats
|
|
runtime.ReadMemStats(&memStats)
|
|
|
|
status := "healthy"
|
|
statusCode := http.StatusOK
|
|
if !dbHealthy {
|
|
status = "degraded"
|
|
statusCode = http.StatusServiceUnavailable
|
|
}
|
|
|
|
response := map[string]interface{}{
|
|
"status": status,
|
|
"uptime_seconds": int64(uptime.Seconds()),
|
|
"database": dbHealthy,
|
|
"connected_runners": runnerCount,
|
|
"connected_clients": clientCount,
|
|
"memory": map[string]interface{}{
|
|
"alloc_mb": memStats.Alloc / 1024 / 1024,
|
|
"total_alloc_mb": memStats.TotalAlloc / 1024 / 1024,
|
|
"sys_mb": memStats.Sys / 1024 / 1024,
|
|
"num_gc": memStats.NumGC,
|
|
},
|
|
"timestamp": time.Now().Unix(),
|
|
}
|
|
|
|
s.respondJSON(w, statusCode, response)
|
|
}
|
|
|
|
// Auth handlers
|
|
func (s *Manager) handleGoogleLogin(w http.ResponseWriter, r *http.Request) {
|
|
url, err := s.auth.GoogleLoginURL()
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, err.Error())
|
|
return
|
|
}
|
|
http.Redirect(w, r, url, http.StatusFound)
|
|
}
|
|
|
|
func (s *Manager) handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
|
|
code := r.URL.Query().Get("code")
|
|
if code == "" {
|
|
s.respondError(w, http.StatusBadRequest, "Missing code parameter")
|
|
return
|
|
}
|
|
|
|
session, err := s.auth.GoogleCallback(r.Context(), code)
|
|
if err != nil {
|
|
// If registration is disabled, redirect back to login with error
|
|
if err.Error() == "registration is disabled" {
|
|
http.Redirect(w, r, "/?error=registration_disabled", http.StatusFound)
|
|
return
|
|
}
|
|
s.respondError(w, http.StatusInternalServerError, err.Error())
|
|
return
|
|
}
|
|
|
|
sessionID := s.auth.CreateSession(session)
|
|
http.SetCookie(w, createSessionCookie(sessionID))
|
|
|
|
http.Redirect(w, r, "/", http.StatusFound)
|
|
}
|
|
|
|
func (s *Manager) handleDiscordLogin(w http.ResponseWriter, r *http.Request) {
|
|
url, err := s.auth.DiscordLoginURL()
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, err.Error())
|
|
return
|
|
}
|
|
http.Redirect(w, r, url, http.StatusFound)
|
|
}
|
|
|
|
func (s *Manager) handleDiscordCallback(w http.ResponseWriter, r *http.Request) {
|
|
code := r.URL.Query().Get("code")
|
|
if code == "" {
|
|
s.respondError(w, http.StatusBadRequest, "Missing code parameter")
|
|
return
|
|
}
|
|
|
|
session, err := s.auth.DiscordCallback(r.Context(), code)
|
|
if err != nil {
|
|
// If registration is disabled, redirect back to login with error
|
|
if err.Error() == "registration is disabled" {
|
|
http.Redirect(w, r, "/?error=registration_disabled", http.StatusFound)
|
|
return
|
|
}
|
|
s.respondError(w, http.StatusInternalServerError, err.Error())
|
|
return
|
|
}
|
|
|
|
sessionID := s.auth.CreateSession(session)
|
|
http.SetCookie(w, createSessionCookie(sessionID))
|
|
|
|
http.Redirect(w, r, "/", http.StatusFound)
|
|
}
|
|
|
|
func (s *Manager) handleLogout(w http.ResponseWriter, r *http.Request) {
|
|
cookie, err := r.Cookie("session_id")
|
|
if err == nil {
|
|
s.auth.DeleteSession(cookie.Value)
|
|
}
|
|
// Create an expired cookie to clear the session
|
|
expiredCookie := &http.Cookie{
|
|
Name: "session_id",
|
|
Value: "",
|
|
Path: "/",
|
|
MaxAge: -1,
|
|
HttpOnly: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
}
|
|
// Use s.cfg.IsProductionMode() for consistency with other server methods
|
|
if s.cfg.IsProductionMode() {
|
|
expiredCookie.Secure = true
|
|
}
|
|
http.SetCookie(w, expiredCookie)
|
|
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Logged out"})
|
|
}
|
|
|
|
func (s *Manager) handleGetMe(w http.ResponseWriter, r *http.Request) {
|
|
cookie, err := r.Cookie("session_id")
|
|
if err != nil {
|
|
log.Printf("Authentication failed: missing session cookie in /auth/me")
|
|
s.respondError(w, http.StatusUnauthorized, "Not authenticated")
|
|
return
|
|
}
|
|
|
|
session, ok := s.auth.GetSession(cookie.Value)
|
|
if !ok {
|
|
log.Printf("Authentication failed: invalid session cookie in /auth/me")
|
|
s.respondError(w, http.StatusUnauthorized, "Invalid session")
|
|
return
|
|
}
|
|
|
|
s.respondJSON(w, http.StatusOK, map[string]interface{}{
|
|
"id": session.UserID,
|
|
"email": session.Email,
|
|
"name": session.Name,
|
|
"is_admin": session.IsAdmin,
|
|
})
|
|
}
|
|
|
|
func (s *Manager) handleGetAuthProviders(w http.ResponseWriter, r *http.Request) {
|
|
s.respondJSON(w, http.StatusOK, map[string]bool{
|
|
"google": s.auth.IsGoogleOAuthConfigured(),
|
|
"discord": s.auth.IsDiscordOAuthConfigured(),
|
|
"local": s.auth.IsLocalLoginEnabled(),
|
|
})
|
|
}
|
|
|
|
func (s *Manager) handleLocalLoginAvailable(w http.ResponseWriter, r *http.Request) {
|
|
s.respondJSON(w, http.StatusOK, map[string]bool{
|
|
"available": s.auth.IsLocalLoginEnabled(),
|
|
})
|
|
}
|
|
|
|
func (s *Manager) handleLocalRegister(w http.ResponseWriter, r *http.Request) {
|
|
var req struct {
|
|
Email string `json:"email"`
|
|
Name string `json:"name"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
|
|
return
|
|
}
|
|
|
|
if req.Email == "" || req.Name == "" || req.Password == "" {
|
|
s.respondError(w, http.StatusBadRequest, "Email, name, and password are required")
|
|
return
|
|
}
|
|
|
|
if len(req.Password) < 8 {
|
|
s.respondError(w, http.StatusBadRequest, "Password must be at least 8 characters long")
|
|
return
|
|
}
|
|
|
|
session, err := s.auth.RegisterLocalUser(req.Email, req.Name, req.Password)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
sessionID := s.auth.CreateSession(session)
|
|
http.SetCookie(w, createSessionCookie(sessionID))
|
|
|
|
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
|
|
"message": "Registration successful",
|
|
"user": map[string]interface{}{
|
|
"id": session.UserID,
|
|
"email": session.Email,
|
|
"name": session.Name,
|
|
"is_admin": session.IsAdmin,
|
|
},
|
|
})
|
|
}
|
|
|
|
func (s *Manager) handleLocalLogin(w http.ResponseWriter, r *http.Request) {
|
|
var req struct {
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
|
|
return
|
|
}
|
|
|
|
if req.Username == "" || req.Password == "" {
|
|
s.respondError(w, http.StatusBadRequest, "Username and password are required")
|
|
return
|
|
}
|
|
|
|
session, err := s.auth.LocalLogin(req.Username, req.Password)
|
|
if err != nil {
|
|
log.Printf("Authentication failed: invalid credentials for username '%s'", req.Username)
|
|
s.respondError(w, http.StatusUnauthorized, "Invalid credentials")
|
|
return
|
|
}
|
|
|
|
sessionID := s.auth.CreateSession(session)
|
|
http.SetCookie(w, createSessionCookie(sessionID))
|
|
|
|
s.respondJSON(w, http.StatusOK, map[string]interface{}{
|
|
"message": "Login successful",
|
|
"user": map[string]interface{}{
|
|
"id": session.UserID,
|
|
"email": session.Email,
|
|
"name": session.Name,
|
|
"is_admin": session.IsAdmin,
|
|
},
|
|
})
|
|
}
|
|
|
|
func (s *Manager) handleChangePassword(w http.ResponseWriter, r *http.Request) {
|
|
userID, err := getUserID(r)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusUnauthorized, err.Error())
|
|
return
|
|
}
|
|
|
|
var req struct {
|
|
OldPassword string `json:"old_password"`
|
|
NewPassword string `json:"new_password"`
|
|
TargetUserID *int64 `json:"target_user_id,omitempty"` // For admin to change other users' passwords
|
|
}
|
|
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
|
|
return
|
|
}
|
|
|
|
if req.NewPassword == "" {
|
|
s.respondError(w, http.StatusBadRequest, "New password is required")
|
|
return
|
|
}
|
|
|
|
if len(req.NewPassword) < 8 {
|
|
s.respondError(w, http.StatusBadRequest, "Password must be at least 8 characters long")
|
|
return
|
|
}
|
|
|
|
isAdmin := authpkg.IsAdmin(r.Context())
|
|
|
|
// If target_user_id is provided and user is admin, allow changing other user's password
|
|
if req.TargetUserID != nil && isAdmin {
|
|
if err := s.auth.AdminChangePassword(*req.TargetUserID, req.NewPassword); err != nil {
|
|
s.respondError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Password changed successfully"})
|
|
return
|
|
}
|
|
|
|
// Otherwise, user is changing their own password (requires old password)
|
|
if req.OldPassword == "" {
|
|
s.respondError(w, http.StatusBadRequest, "Old password is required")
|
|
return
|
|
}
|
|
|
|
if err := s.auth.ChangePassword(userID, req.OldPassword, req.NewPassword); err != nil {
|
|
s.respondError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Password changed successfully"})
|
|
}
|
|
|
|
// Helper to get user ID from context
|
|
func getUserID(r *http.Request) (int64, error) {
|
|
userID, ok := authpkg.GetUserID(r.Context())
|
|
if !ok {
|
|
return 0, fmt.Errorf("user ID not found in context")
|
|
}
|
|
return userID, nil
|
|
}
|
|
|
|
// Helper to parse ID from URL
|
|
func parseID(r *http.Request, param string) (int64, error) {
|
|
idStr := chi.URLParam(r, param)
|
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("invalid ID: %s", idStr)
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
// StartBackgroundTasks starts background goroutines for error recovery
|
|
func (s *Manager) StartBackgroundTasks() {
|
|
go s.recoverStuckTasks()
|
|
go s.cleanupOldRenderJobs()
|
|
go s.cleanupOldTempDirectories()
|
|
go s.cleanupOldOfflineRunners()
|
|
go s.cleanupOldUploadSessions()
|
|
}
|
|
|
|
// recoverRunnersOnStartup marks runners as offline on startup
|
|
// In the polling model, runners will update their status when they poll for jobs
|
|
func (s *Manager) recoverRunnersOnStartup() {
|
|
log.Printf("Recovering runners on startup: marking all as offline...")
|
|
|
|
// Mark all runners as offline - they'll be marked online when they poll
|
|
var runnersAffected int64
|
|
err := s.db.With(func(conn *sql.DB) error {
|
|
result, err := conn.Exec(
|
|
`UPDATE runners SET status = ? WHERE status = ?`,
|
|
types.RunnerStatusOffline, types.RunnerStatusOnline,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
runnersAffected, _ = result.RowsAffected()
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
log.Printf("Failed to mark runners as offline on startup: %v", err)
|
|
return
|
|
}
|
|
|
|
if runnersAffected > 0 {
|
|
log.Printf("Marked %d runners as offline on startup", runnersAffected)
|
|
}
|
|
|
|
// Reset any running tasks that were assigned to runners
|
|
// They will be picked up by runners when they poll
|
|
var tasksAffected int64
|
|
err = s.db.With(func(conn *sql.DB) error {
|
|
result, err := conn.Exec(
|
|
`UPDATE tasks SET runner_id = NULL, status = ?, started_at = NULL
|
|
WHERE status = ?`,
|
|
types.TaskStatusPending, types.TaskStatusRunning,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
tasksAffected, _ = result.RowsAffected()
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
log.Printf("Failed to reset running tasks on startup: %v", err)
|
|
return
|
|
}
|
|
|
|
if tasksAffected > 0 {
|
|
log.Printf("Reset %d running tasks to pending on startup", tasksAffected)
|
|
}
|
|
}
|
|
|
|
// recoverStuckTasks periodically checks for dead runners and stuck tasks
|
|
func (s *Manager) recoverStuckTasks() {
|
|
ticker := time.NewTicker(TaskDistributionInterval)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("Panic in recoverStuckTasks: %v", r)
|
|
}
|
|
}()
|
|
|
|
// Find dead runners (no heartbeat for configured timeout)
|
|
// In polling model, heartbeat is updated when runner polls for jobs
|
|
var deadRunnerIDs []int64
|
|
cutoffTime := time.Now().Add(-RunnerHeartbeatTimeout)
|
|
err := s.db.With(func(conn *sql.DB) error {
|
|
rows, err := conn.Query(
|
|
`SELECT id FROM runners
|
|
WHERE last_heartbeat < ?
|
|
AND status = ?`,
|
|
cutoffTime, types.RunnerStatusOnline,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var runnerID int64
|
|
if err := rows.Scan(&runnerID); err == nil {
|
|
deadRunnerIDs = append(deadRunnerIDs, runnerID)
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
log.Printf("Failed to query dead runners: %v", err)
|
|
return
|
|
}
|
|
|
|
// Reset tasks assigned to dead runners
|
|
for _, runnerID := range deadRunnerIDs {
|
|
s.resetRunnerTasks(runnerID)
|
|
|
|
// Mark runner as offline
|
|
s.db.With(func(conn *sql.DB) error {
|
|
_, _ = conn.Exec(
|
|
`UPDATE runners SET status = ? WHERE id = ?`,
|
|
types.RunnerStatusOffline, runnerID,
|
|
)
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// Check for task timeouts
|
|
s.recoverTaskTimeouts()
|
|
}()
|
|
}
|
|
}
|
|
|
|
// recoverTaskTimeouts handles tasks that have exceeded their timeout
|
|
// Timeouts are treated as runner failures (not task failures) and retry indefinitely
|
|
func (s *Manager) recoverTaskTimeouts() {
|
|
// Find tasks running longer than their timeout
|
|
var tasks []struct {
|
|
taskID int64
|
|
jobID int64
|
|
runnerID sql.NullInt64
|
|
timeoutSeconds sql.NullInt64
|
|
startedAt time.Time
|
|
}
|
|
|
|
err := s.db.With(func(conn *sql.DB) error {
|
|
rows, err := conn.Query(
|
|
`SELECT t.id, t.job_id, t.runner_id, t.timeout_seconds, t.started_at
|
|
FROM tasks t
|
|
WHERE t.status = ?
|
|
AND t.started_at IS NOT NULL
|
|
AND (t.completed_at IS NULL OR t.completed_at < datetime('now', '-30 seconds'))
|
|
AND (t.timeout_seconds IS NULL OR
|
|
(julianday('now') - julianday(t.started_at)) * 86400 > t.timeout_seconds)`,
|
|
types.TaskStatusRunning,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var task struct {
|
|
taskID int64
|
|
jobID int64
|
|
runnerID sql.NullInt64
|
|
timeoutSeconds sql.NullInt64
|
|
startedAt time.Time
|
|
}
|
|
err := rows.Scan(&task.taskID, &task.jobID, &task.runnerID, &task.timeoutSeconds, &task.startedAt)
|
|
if err != nil {
|
|
log.Printf("Failed to scan task row in recoverTaskTimeouts: %v", err)
|
|
continue
|
|
}
|
|
tasks = append(tasks, task)
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
log.Printf("Failed to query timed out tasks: %v", err)
|
|
return
|
|
}
|
|
|
|
for _, task := range tasks {
|
|
taskID := task.taskID
|
|
jobID := task.jobID
|
|
timeoutSeconds := task.timeoutSeconds
|
|
startedAt := task.startedAt
|
|
|
|
// Use default timeout if not set (5 minutes for frame tasks, 24 hours for FFmpeg)
|
|
timeout := 300 // 5 minutes default
|
|
if timeoutSeconds.Valid {
|
|
timeout = int(timeoutSeconds.Int64)
|
|
}
|
|
|
|
// Check if actually timed out
|
|
if time.Since(startedAt).Seconds() < float64(timeout) {
|
|
continue
|
|
}
|
|
|
|
// Timeouts are runner failures - always reset to pending and increment runner_failure_count
|
|
// This does NOT count against retry_count (which is for actual task failures like Blender crashes)
|
|
err = s.db.With(func(conn *sql.DB) error {
|
|
_, err := conn.Exec(`UPDATE tasks SET status = ? WHERE id = ?`, types.TaskStatusPending, taskID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = conn.Exec(`UPDATE tasks SET runner_id = NULL WHERE id = ?`, taskID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = conn.Exec(`UPDATE tasks SET current_step = NULL WHERE id = ?`, taskID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = conn.Exec(`UPDATE tasks SET started_at = NULL WHERE id = ?`, taskID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = conn.Exec(`UPDATE tasks SET runner_failure_count = runner_failure_count + 1 WHERE id = ?`, taskID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Clear steps and logs for fresh retry
|
|
_, err = conn.Exec(`DELETE FROM task_steps WHERE task_id = ?`, taskID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = conn.Exec(`DELETE FROM task_logs WHERE task_id = ?`, taskID)
|
|
return err
|
|
})
|
|
if err == nil {
|
|
// Broadcast task reset to clients (includes steps_cleared and logs_cleared flags)
|
|
s.broadcastTaskUpdate(jobID, taskID, "task_reset", map[string]interface{}{
|
|
"status": types.TaskStatusPending,
|
|
"runner_id": nil,
|
|
"current_step": nil,
|
|
"started_at": nil,
|
|
"steps_cleared": true,
|
|
"logs_cleared": true,
|
|
})
|
|
|
|
// Update job status
|
|
s.updateJobStatusFromTasks(jobID)
|
|
|
|
log.Printf("Reset timed out task %d: %v", taskID, err)
|
|
} else {
|
|
log.Printf("Failed to reset timed out task %d: %v", taskID, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// cleanupOldTempDirectories periodically cleans up old temporary directories
|
|
func (s *Manager) cleanupOldTempDirectories() {
|
|
// Run cleanup every hour
|
|
ticker := time.NewTicker(1 * time.Hour)
|
|
defer ticker.Stop()
|
|
|
|
// Run once immediately on startup
|
|
s.cleanupOldTempDirectoriesOnce()
|
|
|
|
for range ticker.C {
|
|
s.cleanupOldTempDirectoriesOnce()
|
|
}
|
|
}
|
|
|
|
// cleanupOldTempDirectoriesOnce removes temp directories older than 1 hour
|
|
func (s *Manager) cleanupOldTempDirectoriesOnce() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("Panic in cleanupOldTempDirectories: %v", r)
|
|
}
|
|
}()
|
|
|
|
tempPath := filepath.Join(s.storage.BasePath(), "temp")
|
|
|
|
// Check if temp directory exists
|
|
if _, err := os.Stat(tempPath); os.IsNotExist(err) {
|
|
return
|
|
}
|
|
|
|
// Read all entries in temp directory
|
|
entries, err := os.ReadDir(tempPath)
|
|
if err != nil {
|
|
log.Printf("Failed to read temp directory: %v", err)
|
|
return
|
|
}
|
|
|
|
now := time.Now()
|
|
cleanedCount := 0
|
|
|
|
// Check upload sessions to avoid deleting active uploads
|
|
s.uploadSessionsMu.RLock()
|
|
activeSessions := make(map[string]bool)
|
|
for sessionID := range s.uploadSessions {
|
|
activeSessions[sessionID] = true
|
|
}
|
|
s.uploadSessionsMu.RUnlock()
|
|
|
|
for _, entry := range entries {
|
|
if !entry.IsDir() {
|
|
continue
|
|
}
|
|
|
|
entryPath := filepath.Join(tempPath, entry.Name())
|
|
|
|
// Skip if this directory has an active upload session
|
|
if activeSessions[entryPath] {
|
|
continue
|
|
}
|
|
|
|
// Get directory info to check modification time
|
|
info, err := entry.Info()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
// Remove directories older than 1 hour (only if no active session)
|
|
age := now.Sub(info.ModTime())
|
|
if age > 1*time.Hour {
|
|
if err := os.RemoveAll(entryPath); err != nil {
|
|
log.Printf("Warning: Failed to clean up old temp directory %s: %v", entryPath, err)
|
|
} else {
|
|
cleanedCount++
|
|
log.Printf("Cleaned up old temp directory: %s (age: %v)", entryPath, age)
|
|
}
|
|
}
|
|
}
|
|
|
|
if cleanedCount > 0 {
|
|
log.Printf("Cleaned up %d old temp directories", cleanedCount)
|
|
}
|
|
}
|
|
|
|
// cleanupOldUploadSessions periodically cleans up abandoned upload sessions
|
|
func (s *Manager) cleanupOldUploadSessions() {
|
|
// Run cleanup every 10 minutes
|
|
ticker := time.NewTicker(10 * time.Minute)
|
|
defer ticker.Stop()
|
|
|
|
// Run once immediately on startup
|
|
s.cleanupOldUploadSessionsOnce()
|
|
|
|
for range ticker.C {
|
|
s.cleanupOldUploadSessionsOnce()
|
|
}
|
|
}
|
|
|
|
// cleanupOldUploadSessionsOnce removes upload sessions older than 1 hour
|
|
func (s *Manager) cleanupOldUploadSessionsOnce() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("Panic in cleanupOldUploadSessions: %v", r)
|
|
}
|
|
}()
|
|
|
|
s.uploadSessionsMu.Lock()
|
|
defer s.uploadSessionsMu.Unlock()
|
|
|
|
now := time.Now()
|
|
cleanedCount := 0
|
|
|
|
for sessionID, session := range s.uploadSessions {
|
|
// Remove sessions older than 1 hour
|
|
age := now.Sub(session.CreatedAt)
|
|
if age > 1*time.Hour {
|
|
delete(s.uploadSessions, sessionID)
|
|
cleanedCount++
|
|
log.Printf("Cleaned up abandoned upload session: %s (user: %d, status: %s, age: %v)",
|
|
sessionID, session.UserID, session.Status, age)
|
|
}
|
|
}
|
|
|
|
if cleanedCount > 0 {
|
|
log.Printf("Cleaned up %d abandoned upload sessions", cleanedCount)
|
|
}
|
|
}
|