Files
jiggablend/internal/api/runners.go
2025-11-27 00:46:48 -06:00

2703 lines
86 KiB
Go

package api
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"math/rand"
"net/http"
"net/url"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
"time"
"jiggablend/pkg/types"
"github.com/go-chi/chi/v5"
"github.com/gorilla/websocket"
)
type contextKey string
const runnerIDContextKey contextKey = "runner_id"
// runnerAuthMiddleware verifies runner requests using API key
func (s *Server) runnerAuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Get API key from header
apiKey := r.Header.Get("Authorization")
if apiKey == "" {
// Try alternative header
apiKey = r.Header.Get("X-API-Key")
}
if apiKey == "" {
s.respondError(w, http.StatusUnauthorized, "API key required")
return
}
// Remove "Bearer " prefix if present
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
// Validate API key and get its ID
apiKeyID, _, err := s.secrets.ValidateRunnerAPIKey(apiKey)
if err != nil {
log.Printf("API key validation failed: %v", err)
s.respondError(w, http.StatusUnauthorized, "invalid API key")
return
}
// Get runner ID from query string or find runner by API key
runnerIDStr := r.URL.Query().Get("runner_id")
var runnerID int64
if runnerIDStr != "" {
// Runner ID provided - verify it belongs to this API key
_, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID)
if err != nil {
s.respondError(w, http.StatusBadRequest, "invalid runner_id")
return
}
// For fixed API keys, skip database verification
if apiKeyID != -1 {
// Verify runner exists and uses this API key
var dbAPIKeyID sql.NullInt64
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT api_key_id FROM runners WHERE id = ?", runnerID).Scan(&dbAPIKeyID)
})
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "runner not found")
return
}
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to query runner API key: %v", err))
return
}
if !dbAPIKeyID.Valid || dbAPIKeyID.Int64 != apiKeyID {
s.respondError(w, http.StatusForbidden, "runner does not belong to this API key")
return
}
}
} else {
// No runner ID provided - find the runner for this API key
// For simplicity, assume each API key has one runner
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT id FROM runners WHERE api_key_id = ?", apiKeyID).Scan(&runnerID)
})
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "no runner found for this API key")
return
}
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to query runner by API key: %v", err))
return
}
}
// Add runner ID to context
ctx := r.Context()
ctx = context.WithValue(ctx, runnerIDContextKey, runnerID)
next(w, r.WithContext(ctx))
}
}
// handleRegisterRunner registers a new runner using an API key
func (s *Server) handleRegisterRunner(w http.ResponseWriter, r *http.Request) {
var req struct {
types.RegisterRunnerRequest
APIKey string `json:"api_key"`
Fingerprint string `json:"fingerprint,omitempty"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
return
}
// Lock to prevent concurrent registrations that could create duplicate runners
s.secrets.RegistrationMu.Lock()
defer s.secrets.RegistrationMu.Unlock()
// Validate runner name
if req.Name == "" {
s.respondError(w, http.StatusBadRequest, "Runner name is required")
return
}
if len(req.Name) > 255 {
s.respondError(w, http.StatusBadRequest, "Runner name must be 255 characters or less")
return
}
// Validate hostname
if req.Hostname != "" {
// Basic hostname validation (allow IP addresses and domain names)
if len(req.Hostname) > 253 {
s.respondError(w, http.StatusBadRequest, "Hostname must be 253 characters or less")
return
}
}
// Validate capabilities JSON if provided
if req.Capabilities != "" {
var testCapabilities map[string]interface{}
if err := json.Unmarshal([]byte(req.Capabilities), &testCapabilities); err != nil {
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid capabilities JSON: %v", err))
return
}
}
if req.APIKey == "" {
s.respondError(w, http.StatusBadRequest, "API key is required")
return
}
// Validate API key
apiKeyID, apiKeyScope, err := s.secrets.ValidateRunnerAPIKey(req.APIKey)
if err != nil {
s.respondError(w, http.StatusUnauthorized, fmt.Sprintf("Invalid API key: %v", err))
return
}
// For fixed API keys (keyID = -1), skip fingerprint checking
// Set default priority if not provided
priority := 100
if req.Priority != nil {
priority = *req.Priority
}
// Register runner
var runnerID int64
// For fixed API keys, don't store api_key_id in database
var dbAPIKeyID interface{}
if apiKeyID == -1 {
dbAPIKeyID = nil // NULL for fixed API keys
} else {
dbAPIKeyID = apiKeyID
}
// Determine fingerprint value
fingerprint := req.Fingerprint
if apiKeyID == -1 || fingerprint == "" {
// For fixed API keys or when no fingerprint provided, generate a unique fingerprint
// to avoid conflicts while still maintaining some uniqueness
fingerprint = fmt.Sprintf("fixed-%s-%d", req.Name, time.Now().UnixNano())
}
// Check fingerprint uniqueness only for non-fixed API keys
if apiKeyID != -1 && req.Fingerprint != "" {
var existingRunnerID int64
var existingAPIKeyID sql.NullInt64
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
"SELECT id, api_key_id FROM runners WHERE fingerprint = ?",
req.Fingerprint,
).Scan(&existingRunnerID, &existingAPIKeyID)
})
if err == nil {
// Runner already exists with this fingerprint
if existingAPIKeyID.Valid && existingAPIKeyID.Int64 == apiKeyID {
// Same API key - update and return existing runner
log.Printf("Runner with fingerprint %s already exists (ID: %d), updating info", req.Fingerprint, existingRunnerID)
err = s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`UPDATE runners SET name = ?, hostname = ?, capabilities = ?, status = ?, last_heartbeat = ? WHERE id = ?`,
req.Name, req.Hostname, req.Capabilities, types.RunnerStatusOnline, time.Now(), existingRunnerID,
)
return err
})
if err != nil {
log.Printf("Warning: Failed to update existing runner info: %v", err)
}
s.respondJSON(w, http.StatusOK, map[string]interface{}{
"id": existingRunnerID,
"name": req.Name,
"hostname": req.Hostname,
"status": types.RunnerStatusOnline,
"reused": true, // Indicates this was a re-registration
})
return
} else {
// Different API key - reject registration
s.respondError(w, http.StatusConflict, "Runner with this fingerprint already registered with different API key")
return
}
}
// If err is not nil, it means no existing runner with this fingerprint - proceed with new registration
}
// Insert runner
err = s.db.With(func(conn *sql.DB) error {
result, err := conn.Exec(
`INSERT INTO runners (name, hostname, ip_address, status, last_heartbeat, capabilities,
api_key_id, api_key_scope, priority, fingerprint)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
req.Name, req.Hostname, "", types.RunnerStatusOnline, time.Now(), req.Capabilities,
dbAPIKeyID, apiKeyScope, priority, fingerprint,
)
if err != nil {
return err
}
runnerID, err = result.LastInsertId()
return err
})
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to register runner: %v", err))
return
}
log.Printf("Registered new runner %s (ID: %d) with API key ID: %d", req.Name, runnerID, apiKeyID)
// Return runner info
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
"id": runnerID,
"name": req.Name,
"hostname": req.Hostname,
"status": types.RunnerStatusOnline,
})
}
// handleRunnerPing allows runners to validate their secrets and connection
func (s *Server) handleRunnerPing(w http.ResponseWriter, r *http.Request) {
// This endpoint uses runnerAuthMiddleware, so if we get here, secrets are valid
// Get runner ID from context (set by runnerAuthMiddleware)
runnerID, ok := r.Context().Value(runnerIDContextKey).(int64)
if !ok {
s.respondError(w, http.StatusUnauthorized, "runner_id not found in context")
return
}
// Update last heartbeat
err := s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`,
time.Now(), types.RunnerStatusOnline, runnerID,
)
return err
})
if err != nil {
log.Printf("Warning: Failed to update runner heartbeat: %v", err)
}
s.respondJSON(w, http.StatusOK, map[string]interface{}{
"status": "ok",
"runner_id": runnerID,
"timestamp": time.Now().Unix(),
})
}
// handleUpdateTaskProgress updates task progress
func (s *Server) handleUpdateTaskProgress(w http.ResponseWriter, r *http.Request) {
_, err := parseID(r, "id")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
var req struct {
Progress float64 `json:"progress"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
return
}
// This is mainly for logging/debugging, actual progress is calculated from completed tasks
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Progress updated"})
}
// handleUpdateTaskStep handles step start/complete events from runners
func (s *Server) handleUpdateTaskStep(w http.ResponseWriter, r *http.Request) {
// Get runner ID from context (set by runnerAuthMiddleware)
runnerID, ok := r.Context().Value(runnerIDContextKey).(int64)
if !ok {
s.respondError(w, http.StatusUnauthorized, "runner_id not found in context")
return
}
taskID, err := parseID(r, "id")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
var req struct {
StepName string `json:"step_name"`
Status string `json:"status"` // "pending", "running", "completed", "failed", "skipped"
DurationMs *int `json:"duration_ms,omitempty"`
ErrorMessage string `json:"error_message,omitempty"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
return
}
// Verify task belongs to runner
var taskRunnerID sql.NullInt64
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskID).Scan(&taskRunnerID)
})
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "Task not found")
return
}
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify task: %v", err))
return
}
if !taskRunnerID.Valid || taskRunnerID.Int64 != runnerID {
s.respondError(w, http.StatusForbidden, "Task does not belong to this runner")
return
}
now := time.Now()
var stepID int64
// Check if step already exists
var existingStepID sql.NullInt64
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT id FROM task_steps WHERE task_id = ? AND step_name = ?`,
taskID, req.StepName,
).Scan(&existingStepID)
})
if err == sql.ErrNoRows || !existingStepID.Valid {
// Create new step
var startedAt *time.Time
var completedAt *time.Time
if req.Status == string(types.StepStatusRunning) || req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) {
startedAt = &now
}
if req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) {
completedAt = &now
}
err = s.db.With(func(conn *sql.DB) error {
result, err := conn.Exec(
`INSERT INTO task_steps (task_id, step_name, status, started_at, completed_at, duration_ms, error_message)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
taskID, req.StepName, req.Status, startedAt, completedAt, req.DurationMs, req.ErrorMessage,
)
if err != nil {
return err
}
stepID, err = result.LastInsertId()
return err
})
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create step: %v", err))
return
}
} else {
// Update existing step
stepID = existingStepID.Int64
var startedAt *time.Time
var completedAt *time.Time
// Get existing started_at if status is running/completed/failed
if req.Status == string(types.StepStatusRunning) || req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) {
var existingStartedAt sql.NullTime
s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(`SELECT started_at FROM task_steps WHERE id = ?`, stepID).Scan(&existingStartedAt)
})
if existingStartedAt.Valid {
startedAt = &existingStartedAt.Time
} else {
startedAt = &now
}
}
if req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) {
completedAt = &now
}
err = s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`UPDATE task_steps SET status = ?, started_at = ?, completed_at = ?, duration_ms = ?, error_message = ?
WHERE id = ?`,
req.Status, startedAt, completedAt, req.DurationMs, req.ErrorMessage, stepID,
)
return err
})
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update step: %v", err))
return
}
}
// Get job ID for broadcasting
var jobID int64
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&jobID)
})
if err == nil {
// Broadcast step update to frontend
s.broadcastTaskUpdate(jobID, taskID, "step_update", map[string]interface{}{
"step_id": stepID,
"step_name": req.StepName,
"status": req.Status,
"duration_ms": req.DurationMs,
"error_message": req.ErrorMessage,
})
}
s.respondJSON(w, http.StatusOK, map[string]interface{}{
"step_id": stepID,
"message": "Step updated successfully",
})
}
// handleDownloadJobContext allows runners to download the job context tar
func (s *Server) handleDownloadJobContext(w http.ResponseWriter, r *http.Request) {
jobID, err := parseID(r, "jobId")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
// Construct the context file path
contextPath := filepath.Join(s.storage.JobPath(jobID), "context.tar")
// Check if context file exists
if !s.storage.FileExists(contextPath) {
log.Printf("Context archive not found for job %d", jobID)
s.respondError(w, http.StatusNotFound, "Context archive not found. The file may not have been uploaded successfully.")
return
}
// Open and serve file
file, err := s.storage.GetFile(contextPath)
if err != nil {
s.respondError(w, http.StatusNotFound, "Context file not found on disk")
return
}
defer file.Close()
// Set appropriate headers for tar file
w.Header().Set("Content-Type", "application/x-tar")
w.Header().Set("Content-Disposition", "attachment; filename=context.tar")
// Stream the file to the response
io.Copy(w, file)
}
// handleUploadFileFromRunner allows runners to upload output files
func (s *Server) handleUploadFileFromRunner(w http.ResponseWriter, r *http.Request) {
jobID, err := parseID(r, "jobId")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
err = r.ParseMultipartForm(50 << 30) // 50 GB (for large output files)
if err != nil {
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Failed to parse multipart form: %v", err))
return
}
file, header, err := r.FormFile("file")
if err != nil {
s.respondError(w, http.StatusBadRequest, "No file provided")
return
}
defer file.Close()
// Save file
filePath, err := s.storage.SaveOutput(jobID, header.Filename, file)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to save file: %v", err))
return
}
// Record in database
var fileID int64
err = s.db.With(func(conn *sql.DB) error {
result, err := conn.Exec(
`INSERT INTO job_files (job_id, file_type, file_path, file_name, file_size)
VALUES (?, ?, ?, ?, ?)`,
jobID, types.JobFileTypeOutput, filePath, header.Filename, header.Size,
)
if err != nil {
return err
}
fileID, err = result.LastInsertId()
return err
})
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to record file: %v", err))
return
}
// Broadcast file addition
s.broadcastJobUpdate(jobID, "file_added", map[string]interface{}{
"file_id": fileID,
"file_type": types.JobFileTypeOutput,
"file_name": header.Filename,
"file_size": header.Size,
})
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
"file_path": filePath,
"file_name": header.Filename,
})
}
// handleGetJobStatusForRunner allows runners to check job status
func (s *Server) handleGetJobStatusForRunner(w http.ResponseWriter, r *http.Request) {
jobID, err := parseID(r, "jobId")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
var job types.Job
var startedAt, completedAt sql.NullTime
var errorMessage sql.NullString
var jobType string
var frameStart, frameEnd sql.NullInt64
var outputFormat sql.NullString
var allowParallelRunners sql.NullBool
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format,
allow_parallel_runners, created_at, started_at, completed_at, error_message
FROM jobs WHERE id = ?`,
jobID,
).Scan(
&job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress,
&frameStart, &frameEnd, &outputFormat, &allowParallelRunners,
&job.CreatedAt, &startedAt, &completedAt, &errorMessage,
)
})
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query job: %v", err))
return
}
job.JobType = types.JobType(jobType)
if frameStart.Valid {
fs := int(frameStart.Int64)
job.FrameStart = &fs
}
if frameEnd.Valid {
fe := int(frameEnd.Int64)
job.FrameEnd = &fe
}
if outputFormat.Valid {
job.OutputFormat = &outputFormat.String
}
if allowParallelRunners.Valid {
job.AllowParallelRunners = &allowParallelRunners.Bool
}
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "Job not found")
return
}
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query job: %v", err))
return
}
if startedAt.Valid {
job.StartedAt = &startedAt.Time
}
if completedAt.Valid {
job.CompletedAt = &completedAt.Time
}
if errorMessage.Valid {
job.ErrorMessage = errorMessage.String
}
s.respondJSON(w, http.StatusOK, job)
}
// handleGetJobFilesForRunner allows runners to get job files
func (s *Server) handleGetJobFilesForRunner(w http.ResponseWriter, r *http.Request) {
jobID, err := parseID(r, "jobId")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
var rows *sql.Rows
err = s.db.With(func(conn *sql.DB) error {
var err error
rows, err = conn.Query(
`SELECT id, job_id, file_type, file_path, file_name, file_size, created_at
FROM job_files WHERE job_id = ? ORDER BY file_name`,
jobID,
)
return err
})
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query files: %v", err))
return
}
defer rows.Close()
files := []types.JobFile{}
for rows.Next() {
var file types.JobFile
err := rows.Scan(
&file.ID, &file.JobID, &file.FileType, &file.FilePath,
&file.FileName, &file.FileSize, &file.CreatedAt,
)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan file: %v", err))
return
}
files = append(files, file)
}
s.respondJSON(w, http.StatusOK, files)
}
// handleGetJobMetadataForRunner allows runners to get job metadata
func (s *Server) handleGetJobMetadataForRunner(w http.ResponseWriter, r *http.Request) {
jobID, err := parseID(r, "jobId")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
var blendMetadataJSON sql.NullString
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT blend_metadata FROM jobs WHERE id = ?`,
jobID,
).Scan(&blendMetadataJSON)
})
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "Job not found")
return
}
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query job: %v", err))
return
}
if !blendMetadataJSON.Valid || blendMetadataJSON.String == "" {
s.respondJSON(w, http.StatusOK, nil)
return
}
var metadata types.BlendMetadata
if err := json.Unmarshal([]byte(blendMetadataJSON.String), &metadata); err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to parse metadata JSON: %v", err))
return
}
s.respondJSON(w, http.StatusOK, metadata)
}
// handleDownloadFileForRunner allows runners to download a file by fileName
func (s *Server) handleDownloadFileForRunner(w http.ResponseWriter, r *http.Request) {
jobID, err := parseID(r, "jobId")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
// Get fileName from URL path (may need URL decoding)
fileName := chi.URLParam(r, "fileName")
if fileName == "" {
s.respondError(w, http.StatusBadRequest, "fileName is required")
return
}
// URL decode the fileName in case it contains encoded characters
decodedFileName, err := url.QueryUnescape(fileName)
if err != nil {
// If decoding fails, use original fileName
decodedFileName = fileName
}
// Get file info from database
var filePath string
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT file_path FROM job_files WHERE job_id = ? AND file_name = ?`,
jobID, decodedFileName,
).Scan(&filePath)
})
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "File not found")
return
}
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query file: %v", err))
return
}
// Open file
file, err := s.storage.GetFile(filePath)
if err != nil {
s.respondError(w, http.StatusNotFound, "File not found on disk")
return
}
defer file.Close()
// Determine content type based on file extension
contentType := "application/octet-stream"
fileNameLower := strings.ToLower(decodedFileName)
switch {
case strings.HasSuffix(fileNameLower, ".png"):
contentType = "image/png"
case strings.HasSuffix(fileNameLower, ".jpg") || strings.HasSuffix(fileNameLower, ".jpeg"):
contentType = "image/jpeg"
case strings.HasSuffix(fileNameLower, ".gif"):
contentType = "image/gif"
case strings.HasSuffix(fileNameLower, ".webp"):
contentType = "image/webp"
case strings.HasSuffix(fileNameLower, ".exr") || strings.HasSuffix(fileNameLower, ".EXR"):
contentType = "image/x-exr"
case strings.HasSuffix(fileNameLower, ".mp4"):
contentType = "video/mp4"
case strings.HasSuffix(fileNameLower, ".webm"):
contentType = "video/webm"
}
// Set headers
w.Header().Set("Content-Type", contentType)
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", decodedFileName))
// Stream file
io.Copy(w, file)
}
// WebSocket message types
type WSMessage struct {
Type string `json:"type"`
Data json.RawMessage `json:"data"`
Timestamp int64 `json:"timestamp"`
}
type WSTaskAssignment struct {
TaskID int64 `json:"task_id"`
JobID int64 `json:"job_id"`
JobName string `json:"job_name"`
OutputFormat string `json:"output_format"`
FrameStart int `json:"frame_start"`
FrameEnd int `json:"frame_end"`
TaskType string `json:"task_type"`
InputFiles []string `json:"input_files"`
}
type WSLogEntry struct {
TaskID int64 `json:"task_id"`
LogLevel string `json:"log_level"`
Message string `json:"message"`
StepName string `json:"step_name,omitempty"`
}
type WSTaskUpdate struct {
TaskID int64 `json:"task_id"`
Status string `json:"status"`
OutputPath string `json:"output_path,omitempty"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
}
// handleRunnerWebSocket handles WebSocket connections from runners
func (s *Server) handleRunnerWebSocket(w http.ResponseWriter, r *http.Request) {
// Get API key from query params or headers
apiKey := r.URL.Query().Get("api_key")
if apiKey == "" {
apiKey = r.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
}
if apiKey == "" {
s.respondError(w, http.StatusBadRequest, "API key required")
return
}
// Validate API key
apiKeyID, _, err := s.secrets.ValidateRunnerAPIKey(apiKey)
if err != nil {
s.respondError(w, http.StatusUnauthorized, fmt.Sprintf("Invalid API key: %v", err))
return
}
// Get runner ID from query params or find by API key
runnerIDStr := r.URL.Query().Get("runner_id")
var runnerID int64
if runnerIDStr != "" {
// Runner ID provided - verify it belongs to this API key
_, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID)
if err != nil {
s.respondError(w, http.StatusBadRequest, "invalid runner_id")
return
}
// For fixed API keys, skip database verification
if apiKeyID != -1 {
var dbAPIKeyID sql.NullInt64
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT api_key_id FROM runners WHERE id = ?", runnerID).Scan(&dbAPIKeyID)
})
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "runner not found")
return
}
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to query runner API key: %v", err))
return
}
if !dbAPIKeyID.Valid || dbAPIKeyID.Int64 != apiKeyID {
s.respondError(w, http.StatusForbidden, "runner does not belong to this API key")
return
}
}
} else {
// No runner ID provided - find the runner for this API key
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT id FROM runners WHERE api_key_id = ?", apiKeyID).Scan(&runnerID)
})
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "no runner found for this API key")
return
}
if err != nil {
s.respondError(w, http.StatusInternalServerError, "database error")
return
}
}
// Upgrade to WebSocket
conn, err := s.wsUpgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("Failed to upgrade WebSocket: %v", err)
return
}
defer conn.Close()
// Register connection (must be done before any distribution checks)
// Fix race condition: Close old connection and create write mutex BEFORE registering new connection
var oldConn *websocket.Conn
s.runnerConnsMu.Lock()
if existingConn, exists := s.runnerConns[runnerID]; exists {
oldConn = existingConn
}
s.runnerConnsMu.Unlock()
// Close old connection BEFORE registering new one to prevent race conditions
if oldConn != nil {
log.Printf("Runner %d: closing existing WebSocket connection (reconnection)", runnerID)
oldConn.Close()
}
// Create write mutex BEFORE registering connection to prevent race condition
s.runnerConnsWriteMuMu.Lock()
s.runnerConnsWriteMu[runnerID] = &sync.Mutex{}
s.runnerConnsWriteMuMu.Unlock()
// Now register the new connection
s.runnerConnsMu.Lock()
s.runnerConns[runnerID] = conn
s.runnerConnsMu.Unlock()
log.Printf("Runner %d: WebSocket connection established successfully", runnerID)
// Check if runner was offline and had tasks assigned - redistribute them
// This handles the case where the manager restarted and marked the runner offline
// but tasks were still assigned to it
s.db.With(func(conn *sql.DB) error {
var count int
err := conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status = ?`,
runnerID, types.TaskStatusRunning,
).Scan(&count)
if err == nil && count > 0 {
log.Printf("Runner %d reconnected with %d running tasks assigned - redistributing them", runnerID, count)
s.redistributeRunnerTasks(runnerID)
}
return nil
})
// Update runner status to online
s.db.With(func(conn *sql.DB) error {
_, _ = conn.Exec(
`UPDATE runners SET status = ?, last_heartbeat = ? WHERE id = ?`,
types.RunnerStatusOnline, time.Now(), runnerID,
)
return nil
})
// Immediately try to distribute pending tasks to this newly connected runner
log.Printf("Runner %d connected, distributing pending tasks", runnerID)
s.triggerTaskDistribution()
// Note: We don't log to task logs here because we don't know which tasks will be assigned yet
// Task assignment logging happens in distributeTasksToRunners
// Cleanup on disconnect
// Fix race condition: Only cleanup if this is still the current connection (not replaced by reconnection)
defer func() {
log.Printf("Runner %d: WebSocket connection cleanup started", runnerID)
// Check if this is still the current connection before cleanup
s.runnerConnsMu.Lock()
currentConn, stillCurrent := s.runnerConns[runnerID]
if !stillCurrent || currentConn != conn {
// Connection was replaced by a newer one, don't cleanup
s.runnerConnsMu.Unlock()
log.Printf("Runner %d: Skipping cleanup - connection was replaced by newer connection", runnerID)
return
}
// Remove connection from map
delete(s.runnerConns, runnerID)
s.runnerConnsMu.Unlock()
// Update database status
err := s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`UPDATE runners SET status = ?, last_heartbeat = ? WHERE id = ?`,
types.RunnerStatusOffline, time.Now(), runnerID,
)
return err
})
if err != nil {
log.Printf("Warning: Failed to update runner %d status to offline: %v", runnerID, err)
}
// Clean up write mutex
s.runnerConnsWriteMuMu.Lock()
delete(s.runnerConnsWriteMu, runnerID)
s.runnerConnsWriteMuMu.Unlock()
// Immediately redistribute tasks that were assigned to this runner
log.Printf("Runner %d: WebSocket disconnected, redistributing tasks", runnerID)
s.redistributeRunnerTasks(runnerID)
log.Printf("Runner %d: WebSocket connection cleanup completed", runnerID)
}()
// Set pong handler to update heartbeat when we receive pong responses from runner
// Also reset read deadline to keep connection alive
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Increased to 90 seconds
s.db.With(func(conn *sql.DB) error {
_, _ = conn.Exec(
`UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`,
time.Now(), types.RunnerStatusOnline, runnerID,
)
return nil
})
return nil
})
// Set read deadline to ensure we process control frames (like pong)
conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Increased to 90 seconds
// Send ping every 30 seconds to trigger pong responses
go func() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for range ticker.C {
s.runnerConnsMu.RLock()
currentConn, exists := s.runnerConns[runnerID]
s.runnerConnsMu.RUnlock()
if !exists || currentConn != conn {
// Connection was replaced or removed
return
}
// Get write mutex for this connection
s.runnerConnsWriteMuMu.RLock()
writeMu, hasMu := s.runnerConnsWriteMu[runnerID]
s.runnerConnsWriteMuMu.RUnlock()
if !hasMu || writeMu == nil {
return
}
// Send ping - runner should respond with pong automatically
// Reset read deadline before sending ping to ensure we can receive pong
conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Increased to 90 seconds
writeMu.Lock()
err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second))
writeMu.Unlock()
if err != nil {
// Write failed - connection is likely dead, read loop will detect and cleanup
log.Printf("Failed to send ping to runner %d: %v", runnerID, err)
return
}
}
}()
// Handle incoming messages
for {
// Reset read deadline for each message - this is critical to keep connection alive
conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Increased to 90 seconds for safety
var msg WSMessage
err := conn.ReadJSON(&msg)
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("WebSocket error for runner %d: %v", runnerID, err)
}
break
}
// Reset read deadline after successfully reading a message
// This ensures the connection stays alive as long as we're receiving messages
conn.SetReadDeadline(time.Now().Add(90 * time.Second))
switch msg.Type {
case "heartbeat":
// Heartbeat messages are handled by pong handler (manager-side)
// Reset read deadline to keep connection alive
conn.SetReadDeadline(time.Now().Add(90 * time.Second))
// Note: Heartbeat updates are consolidated to pong handler to avoid race conditions
// The pong handler is the single source of truth for heartbeat updates
case "log_entry":
var logEntry WSLogEntry
if err := json.Unmarshal(msg.Data, &logEntry); err == nil {
s.handleWebSocketLog(runnerID, logEntry)
}
case "task_update":
var taskUpdate WSTaskUpdate
if err := json.Unmarshal(msg.Data, &taskUpdate); err == nil {
s.handleWebSocketTaskUpdate(runnerID, taskUpdate)
}
case "task_complete":
var taskUpdate WSTaskUpdate
if err := json.Unmarshal(msg.Data, &taskUpdate); err == nil {
s.handleWebSocketTaskComplete(runnerID, taskUpdate)
}
}
}
}
// handleWebSocketLog handles log entries from WebSocket
func (s *Server) handleWebSocketLog(runnerID int64, logEntry WSLogEntry) {
// Store log in database
err := s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`INSERT INTO task_logs (task_id, runner_id, log_level, message, step_name, created_at)
VALUES (?, ?, ?, ?, ?, ?)`,
logEntry.TaskID, runnerID, logEntry.LogLevel, logEntry.Message, logEntry.StepName, time.Now(),
)
return err
})
if err != nil {
log.Printf("Failed to store log: %v", err)
return
}
// Broadcast to frontend clients
s.broadcastLogToFrontend(logEntry.TaskID, logEntry)
// If this log contains a frame number (Fra:), update progress for single-runner render jobs
if strings.Contains(logEntry.Message, "Fra:") {
// Get job ID from task
var jobID int64
err := s.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT job_id FROM tasks WHERE id = ?", logEntry.TaskID).Scan(&jobID)
})
if err == nil {
// Throttle progress updates (max once per 2 seconds per job)
s.progressUpdateTimesMu.RLock()
lastUpdate, exists := s.progressUpdateTimes[jobID]
s.progressUpdateTimesMu.RUnlock()
shouldUpdate := !exists || time.Since(lastUpdate) >= 2*time.Second
if shouldUpdate {
s.progressUpdateTimesMu.Lock()
s.progressUpdateTimes[jobID] = time.Now()
s.progressUpdateTimesMu.Unlock()
// Update progress in background to avoid blocking log processing
go s.updateJobStatusFromTasks(jobID)
}
}
}
}
// handleWebSocketTaskUpdate handles task status updates from WebSocket
func (s *Server) handleWebSocketTaskUpdate(runnerID int64, taskUpdate WSTaskUpdate) {
// This can be used for progress updates
// For now, we'll just log it
log.Printf("Task %d update from runner %d: %s", taskUpdate.TaskID, runnerID, taskUpdate.Status)
}
// handleWebSocketTaskComplete handles task completion from WebSocket
func (s *Server) handleWebSocketTaskComplete(runnerID int64, taskUpdate WSTaskUpdate) {
// Verify task belongs to runner
var taskRunnerID sql.NullInt64
err := s.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskUpdate.TaskID).Scan(&taskRunnerID)
})
if err != nil || !taskRunnerID.Valid || taskRunnerID.Int64 != runnerID {
log.Printf("Task %d does not belong to runner %d", taskUpdate.TaskID, runnerID)
return
}
status := types.TaskStatusCompleted
if !taskUpdate.Success {
status = types.TaskStatusFailed
}
// Get job ID first for atomic update
var jobID int64
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT job_id FROM tasks WHERE id = ?`,
taskUpdate.TaskID,
).Scan(&jobID)
})
if err != nil {
log.Printf("Failed to get job ID for task %d: %v", taskUpdate.TaskID, err)
return
}
// Use transaction to update task and job status atomically
now := time.Now()
err = s.db.WithTx(func(tx *sql.Tx) error {
// Update columns individually
_, err := tx.Exec(`UPDATE tasks SET status = ? WHERE id = ?`, status, taskUpdate.TaskID)
if err != nil {
log.Printf("Failed to update task status: %v", err)
return err
}
if taskUpdate.OutputPath != "" {
_, err = tx.Exec(`UPDATE tasks SET output_path = ? WHERE id = ?`, taskUpdate.OutputPath, taskUpdate.TaskID)
if err != nil {
log.Printf("Failed to update task output_path: %v", err)
return err
}
}
_, err = tx.Exec(`UPDATE tasks SET completed_at = ? WHERE id = ?`, now, taskUpdate.TaskID)
if err != nil {
log.Printf("Failed to update task completed_at: %v", err)
return err
}
if taskUpdate.Error != "" {
_, err = tx.Exec(`UPDATE tasks SET error_message = ? WHERE id = ?`, taskUpdate.Error, taskUpdate.TaskID)
if err != nil {
log.Printf("Failed to update task error_message: %v", err)
return err
}
}
return nil // Commit on nil return
})
if err != nil {
log.Printf("Failed to update task %d: %v", taskUpdate.TaskID, err)
return
}
// Broadcast task update
s.broadcastTaskUpdate(jobID, taskUpdate.TaskID, "task_update", map[string]interface{}{
"status": status,
"output_path": taskUpdate.OutputPath,
"completed_at": now,
"error": taskUpdate.Error,
})
// Update job status and progress (this will query tasks and update job accordingly)
s.updateJobStatusFromTasks(jobID)
}
// parseBlenderFrame extracts the current frame number from Blender log messages
// Looks for patterns like "Fra:2470" in log messages
func parseBlenderFrame(logMessage string) (int, bool) {
// Look for "Fra:" followed by digits
// Pattern: "Fra:2470" or "Fra: 2470" or similar variations
fraIndex := strings.Index(logMessage, "Fra:")
if fraIndex == -1 {
return 0, false
}
// Find the number after "Fra:"
start := fraIndex + 4 // Skip "Fra:"
// Skip whitespace
for start < len(logMessage) && (logMessage[start] == ' ' || logMessage[start] == '\t') {
start++
}
// Extract digits
end := start
for end < len(logMessage) && logMessage[end] >= '0' && logMessage[end] <= '9' {
end++
}
if end > start {
frame, err := strconv.Atoi(logMessage[start:end])
if err == nil {
return frame, true
}
}
return 0, false
}
// getCurrentFrameFromLogs gets the highest frame number found in logs for a job's render tasks
func (s *Server) getCurrentFrameFromLogs(jobID int64) (int, bool) {
// Get all render tasks for this job
var rows *sql.Rows
err := s.db.With(func(conn *sql.DB) error {
var err error
rows, err = conn.Query(
`SELECT id FROM tasks WHERE job_id = ? AND task_type = ? AND status = ?`,
jobID, types.TaskTypeRender, types.TaskStatusRunning,
)
return err
})
if err != nil {
return 0, false
}
defer rows.Close()
maxFrame := 0
found := false
for rows.Next() {
var taskID int64
if err := rows.Scan(&taskID); err != nil {
log.Printf("Failed to scan task ID in getCurrentFrameFromLogs: %v", err)
continue
}
// Get the most recent log entries for this task (last 100 to avoid scanning all logs)
var logRows *sql.Rows
err := s.db.With(func(conn *sql.DB) error {
var err error
logRows, err = conn.Query(
`SELECT message FROM task_logs
WHERE task_id = ? AND message LIKE '%Fra:%'
ORDER BY id DESC LIMIT 100`,
taskID,
)
return err
})
if err != nil {
continue
}
for logRows.Next() {
var message string
if err := logRows.Scan(&message); err != nil {
continue
}
if frame, ok := parseBlenderFrame(message); ok {
if frame > maxFrame {
maxFrame = frame
found = true
}
}
}
logRows.Close()
}
return maxFrame, found
}
// resetFailedTasksAndRedistribute resets all failed tasks for a job to pending and redistributes them
func (s *Server) resetFailedTasksAndRedistribute(jobID int64) error {
// Reset all failed tasks to pending and clear their retry_count
err := s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`UPDATE tasks SET status = ?, retry_count = 0, runner_id = NULL, started_at = NULL, completed_at = NULL, error_message = NULL
WHERE job_id = ? AND status = ?`,
types.TaskStatusPending, jobID, types.TaskStatusFailed,
)
if err != nil {
return fmt.Errorf("failed to reset failed tasks: %v", err)
}
// Increment job retry_count
_, err = conn.Exec(
`UPDATE jobs SET retry_count = retry_count + 1 WHERE id = ?`,
jobID,
)
if err != nil {
return fmt.Errorf("failed to increment job retry_count: %v", err)
}
return nil
})
if err != nil {
return err
}
log.Printf("Reset failed tasks for job %d and incremented retry_count", jobID)
// Trigger task distribution to redistribute the reset tasks
s.triggerTaskDistribution()
return nil
}
// cancelActiveTasksForJob cancels all active (pending or running) tasks for a job
func (s *Server) cancelActiveTasksForJob(jobID int64) error {
// Tasks don't have a cancelled status - mark them as failed instead
err := s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`UPDATE tasks SET status = ?, error_message = ? WHERE job_id = ? AND status IN (?, ?)`,
types.TaskStatusFailed, "Job cancelled", jobID, types.TaskStatusPending, types.TaskStatusRunning,
)
if err != nil {
return fmt.Errorf("failed to cancel active tasks: %v", err)
}
return nil
})
if err != nil {
return err
}
log.Printf("Cancelled all active tasks for job %d", jobID)
return nil
}
// updateJobStatusFromTasks updates job status and progress based on task states
func (s *Server) updateJobStatusFromTasks(jobID int64) {
now := time.Now()
// Get job info to check if it's a render job without parallel runners
var jobType string
var frameStart, frameEnd sql.NullInt64
var allowParallelRunners sql.NullBool
err := s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT job_type, frame_start, frame_end, allow_parallel_runners FROM jobs WHERE id = ?`,
jobID,
).Scan(&jobType, &frameStart, &frameEnd, &allowParallelRunners)
})
if err != nil {
log.Printf("Failed to get job info for job %d: %v", jobID, err)
return
}
// Check if we should use frame-based progress (render job, single runner)
useFrameProgress := jobType == string(types.JobTypeRender) &&
allowParallelRunners.Valid && !allowParallelRunners.Bool &&
frameStart.Valid && frameEnd.Valid
// Get current job status to detect changes
var currentStatus string
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(`SELECT status FROM jobs WHERE id = ?`, jobID).Scan(&currentStatus)
})
if err != nil {
log.Printf("Failed to get current job status for job %d: %v", jobID, err)
return
}
// Count total tasks and completed tasks
var totalTasks, completedTasks int
err = s.db.With(func(conn *sql.DB) error {
err := conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?, ?, ?)`,
jobID, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed,
).Scan(&totalTasks)
if err != nil {
return err
}
return conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
jobID, types.TaskStatusCompleted,
).Scan(&completedTasks)
})
if err != nil {
log.Printf("Failed to count completed tasks for job %d: %v", jobID, err)
return
}
// Calculate progress
var progress float64
if totalTasks == 0 {
// All tasks cancelled or no tasks, set progress to 0
progress = 0.0
} else if useFrameProgress {
// For single-runner render jobs, use frame-based progress from logs
currentFrame, frameFound := s.getCurrentFrameFromLogs(jobID)
frameStartVal := int(frameStart.Int64)
frameEndVal := int(frameEnd.Int64)
totalFrames := frameEndVal - frameStartVal + 1
// Count non-render tasks (like video generation) separately
var nonRenderTasks, nonRenderCompleted int
s.db.With(func(conn *sql.DB) error {
conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND task_type != ? AND status IN (?, ?, ?, ?)`,
jobID, types.TaskTypeRender, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed,
).Scan(&nonRenderTasks)
conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND task_type != ? AND status = ?`,
jobID, types.TaskTypeRender, types.TaskStatusCompleted,
).Scan(&nonRenderCompleted)
return nil
})
// Calculate render task progress from frames
var renderProgress float64
if frameFound && totalFrames > 0 {
// Calculate how many frames have been rendered (current - start + 1)
// But cap at frame_end to handle cases where logs show frames beyond end
renderedFrames := currentFrame - frameStartVal + 1
if currentFrame > frameEndVal {
renderedFrames = totalFrames
} else if renderedFrames < 0 {
renderedFrames = 0
}
if renderedFrames > totalFrames {
renderedFrames = totalFrames
}
renderProgress = float64(renderedFrames) / float64(totalFrames) * 100.0
} else {
// Fall back to task-based progress for render tasks
var renderTasks, renderCompleted int
s.db.With(func(conn *sql.DB) error {
conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND task_type = ? AND status IN (?, ?, ?, ?)`,
jobID, types.TaskTypeRender, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed,
).Scan(&renderTasks)
conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND task_type = ? AND status = ?`,
jobID, types.TaskTypeRender, types.TaskStatusCompleted,
).Scan(&renderCompleted)
return nil
})
if renderTasks > 0 {
renderProgress = float64(renderCompleted) / float64(renderTasks) * 100.0
}
}
// Combine render progress with non-render task progress
// Weight: render tasks contribute 90%, other tasks contribute 10% (adjust as needed)
var nonRenderProgress float64
if nonRenderTasks > 0 {
nonRenderProgress = float64(nonRenderCompleted) / float64(nonRenderTasks) * 100.0
}
// Weighted average: render progress is most important
if totalTasks > 0 {
renderWeight := 0.9
nonRenderWeight := 0.1
progress = renderProgress*renderWeight + nonRenderProgress*nonRenderWeight
} else {
progress = renderProgress
}
} else {
// Standard task-based progress
progress = float64(completedTasks) / float64(totalTasks) * 100.0
}
var jobStatus string
var outputFormat sql.NullString
s.db.With(func(conn *sql.DB) error {
conn.QueryRow(`SELECT output_format FROM jobs WHERE id = ?`, jobID).Scan(&outputFormat)
return nil
})
outputFormatStr := ""
if outputFormat.Valid {
outputFormatStr = outputFormat.String
}
// Check if all non-cancelled tasks are completed
var pendingOrRunningTasks int
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT COUNT(*) FROM tasks
WHERE job_id = ? AND status IN (?, ?)`,
jobID, types.TaskStatusPending, types.TaskStatusRunning,
).Scan(&pendingOrRunningTasks)
})
if err != nil {
log.Printf("Failed to count pending/running tasks for job %d: %v", jobID, err)
return
}
if pendingOrRunningTasks == 0 && totalTasks > 0 {
// All tasks are either completed or failed/cancelled
// Check if any tasks failed
var failedTasks int
s.db.With(func(conn *sql.DB) error {
conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
jobID, types.TaskStatusFailed,
).Scan(&failedTasks)
return nil
})
if failedTasks > 0 {
// Some tasks failed - check if job has retries left
var retryCount, maxRetries int
err := s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT retry_count, max_retries FROM jobs WHERE id = ?`,
jobID,
).Scan(&retryCount, &maxRetries)
})
if err != nil {
log.Printf("Failed to get retry info for job %d: %v", jobID, err)
// Fall back to marking job as failed
jobStatus = string(types.JobStatusFailed)
} else if retryCount < maxRetries {
// Job has retries left - reset failed tasks and redistribute
if err := s.resetFailedTasksAndRedistribute(jobID); err != nil {
log.Printf("Failed to reset failed tasks for job %d: %v", jobID, err)
// If reset fails, mark job as failed
jobStatus = string(types.JobStatusFailed)
} else {
// Tasks reset successfully - job remains in running/pending state
// Don't update job status, just update progress
jobStatus = currentStatus // Keep current status
// Recalculate progress after reset (failed tasks are now pending again)
var newTotalTasks, newCompletedTasks int
s.db.With(func(conn *sql.DB) error {
conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?, ?, ?)`,
jobID, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed,
).Scan(&newTotalTasks)
conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
jobID, types.TaskStatusCompleted,
).Scan(&newCompletedTasks)
return nil
})
if newTotalTasks > 0 {
progress = float64(newCompletedTasks) / float64(newTotalTasks) * 100.0
}
// Update progress only
err := s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`UPDATE jobs SET progress = ? WHERE id = ?`,
progress, jobID,
)
return err
})
if err != nil {
log.Printf("Failed to update job %d progress: %v", jobID, err)
} else {
// Broadcast job update via WebSocket
s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{
"status": jobStatus,
"progress": progress,
})
}
return // Exit early since we've handled the retry
}
} else {
// No retries left - mark job as failed and cancel active tasks
jobStatus = string(types.JobStatusFailed)
if err := s.cancelActiveTasksForJob(jobID); err != nil {
log.Printf("Failed to cancel active tasks for job %d: %v", jobID, err)
}
}
} else {
// All tasks completed successfully
jobStatus = string(types.JobStatusCompleted)
progress = 100.0 // Ensure progress is 100% when all tasks complete
}
// Update job status (if we didn't return early from retry logic)
if jobStatus != "" {
err := s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`UPDATE jobs SET status = ?, progress = ?, completed_at = ? WHERE id = ?`,
jobStatus, progress, now, jobID,
)
return err
})
if err != nil {
log.Printf("Failed to update job %d status to %s: %v", jobID, jobStatus, err)
} else {
// Only log if status actually changed
if currentStatus != jobStatus {
log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed tasks: %d/%d)", jobID, currentStatus, jobStatus, progress, completedTasks, totalTasks)
}
// Broadcast job update via WebSocket
s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{
"status": jobStatus,
"progress": progress,
"completed_at": now,
})
}
}
if outputFormatStr == "EXR_264_MP4" || outputFormatStr == "EXR_AV1_MP4" {
// Check if a video generation task already exists for this job (any status)
var existingVideoTask int
s.db.With(func(conn *sql.DB) error {
conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND task_type = ?`,
jobID, types.TaskTypeVideoGeneration,
).Scan(&existingVideoTask)
return nil
})
if existingVideoTask == 0 {
// Create a video generation task instead of calling generateMP4Video directly
// This prevents race conditions when multiple runners complete frames simultaneously
videoTaskTimeout := 86400 // 24 hours for video generation
var videoTaskID int64
err := s.db.With(func(conn *sql.DB) error {
result, err := conn.Exec(
`INSERT INTO tasks (job_id, frame_start, frame_end, task_type, status, timeout_seconds, max_retries)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
jobID, 0, 0, types.TaskTypeVideoGeneration, types.TaskStatusPending, videoTaskTimeout, 1,
)
if err != nil {
return err
}
videoTaskID, err = result.LastInsertId()
return err
})
if err != nil {
log.Printf("Failed to create video generation task for job %d: %v", jobID, err)
} else {
// Broadcast that a new task was added
if s.verboseWSLogging {
log.Printf("Broadcasting task_added for job %d: video generation task %d", jobID, videoTaskID)
}
s.broadcastTaskUpdate(jobID, videoTaskID, "task_added", map[string]interface{}{
"task_id": videoTaskID,
"task_type": types.TaskTypeVideoGeneration,
})
// Update job status to ensure it's marked as running (has pending video task)
s.updateJobStatusFromTasks(jobID)
// Try to distribute the task immediately
s.triggerTaskDistribution()
}
} else {
log.Printf("Skipping video generation task creation for job %d (video task already exists)", jobID)
}
}
} else {
// Job has pending or running tasks - determine if it's running or still pending
var runningTasks int
s.db.With(func(conn *sql.DB) error {
conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
jobID, types.TaskStatusRunning,
).Scan(&runningTasks)
return nil
})
if runningTasks > 0 {
// Has running tasks - job is running
jobStatus = string(types.JobStatusRunning)
var startedAt sql.NullTime
s.db.With(func(conn *sql.DB) error {
conn.QueryRow(`SELECT started_at FROM jobs WHERE id = ?`, jobID).Scan(&startedAt)
if !startedAt.Valid {
conn.Exec(`UPDATE jobs SET started_at = ? WHERE id = ?`, now, jobID)
}
return nil
})
} else {
// All tasks are pending - job is pending
jobStatus = string(types.JobStatusPending)
}
err := s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`UPDATE jobs SET status = ?, progress = ? WHERE id = ?`,
jobStatus, progress, jobID,
)
return err
})
if err != nil {
log.Printf("Failed to update job %d status to %s: %v", jobID, jobStatus, err)
} else {
// Only log if status actually changed
if currentStatus != jobStatus {
log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed: %d/%d, pending: %d, running: %d)", jobID, currentStatus, jobStatus, progress, completedTasks, totalTasks, pendingOrRunningTasks-runningTasks, runningTasks)
}
// Broadcast job update during execution (not just on completion)
s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{
"status": jobStatus,
"progress": progress,
})
}
}
}
// broadcastLogToFrontend broadcasts log to connected frontend clients
func (s *Server) broadcastLogToFrontend(taskID int64, logEntry WSLogEntry) {
// Get job_id, user_id, and task status from task
var jobID, userID int64
var taskStatus string
var taskRunnerID sql.NullInt64
var taskStartedAt sql.NullTime
err := s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT t.job_id, j.user_id, t.status, t.runner_id, t.started_at
FROM tasks t
JOIN jobs j ON t.job_id = j.id
WHERE t.id = ?`,
taskID,
).Scan(&jobID, &userID, &taskStatus, &taskRunnerID, &taskStartedAt)
})
if err != nil {
return
}
// Get full log entry from database for consistency
// Use a more reliable query that gets the most recent log with matching message
// This avoids race conditions with concurrent inserts
var taskLog types.TaskLog
var runnerID sql.NullInt64
var stepName sql.NullString
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT id, task_id, runner_id, log_level, message, step_name, created_at
FROM task_logs WHERE task_id = ? AND message = ? ORDER BY id DESC LIMIT 1`,
taskID, logEntry.Message,
).Scan(&taskLog.ID, &taskLog.TaskID, &runnerID, &taskLog.LogLevel, &taskLog.Message, &stepName, &taskLog.CreatedAt)
})
if err != nil {
return
}
if runnerID.Valid {
taskLog.RunnerID = &runnerID.Int64
}
if stepName.Valid {
taskLog.StepName = stepName.String
}
msg := map[string]interface{}{
"type": "log",
"task_id": taskID,
"job_id": jobID,
"data": taskLog,
"timestamp": time.Now().Unix(),
}
// Broadcast to client WebSocket if subscribed to logs:{jobId}:{taskId}
channel := fmt.Sprintf("logs:%d:%d", jobID, taskID)
if s.verboseWSLogging {
runnerIDStr := "none"
if taskRunnerID.Valid {
runnerIDStr = fmt.Sprintf("%d", taskRunnerID.Int64)
}
log.Printf("broadcastLogToFrontend: Broadcasting log for task %d (job %d, user %d) on channel %s, log_id=%d, task_status=%s, runner_id=%s", taskID, jobID, userID, channel, taskLog.ID, taskStatus, runnerIDStr)
}
s.broadcastToClient(userID, channel, msg)
// If task status is pending but logs are coming in, log a warning
// This indicates the initial assignment broadcast may have been missed or the database update failed
if taskStatus == string(types.TaskStatusPending) {
log.Printf("broadcastLogToFrontend: ERROR - Task %d has logs but status is 'pending'. This indicates the initial task assignment failed or the task_update broadcast was missed.", taskID)
}
// Also broadcast to old WebSocket connection (for backwards compatibility during migration)
key := fmt.Sprintf("%d:%d", jobID, taskID)
s.frontendConnsMu.RLock()
conn, exists := s.frontendConns[key]
s.frontendConnsMu.RUnlock()
if exists && conn != nil {
// Serialize writes to prevent concurrent write panics
s.frontendConnsWriteMuMu.RLock()
writeMu, hasMu := s.frontendConnsWriteMu[key]
s.frontendConnsWriteMuMu.RUnlock()
if hasMu && writeMu != nil {
writeMu.Lock()
conn.WriteJSON(msg)
writeMu.Unlock()
} else {
// Fallback if mutex doesn't exist yet (shouldn't happen, but be safe)
conn.WriteJSON(msg)
}
}
}
// triggerTaskDistribution triggers task distribution in a serialized manner
func (s *Server) triggerTaskDistribution() {
go func() {
// Try to acquire lock - if already running, log and skip
if !s.taskDistMu.TryLock() {
// Log when distribution is skipped to help with debugging
log.Printf("Task distribution already in progress, skipping trigger")
return // Distribution already in progress
}
defer s.taskDistMu.Unlock()
s.distributeTasksToRunners()
}()
}
// distributeTasksToRunners pushes available tasks to connected runners
// This function should only be called while holding taskDistMu lock
func (s *Server) distributeTasksToRunners() {
// Quick check: if there are no pending tasks, skip the expensive query
var pendingCount int
err := s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT COUNT(*) FROM tasks t
JOIN jobs j ON t.job_id = j.id
WHERE t.status = ? AND j.status != ?`,
types.TaskStatusPending, types.JobStatusCancelled,
).Scan(&pendingCount)
})
if err != nil {
log.Printf("Failed to check pending tasks count: %v", err)
return
}
if pendingCount == 0 {
// No pending tasks, nothing to distribute
return
}
// Get all pending tasks
var rows *sql.Rows
err = s.db.With(func(conn *sql.DB) error {
var err error
rows, err = conn.Query(
`SELECT t.id, t.job_id, t.frame_start, t.frame_end, t.task_type, j.allow_parallel_runners, j.status as job_status, j.name as job_name, j.user_id
FROM tasks t
JOIN jobs j ON t.job_id = j.id
WHERE t.status = ? AND j.status != ?
ORDER BY t.created_at ASC`,
types.TaskStatusPending, types.JobStatusCancelled,
)
return err
})
if err != nil {
log.Printf("Failed to query pending tasks: %v", err)
return
}
defer rows.Close()
var pendingTasks []struct {
TaskID int64
JobID int64
FrameStart int
FrameEnd int
TaskType string
AllowParallelRunners bool
JobName string
JobStatus string
JobUserID int64
}
for rows.Next() {
var t struct {
TaskID int64
JobID int64
FrameStart int
FrameEnd int
TaskType string
AllowParallelRunners bool
JobName string
JobStatus string
JobUserID int64
}
var allowParallel sql.NullBool
err := rows.Scan(&t.TaskID, &t.JobID, &t.FrameStart, &t.FrameEnd, &t.TaskType, &allowParallel, &t.JobStatus, &t.JobName, &t.JobUserID)
if err != nil {
log.Printf("Failed to scan pending task: %v", err)
continue
}
// Default to true if NULL (for metadata jobs or legacy data)
if allowParallel.Valid {
t.AllowParallelRunners = allowParallel.Bool
} else {
t.AllowParallelRunners = true
}
pendingTasks = append(pendingTasks, t)
}
if len(pendingTasks) == 0 {
log.Printf("No pending tasks found for distribution")
return
}
log.Printf("Found %d pending tasks for distribution", len(pendingTasks))
// Get connected runners (WebSocket connection is source of truth)
// Use a read lock to safely read the map
s.runnerConnsMu.RLock()
connectedRunners := make([]int64, 0, len(s.runnerConns))
for runnerID := range s.runnerConns {
// Verify connection is still valid (not closed)
conn := s.runnerConns[runnerID]
if conn != nil {
connectedRunners = append(connectedRunners, runnerID)
}
}
s.runnerConnsMu.RUnlock()
// Get runner priorities, capabilities, and API key scopes for all connected runners
runnerPriorities := make(map[int64]int)
runnerCapabilities := make(map[int64]map[string]interface{})
runnerScopes := make(map[int64]string)
for _, runnerID := range connectedRunners {
var priority int
var capabilitiesJSON sql.NullString
var scope string
err := s.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT priority, capabilities, api_key_scope FROM runners WHERE id = ?", runnerID).Scan(&priority, &capabilitiesJSON, &scope)
})
if err != nil {
// Default to 100 if priority not found
priority = 100
capabilitiesJSON = sql.NullString{String: "{}", Valid: true}
}
runnerPriorities[runnerID] = priority
runnerScopes[runnerID] = scope
// Parse capabilities JSON (can contain both bools and numbers)
capabilitiesStr := "{}"
if capabilitiesJSON.Valid {
capabilitiesStr = capabilitiesJSON.String
}
var capabilities map[string]interface{}
if err := json.Unmarshal([]byte(capabilitiesStr), &capabilities); err != nil {
// If parsing fails, try old format (map[string]bool) for backward compatibility
var oldCapabilities map[string]bool
if err2 := json.Unmarshal([]byte(capabilitiesStr), &oldCapabilities); err2 == nil {
// Convert old format to new format
capabilities = make(map[string]interface{})
for k, v := range oldCapabilities {
capabilities[k] = v
}
} else {
// Both formats failed, assume no capabilities
capabilities = make(map[string]interface{})
}
}
runnerCapabilities[runnerID] = capabilities
}
// Update database status for all connected runners (outside the lock to avoid holding it too long)
for _, runnerID := range connectedRunners {
// Ensure database status matches WebSocket connection
// Update status to online if it's not already
s.db.With(func(conn *sql.DB) error {
_, _ = conn.Exec(
`UPDATE runners SET status = ?, last_heartbeat = ? WHERE id = ? AND status != ?`,
types.RunnerStatusOnline, time.Now(), runnerID, types.RunnerStatusOnline,
)
return nil
})
}
if len(connectedRunners) == 0 {
log.Printf("No connected runners available for task distribution (checked WebSocket connections)")
// Log to task logs that no runners are available
for _, task := range pendingTasks {
if task.TaskType == string(types.TaskTypeMetadata) {
s.logTaskEvent(task.TaskID, nil, types.LogLevelWarn, "No connected runners available for task assignment", "")
}
}
return
}
// Log task types being distributed
taskTypes := make(map[string]int)
for _, task := range pendingTasks {
taskTypes[task.TaskType]++
}
log.Printf("Distributing %d pending tasks (%v) to %d connected runners: %v", len(pendingTasks), taskTypes, len(connectedRunners), connectedRunners)
// Distribute tasks to runners
// Sort tasks to prioritize metadata tasks
sort.Slice(pendingTasks, func(i, j int) bool {
// Metadata tasks first
if pendingTasks[i].TaskType == string(types.TaskTypeMetadata) && pendingTasks[j].TaskType != string(types.TaskTypeMetadata) {
return true
}
if pendingTasks[i].TaskType != string(types.TaskTypeMetadata) && pendingTasks[j].TaskType == string(types.TaskTypeMetadata) {
return false
}
return false // Keep original order for same type
})
// Track how many tasks each runner has been assigned in this distribution cycle
runnerTaskCounts := make(map[int64]int)
for _, task := range pendingTasks {
// Determine required capability for this task
var requiredCapability string
switch task.TaskType {
case string(types.TaskTypeRender), string(types.TaskTypeMetadata):
requiredCapability = "blender"
case string(types.TaskTypeVideoGeneration):
requiredCapability = "ffmpeg"
default:
requiredCapability = "" // Unknown task type
}
// Find available runner
var selectedRunnerID int64
var bestRunnerID int64
var bestPriority int = -1
var bestTaskCount int = -1
var bestRandom float64 = -1 // Random tie-breaker
// Try to find the best runner for this task
for _, runnerID := range connectedRunners {
// Check if runner's API key scope allows working on this job
runnerScope := runnerScopes[runnerID]
if runnerScope == "user" && task.JobUserID != 0 {
// User-scoped runner - check if they can work on jobs from this user
// For now, user-scoped runners can only work on jobs from the same user who created their API key
var apiKeyCreatedBy int64
if runnerScope == "user" {
// Get the user who created this runner's API key
var apiKeyID sql.NullInt64
err := s.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT api_key_id FROM runners WHERE id = ?", runnerID).Scan(&apiKeyID)
})
if err == nil && apiKeyID.Valid {
s.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT created_by FROM runner_api_keys WHERE id = ?", apiKeyID.Int64).Scan(&apiKeyCreatedBy)
})
if err != nil {
continue // Skip this runner if we can't determine API key ownership
}
// Only allow if the job owner matches the API key creator
if apiKeyCreatedBy != task.JobUserID {
continue // This user-scoped runner cannot work on this job
}
}
}
// Manager-scoped runners can work on any job
}
// Check if runner has required capability
capabilities := runnerCapabilities[runnerID]
hasRequired := false
if reqVal, ok := capabilities[requiredCapability]; ok {
if reqBool, ok := reqVal.(bool); ok {
hasRequired = reqBool
} else if reqFloat, ok := reqVal.(float64); ok {
hasRequired = reqFloat > 0
} else if reqInt, ok := reqVal.(int); ok {
hasRequired = reqInt > 0
}
}
if !hasRequired && requiredCapability != "" {
continue // Runner doesn't have required capability
}
// Check if runner has ANY tasks (pending or running) - one task at a time only
// This prevents any runner from doing more than one task at a time
var activeTaskCount int
err := s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status IN (?, ?)`,
runnerID, types.TaskStatusPending, types.TaskStatusRunning,
).Scan(&activeTaskCount)
})
if err != nil {
log.Printf("Failed to check active tasks for runner %d: %v", runnerID, err)
continue
}
if activeTaskCount > 0 {
continue // Runner is busy with another task, cannot run any other tasks
}
// For non-parallel jobs, check if runner already has tasks from this job
if !task.AllowParallelRunners {
var jobTaskCount int
err := s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT COUNT(*) FROM tasks
WHERE job_id = ? AND runner_id = ? AND status IN (?, ?)`,
task.JobID, runnerID, types.TaskStatusPending, types.TaskStatusRunning,
).Scan(&jobTaskCount)
})
if err != nil {
log.Printf("Failed to check job tasks for runner %d: %v", runnerID, err)
continue
}
if jobTaskCount > 0 {
continue // Another runner is working on this job
}
}
// Get runner priority and task count
priority := runnerPriorities[runnerID]
currentTaskCount := runnerTaskCounts[runnerID]
// Generate a small random value for absolute tie-breaking
randomValue := rand.Float64()
// Selection priority:
// 1. Priority (higher is better)
// 2. Task count (fewer is better)
// 3. Random value (absolute tie-breaker)
isBetter := false
if bestRunnerID == 0 {
isBetter = true
} else if priority > bestPriority {
// Higher priority
isBetter = true
} else if priority == bestPriority {
if currentTaskCount < bestTaskCount {
// Same priority, but fewer tasks assigned in this cycle
isBetter = true
} else if currentTaskCount == bestTaskCount {
// Absolute tie - use random value as tie-breaker
if randomValue > bestRandom {
isBetter = true
}
}
}
if isBetter {
bestRunnerID = runnerID
bestPriority = priority
bestTaskCount = currentTaskCount
bestRandom = randomValue
}
}
// Use the best runner we found (prioritized by priority, then load balanced)
if bestRunnerID != 0 {
selectedRunnerID = bestRunnerID
}
if selectedRunnerID == 0 {
if task.TaskType == string(types.TaskTypeMetadata) {
log.Printf("Warning: No available runner for metadata task %d (job %d)", task.TaskID, task.JobID)
// Log that no runner is available
s.logTaskEvent(task.TaskID, nil, types.LogLevelWarn, "No available runner for task assignment", "")
}
continue // No available runner - task stays in queue
}
// Track assignment for load balancing
runnerTaskCounts[selectedRunnerID]++
// Atomically assign task to runner using UPDATE with WHERE runner_id IS NULL
// This prevents race conditions when multiple goroutines try to assign the same task
// Use a transaction to ensure atomicity
now := time.Now()
var rowsAffected int64
var verifyStatus string
var verifyRunnerID sql.NullInt64
var verifyStartedAt sql.NullTime
err := s.db.WithTx(func(tx *sql.Tx) error {
result, err := tx.Exec(
`UPDATE tasks SET runner_id = ?, status = ?, started_at = ?
WHERE id = ? AND runner_id IS NULL AND status = ?`,
selectedRunnerID, types.TaskStatusRunning, now, task.TaskID, types.TaskStatusPending,
)
if err != nil {
return err
}
// Check if the update actually affected a row (task was successfully assigned)
rowsAffected, err = result.RowsAffected()
if err != nil {
return err
}
if rowsAffected == 0 {
return sql.ErrNoRows // Task was already assigned
}
// Verify the update within the transaction before committing
// This ensures we catch any issues before the transaction is committed
err = tx.QueryRow(
`SELECT status, runner_id, started_at FROM tasks WHERE id = ?`,
task.TaskID,
).Scan(&verifyStatus, &verifyRunnerID, &verifyStartedAt)
if err != nil {
return err
}
if verifyStatus != string(types.TaskStatusRunning) {
return fmt.Errorf("task status is %s after assignment, expected running", verifyStatus)
}
if !verifyRunnerID.Valid || verifyRunnerID.Int64 != selectedRunnerID {
return fmt.Errorf("task runner_id is %v after assignment, expected %d", verifyRunnerID, selectedRunnerID)
}
return nil // Commit on nil return
})
if err == sql.ErrNoRows {
// Task was already assigned by another goroutine, skip
continue
}
if err != nil {
log.Printf("Failed to atomically assign task %d: %v", task.TaskID, err)
continue
}
log.Printf("Verified and committed task %d assignment: status=%s, runner_id=%d, started_at=%v", task.TaskID, verifyStatus, verifyRunnerID.Int64, verifyStartedAt)
// Broadcast task assignment - include all fields to ensure frontend has complete info
updateData := map[string]interface{}{
"status": types.TaskStatusRunning,
"runner_id": selectedRunnerID,
"started_at": verifyStartedAt.Time,
}
if !verifyStartedAt.Valid {
updateData["started_at"] = now
}
if s.verboseWSLogging {
log.Printf("Broadcasting task_update for task %d (job %d, user %d): status=%s, runner_id=%d, started_at=%v", task.TaskID, task.JobID, task.JobUserID, types.TaskStatusRunning, selectedRunnerID, now)
}
s.broadcastTaskUpdate(task.JobID, task.TaskID, "task_update", updateData)
// Task was successfully assigned in database, now send via WebSocket
log.Printf("Assigned task %d (type: %s, job: %d) to runner %d", task.TaskID, task.TaskType, task.JobID, selectedRunnerID)
// Log runner assignment to task logs
s.logTaskEvent(task.TaskID, nil, types.LogLevelInfo, fmt.Sprintf("Task assigned to runner %d", selectedRunnerID), "")
// Attempt to send task to runner via WebSocket
if err := s.assignTaskToRunner(selectedRunnerID, task.TaskID); err != nil {
log.Printf("Failed to send task %d to runner %d: %v", task.TaskID, selectedRunnerID, err)
// Log assignment failure
s.logTaskEvent(task.TaskID, nil, types.LogLevelError, fmt.Sprintf("Failed to send task to runner %d: %v", selectedRunnerID, err), "")
// Rollback the assignment if WebSocket send fails with retry mechanism
rollbackSuccess := false
for retry := 0; retry < 3; retry++ {
rollbackErr := s.db.WithTx(func(tx *sql.Tx) error {
_, err := tx.Exec(
`UPDATE tasks SET runner_id = NULL, status = ?, started_at = NULL
WHERE id = ? AND runner_id = ?`,
types.TaskStatusPending, task.TaskID, selectedRunnerID,
)
return err
})
if rollbackErr != nil {
log.Printf("Failed to rollback task %d assignment (attempt %d/3): %v", task.TaskID, retry+1, rollbackErr)
if retry < 2 {
time.Sleep(time.Duration(retry+1) * 100 * time.Millisecond) // Exponential backoff
continue
}
// Final attempt failed
log.Printf("CRITICAL: Failed to rollback task %d after 3 attempts - task may be in inconsistent state", task.TaskID)
s.triggerTaskDistribution()
break
}
// Rollback succeeded
rollbackSuccess = true
s.logTaskEvent(task.TaskID, nil, types.LogLevelWarn, fmt.Sprintf("Task assignment rolled back - runner %d connection failed", selectedRunnerID), "")
s.updateJobStatusFromTasks(task.JobID)
s.triggerTaskDistribution()
break
}
if !rollbackSuccess {
// Schedule background cleanup for inconsistent state
go func() {
time.Sleep(5 * time.Second)
// Retry rollback one more time in background
err := s.db.WithTx(func(tx *sql.Tx) error {
_, err := tx.Exec(
`UPDATE tasks SET runner_id = NULL, status = ?, started_at = NULL
WHERE id = ? AND runner_id = ? AND status = ?`,
types.TaskStatusPending, task.TaskID, selectedRunnerID, types.TaskStatusRunning,
)
return err
})
if err == nil {
log.Printf("Background cleanup: Successfully rolled back task %d", task.TaskID)
s.updateJobStatusFromTasks(task.JobID)
s.triggerTaskDistribution()
}
}()
}
} else {
// WebSocket send succeeded, update job status
s.updateJobStatusFromTasks(task.JobID)
}
}
}
// assignTaskToRunner sends a task to a runner via WebSocket
func (s *Server) assignTaskToRunner(runnerID int64, taskID int64) error {
// Hold read lock during entire operation to prevent connection from being replaced
s.runnerConnsMu.RLock()
conn, exists := s.runnerConns[runnerID]
if !exists {
s.runnerConnsMu.RUnlock()
return fmt.Errorf("runner %d not connected", runnerID)
}
// Keep lock held to prevent connection replacement during operation
defer s.runnerConnsMu.RUnlock()
// Get task details
var task WSTaskAssignment
var jobName string
var outputFormat sql.NullString
var taskType string
err := s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT t.job_id, t.frame_start, t.frame_end, t.task_type, j.name, j.output_format
FROM tasks t JOIN jobs j ON t.job_id = j.id WHERE t.id = ?`,
taskID,
).Scan(&task.JobID, &task.FrameStart, &task.FrameEnd, &taskType, &jobName, &outputFormat)
})
if err != nil {
return err
}
task.TaskID = taskID
task.JobName = jobName
if outputFormat.Valid {
task.OutputFormat = outputFormat.String
log.Printf("Task %d assigned with output_format: '%s' (from job %d)", taskID, outputFormat.String, task.JobID)
} else {
log.Printf("Task %d assigned with no output_format (job %d)", taskID, task.JobID)
}
task.TaskType = taskType
// Get input files
var rows *sql.Rows
err = s.db.With(func(conn *sql.DB) error {
var err error
rows, err = conn.Query(
`SELECT file_path FROM job_files WHERE job_id = ? AND file_type = ?`,
task.JobID, types.JobFileTypeInput,
)
return err
})
if err == nil {
defer rows.Close()
for rows.Next() {
var filePath string
if err := rows.Scan(&filePath); err == nil {
task.InputFiles = append(task.InputFiles, filePath)
} else {
log.Printf("Failed to scan input file path for task %d: %v", taskID, err)
}
}
} else {
log.Printf("Warning: Failed to query input files for task %d (job %d): %v", taskID, task.JobID, err)
}
if len(task.InputFiles) == 0 {
errMsg := fmt.Sprintf("No input files found for task %d (job %d). Cannot assign task without input files.", taskID, task.JobID)
log.Printf("ERROR: %s", errMsg)
// Don't send the task - it will fail anyway
// Rollback the assignment
s.db.With(func(conn *sql.DB) error {
_, _ = conn.Exec(
`UPDATE tasks SET runner_id = NULL, status = ?, started_at = NULL
WHERE id = ?`,
types.TaskStatusPending, taskID,
)
return nil
})
s.logTaskEvent(taskID, nil, types.LogLevelError, errMsg, "")
return errors.New(errMsg)
}
// Note: Task is already assigned in database by the atomic update in distributeTasksToRunners
// We just need to verify it's still assigned to this runner
var assignedRunnerID sql.NullInt64
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskID).Scan(&assignedRunnerID)
})
if err != nil {
return fmt.Errorf("task not found: %w", err)
}
if !assignedRunnerID.Valid || assignedRunnerID.Int64 != runnerID {
return fmt.Errorf("task %d is not assigned to runner %d", taskID, runnerID)
}
// Send task via WebSocket with write mutex protection
msg := WSMessage{
Type: "task_assignment",
Timestamp: time.Now().Unix(),
}
msg.Data, _ = json.Marshal(task)
// Get write mutex for this connection
s.runnerConnsWriteMuMu.RLock()
writeMu, hasMu := s.runnerConnsWriteMu[runnerID]
s.runnerConnsWriteMuMu.RUnlock()
if !hasMu || writeMu == nil {
return fmt.Errorf("runner %d write mutex not found", runnerID)
}
// Connection is still valid (we're holding the read lock)
// Write to connection with mutex protection
writeMu.Lock()
err = conn.WriteJSON(msg)
writeMu.Unlock()
return err
}
// redistributeRunnerTasks resets tasks assigned to a disconnected/dead runner and redistributes them
func (s *Server) redistributeRunnerTasks(runnerID int64) {
log.Printf("Starting task redistribution for disconnected runner %d", runnerID)
// Get tasks assigned to this runner that are still running
var taskRows *sql.Rows
err := s.db.With(func(conn *sql.DB) error {
var err error
taskRows, err = conn.Query(
`SELECT id, retry_count, max_retries, job_id FROM tasks
WHERE runner_id = ? AND status = ?`,
runnerID, types.TaskStatusRunning,
)
return err
})
if err != nil {
log.Printf("Failed to query tasks for runner %d: %v", runnerID, err)
return
}
defer taskRows.Close()
var tasksToReset []struct {
ID int64
RetryCount int
MaxRetries int
JobID int64
}
for taskRows.Next() {
var t struct {
ID int64
RetryCount int
MaxRetries int
JobID int64
}
if err := taskRows.Scan(&t.ID, &t.RetryCount, &t.MaxRetries, &t.JobID); err != nil {
log.Printf("Failed to scan task for runner %d: %v", runnerID, err)
continue
}
tasksToReset = append(tasksToReset, t)
}
if len(tasksToReset) == 0 {
log.Printf("No running tasks found for runner %d to redistribute", runnerID)
return
}
log.Printf("Redistributing %d running tasks from disconnected runner %d", len(tasksToReset), runnerID)
// Reset or fail tasks
resetCount := 0
failedCount := 0
for _, task := range tasksToReset {
if task.RetryCount >= task.MaxRetries {
// Mark as failed
err = s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(`UPDATE tasks SET status = ? WHERE id = ? AND runner_id = ?`, types.TaskStatusFailed, task.ID, runnerID)
if err != nil {
return err
}
_, err = conn.Exec(`UPDATE tasks SET error_message = ? WHERE id = ? AND runner_id = ?`, "Runner disconnected, max retries exceeded", task.ID, runnerID)
if err != nil {
return err
}
_, err = conn.Exec(`UPDATE tasks SET runner_id = NULL WHERE id = ? AND runner_id = ?`, task.ID, runnerID)
if err != nil {
return err
}
_, err = conn.Exec(`UPDATE tasks SET completed_at = ? WHERE id = ?`, time.Now(), task.ID)
return err
})
if err != nil {
log.Printf("Failed to mark task %d as failed: %v", task.ID, err)
} else {
failedCount++
// Log task failure
s.logTaskEvent(task.ID, &runnerID, types.LogLevelError,
fmt.Sprintf("Task failed - runner %d disconnected, max retries (%d) exceeded", runnerID, task.MaxRetries), "")
}
} else {
// Reset to pending so it can be redistributed
err = s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL,
retry_count = retry_count + 1, started_at = NULL WHERE id = ? AND runner_id = ?`,
types.TaskStatusPending, task.ID, runnerID,
)
return err
})
if err != nil {
log.Printf("Failed to reset task %d: %v", task.ID, err)
} else {
resetCount++
// Log task reset for redistribution
s.logTaskEvent(task.ID, &runnerID, types.LogLevelWarn,
fmt.Sprintf("Runner %d disconnected, task reset for redistribution (retry %d/%d)", runnerID, task.RetryCount+1, task.MaxRetries), "")
}
}
}
log.Printf("Task redistribution complete for runner %d: %d tasks reset, %d tasks failed", runnerID, resetCount, failedCount)
// Update job statuses for affected jobs
jobIDs := make(map[int64]bool)
for _, task := range tasksToReset {
jobIDs[task.JobID] = true
}
for jobID := range jobIDs {
// Update job status based on remaining tasks
go s.updateJobStatusFromTasks(jobID)
}
// Immediately redistribute the reset tasks
if resetCount > 0 {
log.Printf("Triggering task distribution for %d reset tasks from runner %d", resetCount, runnerID)
s.triggerTaskDistribution()
}
}
// logTaskEvent logs an event to a task's log (manager-side logging)
func (s *Server) logTaskEvent(taskID int64, runnerID *int64, logLevel types.LogLevel, message, stepName string) {
var runnerIDValue interface{}
if runnerID != nil {
runnerIDValue = *runnerID
}
err := s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`INSERT INTO task_logs (task_id, runner_id, log_level, message, step_name, created_at)
VALUES (?, ?, ?, ?, ?, ?)`,
taskID, runnerIDValue, logLevel, message, stepName, time.Now(),
)
return err
})
if err != nil {
log.Printf("Failed to log task event for task %d: %v", taskID, err)
return
}
// Broadcast to frontend if there are connected clients
s.broadcastLogToFrontend(taskID, WSLogEntry{
TaskID: taskID,
LogLevel: string(logLevel),
Message: message,
StepName: stepName,
})
}
// cleanupOldOfflineRunners periodically deletes runners that have been offline for more than 1 month
func (s *Server) cleanupOldOfflineRunners() {
// Run cleanup every 24 hours
ticker := time.NewTicker(24 * time.Hour)
defer ticker.Stop()
// Run once immediately on startup
s.cleanupOldOfflineRunnersOnce()
for range ticker.C {
s.cleanupOldOfflineRunnersOnce()
}
}
// cleanupOldOfflineRunnersOnce finds and deletes runners that have been offline for more than 1 month
func (s *Server) cleanupOldOfflineRunnersOnce() {
defer func() {
if r := recover(); r != nil {
log.Printf("Panic in cleanupOldOfflineRunners: %v", r)
}
}()
// Find runners that:
// 1. Are offline
// 2. Haven't had a heartbeat in over 1 month
// 3. Are not currently connected via WebSocket
var rows *sql.Rows
err := s.db.With(func(conn *sql.DB) error {
var err error
rows, err = conn.Query(
`SELECT id, name FROM runners
WHERE status = ?
AND last_heartbeat < datetime('now', '-1 month')`,
types.RunnerStatusOffline,
)
return err
})
if err != nil {
log.Printf("Failed to query old offline runners: %v", err)
return
}
defer rows.Close()
type runnerInfo struct {
ID int64
Name string
}
var runnersToDelete []runnerInfo
s.runnerConnsMu.RLock()
for rows.Next() {
var info runnerInfo
if err := rows.Scan(&info.ID, &info.Name); err == nil {
// Double-check runner is not connected via WebSocket
if _, connected := s.runnerConns[info.ID]; !connected {
runnersToDelete = append(runnersToDelete, info)
}
}
}
s.runnerConnsMu.RUnlock()
rows.Close()
if len(runnersToDelete) == 0 {
return
}
log.Printf("Cleaning up %d old offline runners (offline for more than 1 month)", len(runnersToDelete))
// Delete each runner
for _, runner := range runnersToDelete {
// First, check if there are any tasks still assigned to this runner
// If so, reset them to pending before deleting the runner
var assignedTaskCount int
err := s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status IN (?, ?)`,
runner.ID, types.TaskStatusRunning, types.TaskStatusPending,
).Scan(&assignedTaskCount)
})
if err != nil {
log.Printf("Failed to check assigned tasks for runner %d: %v", runner.ID, err)
continue
}
if assignedTaskCount > 0 {
// Reset any tasks assigned to this runner
log.Printf("Resetting %d tasks assigned to runner %d before deletion", assignedTaskCount, runner.ID)
err = s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`UPDATE tasks SET runner_id = NULL, status = ? WHERE runner_id = ? AND status IN (?, ?)`,
types.TaskStatusPending, runner.ID, types.TaskStatusRunning, types.TaskStatusPending,
)
return err
})
if err != nil {
log.Printf("Failed to reset tasks for runner %d: %v", runner.ID, err)
continue
}
}
// Delete the runner
err = s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec("DELETE FROM runners WHERE id = ?", runner.ID)
return err
})
if err != nil {
log.Printf("Failed to delete runner %d (%s): %v", runner.ID, runner.Name, err)
continue
}
log.Printf("Deleted old offline runner: %d (%s)", runner.ID, runner.Name)
}
// Trigger task distribution if any tasks were reset
s.triggerTaskDistribution()
}