Refactor runner and installation scripts for improved functionality

- Removed the `--disable-hiprt` flag from the runner command, simplifying the rendering options for users.
- Updated the `jiggablend-runner` script and README to reflect the removal of the HIPRT control flag, enhancing clarity in usage instructions.
- Enhanced the installation script to provide clearer examples for running the jiggablend manager and runner, improving user experience during setup.
- Implemented a more robust GPU backend detection mechanism, allowing for better compatibility with various hardware configurations.
This commit is contained in:
2026-03-14 21:08:06 -05:00
parent 28cb50492c
commit 16d6a95058
30 changed files with 1041 additions and 782 deletions

View File

@@ -3,7 +3,6 @@ package api
import (
"archive/tar"
"compress/bzip2"
"compress/gzip"
"fmt"
"io"
"log"
@@ -16,6 +15,8 @@ import (
"strings"
"sync"
"time"
"jiggablend/pkg/blendfile"
)
const (
@@ -439,144 +440,16 @@ func (s *Manager) cleanupExtractedBlenderFolders(blenderDir string, version *Ble
}
}
// ParseBlenderVersionFromFile parses the Blender version that a .blend file was saved with
// This reads the file header to determine the version
// ParseBlenderVersionFromFile parses the Blender version that a .blend file was saved with.
// Delegates to the shared pkg/blendfile implementation.
func ParseBlenderVersionFromFile(blendPath string) (major, minor int, err error) {
file, err := os.Open(blendPath)
if err != nil {
return 0, 0, fmt.Errorf("failed to open blend file: %w", err)
}
defer file.Close()
return ParseBlenderVersionFromReader(file)
return blendfile.ParseVersionFromFile(blendPath)
}
// ParseBlenderVersionFromReader parses the Blender version from a reader
// Useful for reading from uploaded files without saving to disk first
// ParseBlenderVersionFromReader parses the Blender version from a reader.
// Delegates to the shared pkg/blendfile implementation.
func ParseBlenderVersionFromReader(r io.ReadSeeker) (major, minor int, err error) {
// Read the first 12 bytes of the blend file header
// Format: BLENDER-v<major><minor><patch> or BLENDER_v<major><minor><patch>
// The header is: "BLENDER" (7 bytes) + pointer size (1 byte: '-' for 64-bit, '_' for 32-bit)
// + endianness (1 byte: 'v' for little-endian, 'V' for big-endian)
// + version (3 bytes: e.g., "402" for 4.02)
header := make([]byte, 12)
n, err := r.Read(header)
if err != nil || n < 12 {
return 0, 0, fmt.Errorf("failed to read blend file header: %w", err)
}
// Check for BLENDER magic
if string(header[:7]) != "BLENDER" {
// Might be compressed - try to decompress
r.Seek(0, 0)
return parseCompressedBlendVersion(r)
}
// Parse version from bytes 9-11 (3 digits)
versionStr := string(header[9:12])
var vMajor, vMinor int
// Version format changed in Blender 3.0
// Pre-3.0: "279" = 2.79, "280" = 2.80
// 3.0+: "300" = 3.0, "402" = 4.02, "410" = 4.10
if len(versionStr) == 3 {
// First digit is major version
fmt.Sscanf(string(versionStr[0]), "%d", &vMajor)
// Next two digits are minor version
fmt.Sscanf(versionStr[1:3], "%d", &vMinor)
}
return vMajor, vMinor, nil
}
// parseCompressedBlendVersion handles gzip and zstd compressed blend files
func parseCompressedBlendVersion(r io.ReadSeeker) (major, minor int, err error) {
// Check for compression magic bytes
magic := make([]byte, 4)
if _, err := r.Read(magic); err != nil {
return 0, 0, err
}
r.Seek(0, 0)
if magic[0] == 0x1f && magic[1] == 0x8b {
// gzip compressed
gzReader, err := gzip.NewReader(r)
if err != nil {
return 0, 0, fmt.Errorf("failed to create gzip reader: %w", err)
}
defer gzReader.Close()
header := make([]byte, 12)
n, err := gzReader.Read(header)
if err != nil || n < 12 {
return 0, 0, fmt.Errorf("failed to read compressed blend header: %w", err)
}
if string(header[:7]) != "BLENDER" {
return 0, 0, fmt.Errorf("invalid blend file format")
}
versionStr := string(header[9:12])
var vMajor, vMinor int
if len(versionStr) == 3 {
fmt.Sscanf(string(versionStr[0]), "%d", &vMajor)
fmt.Sscanf(versionStr[1:3], "%d", &vMinor)
}
return vMajor, vMinor, nil
}
// Check for zstd magic (Blender 3.0+): 0x28 0xB5 0x2F 0xFD
if magic[0] == 0x28 && magic[1] == 0xb5 && magic[2] == 0x2f && magic[3] == 0xfd {
return parseZstdBlendVersion(r)
}
return 0, 0, fmt.Errorf("unknown blend file format")
}
// parseZstdBlendVersion handles zstd-compressed blend files (Blender 3.0+)
// Uses zstd command line tool since Go doesn't have native zstd support
func parseZstdBlendVersion(r io.ReadSeeker) (major, minor int, err error) {
r.Seek(0, 0)
// We need to decompress just enough to read the header
// Use zstd command to decompress from stdin
cmd := exec.Command("zstd", "-d", "-c")
cmd.Stdin = r
stdout, err := cmd.StdoutPipe()
if err != nil {
return 0, 0, fmt.Errorf("failed to create zstd stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return 0, 0, fmt.Errorf("failed to start zstd decompression: %w", err)
}
// Read just the header (12 bytes)
header := make([]byte, 12)
n, readErr := io.ReadFull(stdout, header)
// Kill the process early - we only need the header
cmd.Process.Kill()
cmd.Wait()
if readErr != nil || n < 12 {
return 0, 0, fmt.Errorf("failed to read zstd compressed blend header: %v", readErr)
}
if string(header[:7]) != "BLENDER" {
return 0, 0, fmt.Errorf("invalid blend file format in zstd archive")
}
versionStr := string(header[9:12])
var vMajor, vMinor int
if len(versionStr) == 3 {
fmt.Sscanf(string(versionStr[0]), "%d", &vMajor)
fmt.Sscanf(versionStr[1:3], "%d", &vMinor)
}
return vMajor, vMinor, nil
return blendfile.ParseVersionFromReader(r)
}
// handleGetBlenderVersions returns available Blender versions
@@ -713,7 +586,7 @@ func (s *Manager) handleDownloadBlender(w http.ResponseWriter, r *http.Request)
tarFilename = strings.TrimSuffix(tarFilename, ".bz2")
w.Header().Set("Content-Type", "application/x-tar")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", tarFilename))
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", tarFilename))
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
w.Header().Set("X-Blender-Version", blenderVersion.Full)

View File

@@ -345,9 +345,9 @@ func (s *Manager) handleCreateJob(w http.ResponseWriter, r *http.Request) {
// Only create render tasks for render jobs
if req.JobType == types.JobTypeRender {
// Determine task timeout based on output format
taskTimeout := RenderTimeout // 1 hour for render jobs
taskTimeout := s.renderTimeout
if *req.OutputFormat == "EXR_264_MP4" || *req.OutputFormat == "EXR_AV1_MP4" || *req.OutputFormat == "EXR_VP9_WEBM" {
taskTimeout = VideoEncodeTimeout // 24 hours for encoding
taskTimeout = s.videoEncodeTimeout
}
// Create tasks for the job (batch INSERT in a single transaction)
@@ -390,7 +390,7 @@ func (s *Manager) handleCreateJob(w http.ResponseWriter, r *http.Request) {
// Create encode task immediately if output format requires it
// The task will have a condition that prevents it from being assigned until all render tasks are completed
if *req.OutputFormat == "EXR_264_MP4" || *req.OutputFormat == "EXR_AV1_MP4" || *req.OutputFormat == "EXR_VP9_WEBM" {
encodeTaskTimeout := VideoEncodeTimeout // 24 hours for encoding
encodeTaskTimeout := s.videoEncodeTimeout
conditionJSON := `{"type": "all_render_tasks_completed"}`
var encodeTaskID int64
err = s.db.With(func(conn *sql.DB) error {
@@ -2592,7 +2592,7 @@ func (s *Manager) handleDownloadJobFile(w http.ResponseWriter, r *http.Request)
}
// Set headers
w.Header().Set("Content-Disposition", fmt.Sprintf("%s; filename=%s", disposition, fileName))
w.Header().Set("Content-Disposition", fmt.Sprintf("%s; filename=%q", disposition, fileName))
w.Header().Set("Content-Type", contentType)
// Stream file
@@ -2710,7 +2710,7 @@ func (s *Manager) handleDownloadEXRZip(w http.ResponseWriter, r *http.Request) {
fileName := fmt.Sprintf("%s-exr.zip", safeJobName)
w.Header().Set("Content-Type", "application/zip")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", fileName))
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", fileName))
zipWriter := zip.NewWriter(w)
defer zipWriter.Close()
@@ -2881,7 +2881,7 @@ func (s *Manager) handlePreviewEXR(w http.ResponseWriter, r *http.Request) {
// Set headers
pngFileName := strings.TrimSuffix(fileName, filepath.Ext(fileName)) + ".png"
w.Header().Set("Content-Disposition", fmt.Sprintf("inline; filename=%s", pngFileName))
w.Header().Set("Content-Disposition", fmt.Sprintf("inline; filename=%q", pngFileName))
w.Header().Set("Content-Type", "image/png")
w.Header().Set("Content-Length", strconv.Itoa(len(pngData)))

View File

@@ -30,27 +30,22 @@ import (
"github.com/gorilla/websocket"
)
// Configuration constants
// Configuration constants (non-configurable infrastructure values)
const (
// WebSocket timeouts
WSReadDeadline = 90 * time.Second
WSPingInterval = 30 * time.Second
WSWriteDeadline = 10 * time.Second
// Task timeouts
RenderTimeout = 60 * 60 // 1 hour for frame rendering
VideoEncodeTimeout = 60 * 60 * 24 // 24 hours for encoding
// Limits
MaxUploadSize = 50 << 30 // 50 GB
// Infrastructure timers
RunnerHeartbeatTimeout = 90 * time.Second
TaskDistributionInterval = 10 * time.Second
ProgressUpdateThrottle = 2 * time.Second
// Cookie settings
SessionCookieMaxAge = 86400 // 24 hours
)
// Operational limits are loaded from database config at Manager initialization.
// Defaults are defined in internal/config/config.go convenience methods.
// Manager represents the manager server
type Manager struct {
db *database.DB
@@ -109,6 +104,12 @@ type Manager struct {
// Server start time for health checks
startTime time.Time
// Configurable operational values loaded from config
renderTimeout int // seconds
videoEncodeTimeout int // seconds
maxUploadSize int64 // bytes
sessionCookieMaxAge int // seconds
}
// ClientConnection represents a client WebSocket connection with subscriptions
@@ -166,6 +167,11 @@ func NewManager(db *database.DB, cfg *config.Config, auth *authpkg.Auth, storage
router: chi.NewRouter(),
ui: ui,
startTime: time.Now(),
renderTimeout: cfg.RenderTimeoutSeconds(),
videoEncodeTimeout: cfg.EncodeTimeoutSeconds(),
maxUploadSize: cfg.MaxUploadBytes(),
sessionCookieMaxAge: cfg.SessionCookieMaxAgeSec(),
wsUpgrader: websocket.Upgrader{
CheckOrigin: checkWebSocketOrigin,
ReadBufferSize: 1024,
@@ -189,6 +195,10 @@ func NewManager(db *database.DB, cfg *config.Config, auth *authpkg.Auth, storage
jobStatusUpdateMu: make(map[int64]*sync.Mutex),
}
// Initialize rate limiters from config
apiRateLimiter = NewRateLimiter(cfg.APIRateLimitPerMinute(), time.Minute)
authRateLimiter = NewRateLimiter(cfg.AuthRateLimitPerMinute(), time.Minute)
// Check for required external tools
if err := s.checkRequiredTools(); err != nil {
return nil, err
@@ -267,6 +277,7 @@ type RateLimiter struct {
mu sync.RWMutex
limit int // max requests
window time.Duration // time window
stopChan chan struct{}
}
// NewRateLimiter creates a new rate limiter
@@ -275,12 +286,17 @@ func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
requests: make(map[string][]time.Time),
limit: limit,
window: window,
stopChan: make(chan struct{}),
}
// Start cleanup goroutine
go rl.cleanup()
return rl
}
// Stop shuts down the cleanup goroutine.
func (rl *RateLimiter) Stop() {
close(rl.stopChan)
}
// Allow checks if a request from the given IP is allowed
func (rl *RateLimiter) Allow(ip string) bool {
rl.mu.Lock()
@@ -313,32 +329,37 @@ func (rl *RateLimiter) Allow(ip string) bool {
// cleanup periodically removes old entries
func (rl *RateLimiter) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
rl.mu.Lock()
cutoff := time.Now().Add(-rl.window)
for ip, reqs := range rl.requests {
validReqs := make([]time.Time, 0, len(reqs))
for _, t := range reqs {
if t.After(cutoff) {
validReqs = append(validReqs, t)
defer ticker.Stop()
for {
select {
case <-ticker.C:
rl.mu.Lock()
cutoff := time.Now().Add(-rl.window)
for ip, reqs := range rl.requests {
validReqs := make([]time.Time, 0, len(reqs))
for _, t := range reqs {
if t.After(cutoff) {
validReqs = append(validReqs, t)
}
}
if len(validReqs) == 0 {
delete(rl.requests, ip)
} else {
rl.requests[ip] = validReqs
}
}
if len(validReqs) == 0 {
delete(rl.requests, ip)
} else {
rl.requests[ip] = validReqs
}
rl.mu.Unlock()
case <-rl.stopChan:
return
}
rl.mu.Unlock()
}
}
// Global rate limiters for different endpoint types
// Rate limiters — initialized per Manager instance in NewManager.
var (
// General API rate limiter: 100 requests per minute per IP
apiRateLimiter = NewRateLimiter(100, time.Minute)
// Auth rate limiter: 10 requests per minute per IP (stricter for login attempts)
authRateLimiter = NewRateLimiter(10, time.Minute)
apiRateLimiter *RateLimiter
authRateLimiter *RateLimiter
)
// rateLimitMiddleware applies rate limiting based on client IP
@@ -610,17 +631,16 @@ func (s *Manager) respondError(w http.ResponseWriter, status int, message string
}
// createSessionCookie creates a secure session cookie with appropriate flags for the environment
func createSessionCookie(sessionID string) *http.Cookie {
func (s *Manager) createSessionCookie(sessionID string) *http.Cookie {
cookie := &http.Cookie{
Name: "session_id",
Value: sessionID,
Path: "/",
MaxAge: SessionCookieMaxAge,
MaxAge: s.sessionCookieMaxAge,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
// In production mode, set Secure flag to require HTTPS
if authpkg.IsProductionMode() {
cookie.Secure = true
}
@@ -712,7 +732,7 @@ func (s *Manager) handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, createSessionCookie(sessionID))
http.SetCookie(w, s.createSessionCookie(sessionID))
http.Redirect(w, r, "/", http.StatusFound)
}
@@ -745,7 +765,7 @@ func (s *Manager) handleDiscordCallback(w http.ResponseWriter, r *http.Request)
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, createSessionCookie(sessionID))
http.SetCookie(w, s.createSessionCookie(sessionID))
http.Redirect(w, r, "/", http.StatusFound)
}
@@ -838,7 +858,7 @@ func (s *Manager) handleLocalRegister(w http.ResponseWriter, r *http.Request) {
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, createSessionCookie(sessionID))
http.SetCookie(w, s.createSessionCookie(sessionID))
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
"message": "Registration successful",
@@ -875,7 +895,7 @@ func (s *Manager) handleLocalLogin(w http.ResponseWriter, r *http.Request) {
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, createSessionCookie(sessionID))
http.SetCookie(w, s.createSessionCookie(sessionID))
s.respondJSON(w, http.StatusOK, map[string]interface{}{
"message": "Login successful",

View File

@@ -3,6 +3,7 @@ package api
import (
"fmt"
"html/template"
"log"
"net/http"
"strings"
"time"
@@ -92,13 +93,17 @@ func newUIRenderer() (*uiRenderer, error) {
func (r *uiRenderer) render(w http.ResponseWriter, data pageData) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if err := r.templates.ExecuteTemplate(w, "base", data); err != nil {
log.Printf("Template render error: %v", err)
http.Error(w, "template render error", http.StatusInternalServerError)
return
}
}
func (r *uiRenderer) renderTemplate(w http.ResponseWriter, templateName string, data interface{}) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if err := r.templates.ExecuteTemplate(w, templateName, data); err != nil {
log.Printf("Template render error for %s: %v", templateName, err)
http.Error(w, "template render error", http.StatusInternalServerError)
return
}
}

View File

@@ -765,7 +765,7 @@ func (s *Manager) handleDownloadJobContext(w http.ResponseWriter, r *http.Reques
// Set appropriate headers for tar file
w.Header().Set("Content-Type", "application/x-tar")
w.Header().Set("Content-Disposition", "attachment; filename=context.tar")
w.Header().Set("Content-Disposition", "attachment; filename=\"context.tar\"")
// Stream the file to the response
io.Copy(w, file)
@@ -821,7 +821,7 @@ func (s *Manager) handleDownloadJobContextWithToken(w http.ResponseWriter, r *ht
// Set appropriate headers for tar file
w.Header().Set("Content-Type", "application/x-tar")
w.Header().Set("Content-Disposition", "attachment; filename=context.tar")
w.Header().Set("Content-Disposition", "attachment; filename=\"context.tar\"")
// Stream the file to the response
io.Copy(w, file)
@@ -836,7 +836,7 @@ func (s *Manager) handleUploadFileFromRunner(w http.ResponseWriter, r *http.Requ
return
}
err = r.ParseMultipartForm(MaxUploadSize) // 50 GB (for large output files)
err = r.ParseMultipartForm(s.maxUploadSize)
if err != nil {
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Failed to parse multipart form: %v", err))
return
@@ -944,7 +944,7 @@ func (s *Manager) handleUploadFileWithToken(w http.ResponseWriter, r *http.Reque
return
}
err = r.ParseMultipartForm(MaxUploadSize) // 50 GB (for large output files)
err = r.ParseMultipartForm(s.maxUploadSize)
if err != nil {
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Failed to parse multipart form: %v", err))
return
@@ -1228,7 +1228,7 @@ func (s *Manager) handleDownloadFileForRunner(w http.ResponseWriter, r *http.Req
// Set headers
w.Header().Set("Content-Type", contentType)
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", decodedFileName))
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", decodedFileName))
// Stream file
io.Copy(w, file)
@@ -1476,70 +1476,49 @@ func (s *Manager) handleRunnerJobWebSocket(w http.ResponseWriter, r *http.Reques
}
}
case "runner_heartbeat":
// Lookup runner ID from job's assigned_runner_id
var assignedRunnerID sql.NullInt64
err := s.db.With(func(db *sql.DB) error {
return db.QueryRow(
"SELECT assigned_runner_id FROM jobs WHERE id = ?",
jobID,
).Scan(&assignedRunnerID)
})
if err != nil {
log.Printf("Failed to lookup runner for job %d heartbeat: %v", jobID, err)
// Send error response
response := map[string]interface{}{
"type": "error",
"message": "Failed to process heartbeat",
}
s.sendWebSocketMessage(conn, response)
continue
}
if !assignedRunnerID.Valid {
log.Printf("Job %d has no assigned runner, skipping heartbeat update", jobID)
// Send acknowledgment but no database update
response := map[string]interface{}{
"type": "heartbeat_ack",
"timestamp": time.Now().Unix(),
"message": "No assigned runner for this job",
}
s.sendWebSocketMessage(conn, response)
continue
}
runnerID := assignedRunnerID.Int64
// Update runner heartbeat
err = s.db.With(func(db *sql.DB) error {
_, err := db.Exec(
"UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?",
time.Now(), types.RunnerStatusOnline, runnerID,
)
return err
})
if err != nil {
log.Printf("Failed to update runner %d heartbeat for job %d: %v", runnerID, jobID, err)
// Send error response
response := map[string]interface{}{
"type": "error",
"message": "Failed to update heartbeat",
}
s.sendWebSocketMessage(conn, response)
continue
}
// Send acknowledgment
response := map[string]interface{}{
"type": "heartbeat_ack",
"timestamp": time.Now().Unix(),
}
s.sendWebSocketMessage(conn, response)
s.handleWSRunnerHeartbeat(conn, jobID)
continue
}
}
}
// handleWSRunnerHeartbeat processes a runner heartbeat received over a job WebSocket.
func (s *Manager) handleWSRunnerHeartbeat(conn *websocket.Conn, jobID int64) {
var assignedRunnerID sql.NullInt64
err := s.db.With(func(db *sql.DB) error {
return db.QueryRow(
"SELECT assigned_runner_id FROM jobs WHERE id = ?", jobID,
).Scan(&assignedRunnerID)
})
if err != nil {
log.Printf("Failed to lookup runner for job %d heartbeat: %v", jobID, err)
s.sendWebSocketMessage(conn, map[string]interface{}{"type": "error", "message": "Failed to process heartbeat"})
return
}
if !assignedRunnerID.Valid {
log.Printf("Job %d has no assigned runner, skipping heartbeat update", jobID)
s.sendWebSocketMessage(conn, map[string]interface{}{"type": "heartbeat_ack", "timestamp": time.Now().Unix(), "message": "No assigned runner for this job"})
return
}
runnerID := assignedRunnerID.Int64
err = s.db.With(func(db *sql.DB) error {
_, err := db.Exec(
"UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?",
time.Now(), types.RunnerStatusOnline, runnerID,
)
return err
})
if err != nil {
log.Printf("Failed to update runner %d heartbeat for job %d: %v", runnerID, jobID, err)
s.sendWebSocketMessage(conn, map[string]interface{}{"type": "error", "message": "Failed to update heartbeat"})
return
}
s.sendWebSocketMessage(conn, map[string]interface{}{"type": "heartbeat_ack", "timestamp": time.Now().Unix()})
}
// handleWebSocketLog handles log entries from WebSocket
func (s *Manager) handleWebSocketLog(runnerID int64, logEntry WSLogEntry) {
// Store log in database
@@ -1948,241 +1927,226 @@ func (s *Manager) cleanupJobStatusUpdateMutex(jobID int64) {
// This function is serialized per jobID to prevent race conditions when multiple tasks
// complete concurrently and trigger status updates simultaneously.
func (s *Manager) updateJobStatusFromTasks(jobID int64) {
// Serialize updates per job to prevent race conditions
mu := s.getJobStatusUpdateMutex(jobID)
mu.Lock()
defer mu.Unlock()
now := time.Now()
// All jobs now use parallel runners (one task per frame), so we always use task-based progress
// Get current job status to detect changes
var currentStatus string
err := s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(`SELECT status FROM jobs WHERE id = ?`, jobID).Scan(&currentStatus)
})
currentStatus, err := s.getJobStatus(jobID)
if err != nil {
log.Printf("Failed to get current job status for job %d: %v", jobID, err)
return
}
// Cancellation is terminal from the user's perspective.
// Do not allow asynchronous task updates to revive cancelled jobs.
if currentStatus == string(types.JobStatusCancelled) {
return
}
// Count total tasks and completed tasks
var totalTasks, completedTasks int
err = s.db.With(func(conn *sql.DB) error {
err := conn.QueryRow(
counts, err := s.getJobTaskCounts(jobID)
if err != nil {
log.Printf("Failed to count tasks for job %d: %v", jobID, err)
return
}
progress := counts.progress()
if counts.pendingOrRunning == 0 && counts.total > 0 {
s.handleAllTasksFinished(jobID, currentStatus, counts, progress)
} else {
s.handleTasksInProgress(jobID, currentStatus, counts, progress)
}
}
// jobTaskCounts holds task state counts for a job.
type jobTaskCounts struct {
total int
completed int
pendingOrRunning int
failed int
running int
}
func (c *jobTaskCounts) progress() float64 {
if c.total == 0 {
return 0.0
}
return float64(c.completed) / float64(c.total) * 100.0
}
func (s *Manager) getJobStatus(jobID int64) (string, error) {
var status string
err := s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(`SELECT status FROM jobs WHERE id = ?`, jobID).Scan(&status)
})
return status, err
}
func (s *Manager) getJobTaskCounts(jobID int64) (*jobTaskCounts, error) {
c := &jobTaskCounts{}
err := s.db.With(func(conn *sql.DB) error {
if err := conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?, ?, ?)`,
jobID, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed,
).Scan(&totalTasks)
if err != nil {
).Scan(&c.total); err != nil {
return err
}
return conn.QueryRow(
if err := conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
jobID, types.TaskStatusCompleted,
).Scan(&completedTasks)
).Scan(&c.completed); err != nil {
return err
}
if err := conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?)`,
jobID, types.TaskStatusPending, types.TaskStatusRunning,
).Scan(&c.pendingOrRunning); err != nil {
return err
}
if err := conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
jobID, types.TaskStatusFailed,
).Scan(&c.failed); err != nil {
return err
}
if err := conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
jobID, types.TaskStatusRunning,
).Scan(&c.running); err != nil {
return err
}
return nil
})
if err != nil {
log.Printf("Failed to count completed tasks for job %d: %v", jobID, err)
return
}
// Calculate progress
var progress float64
if totalTasks == 0 {
// All tasks cancelled or no tasks, set progress to 0
progress = 0.0
} else {
// Standard task-based progress
progress = float64(completedTasks) / float64(totalTasks) * 100.0
}
return c, err
}
// handleAllTasksFinished handles the case where no pending/running tasks remain.
func (s *Manager) handleAllTasksFinished(jobID int64, currentStatus string, counts *jobTaskCounts, progress float64) {
now := time.Now()
var jobStatus string
// Check if all non-cancelled tasks are completed
var pendingOrRunningTasks int
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT COUNT(*) FROM tasks
WHERE job_id = ? AND status IN (?, ?)`,
jobID, types.TaskStatusPending, types.TaskStatusRunning,
).Scan(&pendingOrRunningTasks)
})
if err != nil {
log.Printf("Failed to count pending/running tasks for job %d: %v", jobID, err)
return
if counts.failed > 0 {
jobStatus = s.handleFailedTasks(jobID, currentStatus, &progress)
if jobStatus == "" {
return // retry handled; early exit
}
} else {
jobStatus = string(types.JobStatusCompleted)
progress = 100.0
}
if pendingOrRunningTasks == 0 && totalTasks > 0 {
// All tasks are either completed or failed/cancelled
// Check if any tasks failed
var failedTasks int
s.db.With(func(conn *sql.DB) error {
conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
jobID, types.TaskStatusFailed,
).Scan(&failedTasks)
return nil
})
s.setJobFinalStatus(jobID, currentStatus, jobStatus, progress, now, counts)
}
if failedTasks > 0 {
// Some tasks failed - check if job has retries left
var retryCount, maxRetries int
err := s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT retry_count, max_retries FROM jobs WHERE id = ?`,
jobID,
).Scan(&retryCount, &maxRetries)
})
if err != nil {
log.Printf("Failed to get retry info for job %d: %v", jobID, err)
// Fall back to marking job as failed
jobStatus = string(types.JobStatusFailed)
} else if retryCount < maxRetries {
// Job has retries left - reset failed tasks and redistribute
if err := s.resetFailedTasksAndRedistribute(jobID); err != nil {
log.Printf("Failed to reset failed tasks for job %d: %v", jobID, err)
// If reset fails, mark job as failed
jobStatus = string(types.JobStatusFailed)
} else {
// Tasks reset successfully - job remains in running/pending state
// Don't update job status, just update progress
jobStatus = currentStatus // Keep current status
// Recalculate progress after reset (failed tasks are now pending again)
var newTotalTasks, newCompletedTasks int
s.db.With(func(conn *sql.DB) error {
conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?, ?, ?)`,
jobID, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed,
).Scan(&newTotalTasks)
conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
jobID, types.TaskStatusCompleted,
).Scan(&newCompletedTasks)
return nil
})
if newTotalTasks > 0 {
progress = float64(newCompletedTasks) / float64(newTotalTasks) * 100.0
}
// Update progress only
err := s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`UPDATE jobs SET progress = ? WHERE id = ?`,
progress, jobID,
)
return err
})
if err != nil {
log.Printf("Failed to update job %d progress: %v", jobID, err)
} else {
// Broadcast job update via WebSocket
s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{
"status": jobStatus,
"progress": progress,
})
}
return // Exit early since we've handled the retry
}
} else {
// No retries left - mark job as failed and cancel active tasks
jobStatus = string(types.JobStatusFailed)
if err := s.cancelActiveTasksForJob(jobID); err != nil {
log.Printf("Failed to cancel active tasks for job %d: %v", jobID, err)
}
}
} else {
// All tasks completed successfully
jobStatus = string(types.JobStatusCompleted)
progress = 100.0 // Ensure progress is 100% when all tasks complete
// handleFailedTasks decides whether to retry or mark the job failed.
// Returns "" if a retry was triggered (caller should return early),
// or the final status string.
func (s *Manager) handleFailedTasks(jobID int64, currentStatus string, progress *float64) string {
var retryCount, maxRetries int
err := s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT retry_count, max_retries FROM jobs WHERE id = ?`, jobID,
).Scan(&retryCount, &maxRetries)
})
if err != nil {
log.Printf("Failed to get retry info for job %d: %v", jobID, err)
return string(types.JobStatusFailed)
}
if retryCount < maxRetries {
if err := s.resetFailedTasksAndRedistribute(jobID); err != nil {
log.Printf("Failed to reset failed tasks for job %d: %v", jobID, err)
return string(types.JobStatusFailed)
}
// Update job status (if we didn't return early from retry logic)
if jobStatus != "" {
err := s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`UPDATE jobs SET status = ?, progress = ?, completed_at = ? WHERE id = ?`,
jobStatus, progress, now, jobID,
)
return err
})
if err != nil {
log.Printf("Failed to update job %d status to %s: %v", jobID, jobStatus, err)
} else {
// Only log if status actually changed
if currentStatus != jobStatus {
log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed tasks: %d/%d)", jobID, currentStatus, jobStatus, progress, completedTasks, totalTasks)
}
// Broadcast job update via WebSocket
s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{
"status": jobStatus,
"progress": progress,
"completed_at": now,
})
// Clean up mutex for jobs in final states (completed or failed)
// No more status updates will occur for these jobs
if jobStatus == string(types.JobStatusCompleted) || jobStatus == string(types.JobStatusFailed) {
s.cleanupJobStatusUpdateMutex(jobID)
}
}
// Recalculate progress after reset
counts, err := s.getJobTaskCounts(jobID)
if err == nil && counts.total > 0 {
*progress = counts.progress()
}
// Encode tasks are now created immediately when the job is created
// with a condition that prevents assignment until all render tasks are completed.
// No need to create them here anymore.
} else {
// Job has pending or running tasks - determine if it's running or still pending
var runningTasks int
s.db.With(func(conn *sql.DB) error {
conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
jobID, types.TaskStatusRunning,
).Scan(&runningTasks)
return nil
})
if runningTasks > 0 {
// Has running tasks - job is running
jobStatus = string(types.JobStatusRunning)
var startedAt sql.NullTime
s.db.With(func(conn *sql.DB) error {
conn.QueryRow(`SELECT started_at FROM jobs WHERE id = ?`, jobID).Scan(&startedAt)
if !startedAt.Valid {
conn.Exec(`UPDATE jobs SET started_at = ? WHERE id = ?`, now, jobID)
}
return nil
})
} else {
// All tasks are pending - job is pending
jobStatus = string(types.JobStatusPending)
}
err := s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`UPDATE jobs SET status = ?, progress = ? WHERE id = ?`,
jobStatus, progress, jobID,
)
err = s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(`UPDATE jobs SET progress = ? WHERE id = ?`, *progress, jobID)
return err
})
if err != nil {
log.Printf("Failed to update job %d status to %s: %v", jobID, jobStatus, err)
log.Printf("Failed to update job %d progress: %v", jobID, err)
} else {
// Only log if status actually changed
if currentStatus != jobStatus {
log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed: %d/%d, pending: %d, running: %d)", jobID, currentStatus, jobStatus, progress, completedTasks, totalTasks, pendingOrRunningTasks-runningTasks, runningTasks)
}
// Broadcast job update during execution (not just on completion)
s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{
"status": jobStatus,
"progress": progress,
"status": currentStatus,
"progress": *progress,
})
}
return "" // retry handled
}
// No retries left
if err := s.cancelActiveTasksForJob(jobID); err != nil {
log.Printf("Failed to cancel active tasks for job %d: %v", jobID, err)
}
return string(types.JobStatusFailed)
}
// setJobFinalStatus persists the terminal job status and broadcasts the update.
func (s *Manager) setJobFinalStatus(jobID int64, currentStatus, jobStatus string, progress float64, now time.Time, counts *jobTaskCounts) {
err := s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`UPDATE jobs SET status = ?, progress = ?, completed_at = ? WHERE id = ?`,
jobStatus, progress, now, jobID,
)
return err
})
if err != nil {
log.Printf("Failed to update job %d status to %s: %v", jobID, jobStatus, err)
return
}
if currentStatus != jobStatus {
log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed tasks: %d/%d)", jobID, currentStatus, jobStatus, progress, counts.completed, counts.total)
}
s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{
"status": jobStatus,
"progress": progress,
"completed_at": now,
})
if jobStatus == string(types.JobStatusCompleted) || jobStatus == string(types.JobStatusFailed) {
s.cleanupJobStatusUpdateMutex(jobID)
}
}
// handleTasksInProgress handles the case where tasks are still pending or running.
func (s *Manager) handleTasksInProgress(jobID int64, currentStatus string, counts *jobTaskCounts, progress float64) {
now := time.Now()
var jobStatus string
if counts.running > 0 {
jobStatus = string(types.JobStatusRunning)
s.db.With(func(conn *sql.DB) error {
var startedAt sql.NullTime
conn.QueryRow(`SELECT started_at FROM jobs WHERE id = ?`, jobID).Scan(&startedAt)
if !startedAt.Valid {
conn.Exec(`UPDATE jobs SET started_at = ? WHERE id = ?`, now, jobID)
}
return nil
})
} else {
jobStatus = string(types.JobStatusPending)
}
err := s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`UPDATE jobs SET status = ?, progress = ? WHERE id = ?`,
jobStatus, progress, jobID,
)
return err
})
if err != nil {
log.Printf("Failed to update job %d status to %s: %v", jobID, jobStatus, err)
return
}
if currentStatus != jobStatus {
pending := counts.pendingOrRunning - counts.running
log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed: %d/%d, pending: %d, running: %d)", jobID, currentStatus, jobStatus, progress, counts.completed, counts.total, pending, counts.running)
}
s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{
"status": jobStatus,
"progress": progress,
})
}
// broadcastLogToFrontend broadcasts log to connected frontend clients