1989 lines
63 KiB
Go
1989 lines
63 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"math/rand"
|
|
"net/http"
|
|
"net/url"
|
|
"path/filepath"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"jiggablend/pkg/types"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const runnerIDContextKey contextKey = "runner_id"
|
|
|
|
// runnerAuthMiddleware verifies runner requests using shared secret header
|
|
func (s *Server) runnerAuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
// Get runner ID from query string
|
|
runnerIDStr := r.URL.Query().Get("runner_id")
|
|
if runnerIDStr == "" {
|
|
s.respondError(w, http.StatusBadRequest, "runner_id required in query string")
|
|
return
|
|
}
|
|
|
|
var runnerID int64
|
|
_, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, "invalid runner_id")
|
|
return
|
|
}
|
|
|
|
// Get runner secret
|
|
runnerSecret, err := s.secrets.GetRunnerSecret(runnerID)
|
|
if err != nil {
|
|
log.Printf("Failed to get runner secret for runner %d: %v", runnerID, err)
|
|
s.respondError(w, http.StatusUnauthorized, "runner not found or not verified")
|
|
return
|
|
}
|
|
|
|
// Verify shared secret from header
|
|
providedSecret := r.Header.Get("X-Runner-Secret")
|
|
if providedSecret == "" {
|
|
s.respondError(w, http.StatusUnauthorized, "missing secret")
|
|
return
|
|
}
|
|
|
|
if providedSecret != runnerSecret {
|
|
s.respondError(w, http.StatusUnauthorized, "invalid secret")
|
|
return
|
|
}
|
|
|
|
// Add runner ID to context
|
|
ctx := r.Context()
|
|
ctx = context.WithValue(ctx, runnerIDContextKey, runnerID)
|
|
next(w, r.WithContext(ctx))
|
|
}
|
|
}
|
|
|
|
// handleRegisterRunner registers a new runner
|
|
// Note: Token expiration only affects whether the token can be used for registration.
|
|
// Once a runner is registered, it receives its own runner_secret and manager_secret
|
|
// and operates independently. The token expiration does not affect registered runners.
|
|
func (s *Server) handleRegisterRunner(w http.ResponseWriter, r *http.Request) {
|
|
var req struct {
|
|
types.RegisterRunnerRequest
|
|
RegistrationToken string `json:"registration_token"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
|
return
|
|
}
|
|
|
|
if req.Name == "" {
|
|
s.respondError(w, http.StatusBadRequest, "Runner name is required")
|
|
return
|
|
}
|
|
|
|
if req.RegistrationToken == "" {
|
|
s.respondError(w, http.StatusBadRequest, "Registration token is required")
|
|
return
|
|
}
|
|
|
|
// Validate registration token (expiration only affects token usability, not registered runners)
|
|
result, err := s.secrets.ValidateRegistrationTokenDetailed(req.RegistrationToken)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to validate token: %v", err))
|
|
return
|
|
}
|
|
if !result.Valid {
|
|
var errorMsg string
|
|
switch result.Reason {
|
|
case "already_used":
|
|
errorMsg = "Registration token has already been used"
|
|
case "expired":
|
|
errorMsg = "Registration token has expired"
|
|
case "not_found":
|
|
errorMsg = "Invalid registration token"
|
|
default:
|
|
errorMsg = "Invalid or expired registration token"
|
|
}
|
|
s.respondError(w, http.StatusUnauthorized, errorMsg)
|
|
return
|
|
}
|
|
|
|
// Get manager secret
|
|
managerSecret, err := s.secrets.GetManagerSecret()
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, "Failed to get manager secret")
|
|
return
|
|
}
|
|
|
|
// Generate runner secret (runner will use this for all future authentication, independent of token)
|
|
runnerSecret, err := s.secrets.GenerateRunnerSecret()
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, "Failed to generate runner secret")
|
|
return
|
|
}
|
|
|
|
// Set default priority if not provided
|
|
priority := 100
|
|
if req.Priority != nil {
|
|
priority = *req.Priority
|
|
}
|
|
|
|
// Register runner
|
|
var runnerID int64
|
|
err = s.db.QueryRow(
|
|
`INSERT INTO runners (name, hostname, ip_address, status, last_heartbeat, capabilities,
|
|
registration_token, runner_secret, manager_secret, verified, priority)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
RETURNING id`,
|
|
req.Name, req.Hostname, "", types.RunnerStatusOnline, time.Now(), req.Capabilities,
|
|
req.RegistrationToken, runnerSecret, managerSecret, true, priority,
|
|
).Scan(&runnerID)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to register runner: %v", err))
|
|
return
|
|
}
|
|
|
|
// Return runner info with secrets
|
|
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
|
|
"id": runnerID,
|
|
"name": req.Name,
|
|
"hostname": req.Hostname,
|
|
"status": types.RunnerStatusOnline,
|
|
"runner_secret": runnerSecret,
|
|
"manager_secret": managerSecret,
|
|
"verified": true,
|
|
})
|
|
}
|
|
|
|
// handleUpdateTaskProgress updates task progress
|
|
func (s *Server) handleUpdateTaskProgress(w http.ResponseWriter, r *http.Request) {
|
|
_, err := parseID(r, "id")
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
var req struct {
|
|
Progress float64 `json:"progress"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
|
return
|
|
}
|
|
|
|
// This is mainly for logging/debugging, actual progress is calculated from completed tasks
|
|
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Progress updated"})
|
|
}
|
|
|
|
// handleUpdateTaskStep handles step start/complete events from runners
|
|
func (s *Server) handleUpdateTaskStep(w http.ResponseWriter, r *http.Request) {
|
|
// Get runner ID from context (set by runnerAuthMiddleware)
|
|
runnerID, ok := r.Context().Value(runnerIDContextKey).(int64)
|
|
if !ok {
|
|
s.respondError(w, http.StatusUnauthorized, "runner_id not found in context")
|
|
return
|
|
}
|
|
|
|
taskID, err := parseID(r, "id")
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
var req struct {
|
|
StepName string `json:"step_name"`
|
|
Status string `json:"status"` // "pending", "running", "completed", "failed", "skipped"
|
|
DurationMs *int `json:"duration_ms,omitempty"`
|
|
ErrorMessage string `json:"error_message,omitempty"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
|
return
|
|
}
|
|
|
|
// Verify task belongs to runner
|
|
var taskRunnerID sql.NullInt64
|
|
err = s.db.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskID).Scan(&taskRunnerID)
|
|
if err == sql.ErrNoRows {
|
|
s.respondError(w, http.StatusNotFound, "Task not found")
|
|
return
|
|
}
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify task: %v", err))
|
|
return
|
|
}
|
|
if !taskRunnerID.Valid || taskRunnerID.Int64 != runnerID {
|
|
s.respondError(w, http.StatusForbidden, "Task does not belong to this runner")
|
|
return
|
|
}
|
|
|
|
now := time.Now()
|
|
var stepID int64
|
|
|
|
// Check if step already exists
|
|
var existingStepID sql.NullInt64
|
|
err = s.db.QueryRow(
|
|
`SELECT id FROM task_steps WHERE task_id = ? AND step_name = ?`,
|
|
taskID, req.StepName,
|
|
).Scan(&existingStepID)
|
|
|
|
if err == sql.ErrNoRows || !existingStepID.Valid {
|
|
// Create new step
|
|
var startedAt *time.Time
|
|
var completedAt *time.Time
|
|
if req.Status == string(types.StepStatusRunning) || req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) {
|
|
startedAt = &now
|
|
}
|
|
if req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) {
|
|
completedAt = &now
|
|
}
|
|
|
|
err = s.db.QueryRow(
|
|
`INSERT INTO task_steps (task_id, step_name, status, started_at, completed_at, duration_ms, error_message)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
RETURNING id`,
|
|
taskID, req.StepName, req.Status, startedAt, completedAt, req.DurationMs, req.ErrorMessage,
|
|
).Scan(&stepID)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create step: %v", err))
|
|
return
|
|
}
|
|
} else {
|
|
// Update existing step
|
|
stepID = existingStepID.Int64
|
|
var startedAt *time.Time
|
|
var completedAt *time.Time
|
|
|
|
// Get existing started_at if status is running/completed/failed
|
|
if req.Status == string(types.StepStatusRunning) || req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) {
|
|
var existingStartedAt sql.NullTime
|
|
s.db.QueryRow(`SELECT started_at FROM task_steps WHERE id = ?`, stepID).Scan(&existingStartedAt)
|
|
if existingStartedAt.Valid {
|
|
startedAt = &existingStartedAt.Time
|
|
} else {
|
|
startedAt = &now
|
|
}
|
|
}
|
|
|
|
if req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) {
|
|
completedAt = &now
|
|
}
|
|
|
|
_, err = s.db.Exec(
|
|
`UPDATE task_steps SET status = ?, started_at = ?, completed_at = ?, duration_ms = ?, error_message = ?
|
|
WHERE id = ?`,
|
|
req.Status, startedAt, completedAt, req.DurationMs, req.ErrorMessage, stepID,
|
|
)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update step: %v", err))
|
|
return
|
|
}
|
|
}
|
|
|
|
// Get job ID for broadcasting
|
|
var jobID int64
|
|
err = s.db.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&jobID)
|
|
if err == nil {
|
|
// Broadcast step update to frontend
|
|
s.broadcastTaskUpdate(jobID, taskID, "step_update", map[string]interface{}{
|
|
"step_id": stepID,
|
|
"step_name": req.StepName,
|
|
"status": req.Status,
|
|
"duration_ms": req.DurationMs,
|
|
"error_message": req.ErrorMessage,
|
|
})
|
|
}
|
|
|
|
s.respondJSON(w, http.StatusOK, map[string]interface{}{
|
|
"step_id": stepID,
|
|
"message": "Step updated successfully",
|
|
})
|
|
}
|
|
|
|
// handleDownloadJobContext allows runners to download the job context tar
|
|
func (s *Server) handleDownloadJobContext(w http.ResponseWriter, r *http.Request) {
|
|
jobID, err := parseID(r, "jobId")
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
// Construct the context file path
|
|
contextPath := filepath.Join(s.storage.JobPath(jobID), "context.tar")
|
|
|
|
// Check if context file exists
|
|
if !s.storage.FileExists(contextPath) {
|
|
log.Printf("Context archive not found for job %d", jobID)
|
|
s.respondError(w, http.StatusNotFound, "Context archive not found. The file may not have been uploaded successfully.")
|
|
return
|
|
}
|
|
|
|
// Open and serve file
|
|
file, err := s.storage.GetFile(contextPath)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusNotFound, "Context file not found on disk")
|
|
return
|
|
}
|
|
defer file.Close()
|
|
|
|
// Set appropriate headers for tar file
|
|
w.Header().Set("Content-Type", "application/x-tar")
|
|
w.Header().Set("Content-Disposition", "attachment; filename=context.tar")
|
|
|
|
// Stream the file to the response
|
|
io.Copy(w, file)
|
|
}
|
|
|
|
// handleUploadFileFromRunner allows runners to upload output files
|
|
func (s *Server) handleUploadFileFromRunner(w http.ResponseWriter, r *http.Request) {
|
|
jobID, err := parseID(r, "jobId")
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
err = r.ParseMultipartForm(50 << 30) // 50 GB (for large output files)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, "Failed to parse form")
|
|
return
|
|
}
|
|
|
|
file, header, err := r.FormFile("file")
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, "No file provided")
|
|
return
|
|
}
|
|
defer file.Close()
|
|
|
|
// Save file
|
|
filePath, err := s.storage.SaveOutput(jobID, header.Filename, file)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to save file: %v", err))
|
|
return
|
|
}
|
|
|
|
// Record in database
|
|
var fileID int64
|
|
err = s.db.QueryRow(
|
|
`INSERT INTO job_files (job_id, file_type, file_path, file_name, file_size)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
RETURNING id`,
|
|
jobID, types.JobFileTypeOutput, filePath, header.Filename, header.Size,
|
|
).Scan(&fileID)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to record file: %v", err))
|
|
return
|
|
}
|
|
|
|
// Broadcast file addition
|
|
s.broadcastJobUpdate(jobID, "file_added", map[string]interface{}{
|
|
"file_id": fileID,
|
|
"file_type": types.JobFileTypeOutput,
|
|
"file_name": header.Filename,
|
|
"file_size": header.Size,
|
|
})
|
|
|
|
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
|
|
"file_path": filePath,
|
|
"file_name": header.Filename,
|
|
})
|
|
}
|
|
|
|
// handleGetJobStatusForRunner allows runners to check job status
|
|
func (s *Server) handleGetJobStatusForRunner(w http.ResponseWriter, r *http.Request) {
|
|
jobID, err := parseID(r, "jobId")
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
var job types.Job
|
|
var startedAt, completedAt sql.NullTime
|
|
var errorMessage sql.NullString
|
|
|
|
var jobType string
|
|
var frameStart, frameEnd sql.NullInt64
|
|
var outputFormat sql.NullString
|
|
var allowParallelRunners sql.NullBool
|
|
err = s.db.QueryRow(
|
|
`SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format,
|
|
allow_parallel_runners, created_at, started_at, completed_at, error_message
|
|
FROM jobs WHERE id = ?`,
|
|
jobID,
|
|
).Scan(
|
|
&job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress,
|
|
&frameStart, &frameEnd, &outputFormat, &allowParallelRunners,
|
|
&job.CreatedAt, &startedAt, &completedAt, &errorMessage,
|
|
)
|
|
|
|
job.JobType = types.JobType(jobType)
|
|
if frameStart.Valid {
|
|
fs := int(frameStart.Int64)
|
|
job.FrameStart = &fs
|
|
}
|
|
if frameEnd.Valid {
|
|
fe := int(frameEnd.Int64)
|
|
job.FrameEnd = &fe
|
|
}
|
|
if outputFormat.Valid {
|
|
job.OutputFormat = &outputFormat.String
|
|
}
|
|
if allowParallelRunners.Valid {
|
|
job.AllowParallelRunners = &allowParallelRunners.Bool
|
|
}
|
|
|
|
if err == sql.ErrNoRows {
|
|
s.respondError(w, http.StatusNotFound, "Job not found")
|
|
return
|
|
}
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query job: %v", err))
|
|
return
|
|
}
|
|
|
|
if startedAt.Valid {
|
|
job.StartedAt = &startedAt.Time
|
|
}
|
|
if completedAt.Valid {
|
|
job.CompletedAt = &completedAt.Time
|
|
}
|
|
if errorMessage.Valid {
|
|
job.ErrorMessage = errorMessage.String
|
|
}
|
|
|
|
s.respondJSON(w, http.StatusOK, job)
|
|
}
|
|
|
|
// handleGetJobFilesForRunner allows runners to get job files
|
|
func (s *Server) handleGetJobFilesForRunner(w http.ResponseWriter, r *http.Request) {
|
|
jobID, err := parseID(r, "jobId")
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
rows, err := s.db.Query(
|
|
`SELECT id, job_id, file_type, file_path, file_name, file_size, created_at
|
|
FROM job_files WHERE job_id = ? ORDER BY file_name`,
|
|
jobID,
|
|
)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query files: %v", err))
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
files := []types.JobFile{}
|
|
for rows.Next() {
|
|
var file types.JobFile
|
|
err := rows.Scan(
|
|
&file.ID, &file.JobID, &file.FileType, &file.FilePath,
|
|
&file.FileName, &file.FileSize, &file.CreatedAt,
|
|
)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan file: %v", err))
|
|
return
|
|
}
|
|
files = append(files, file)
|
|
}
|
|
|
|
s.respondJSON(w, http.StatusOK, files)
|
|
}
|
|
|
|
// handleGetJobMetadataForRunner allows runners to get job metadata
|
|
func (s *Server) handleGetJobMetadataForRunner(w http.ResponseWriter, r *http.Request) {
|
|
jobID, err := parseID(r, "jobId")
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
var blendMetadataJSON sql.NullString
|
|
err = s.db.QueryRow(
|
|
`SELECT blend_metadata FROM jobs WHERE id = ?`,
|
|
jobID,
|
|
).Scan(&blendMetadataJSON)
|
|
|
|
if err == sql.ErrNoRows {
|
|
s.respondError(w, http.StatusNotFound, "Job not found")
|
|
return
|
|
}
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query job: %v", err))
|
|
return
|
|
}
|
|
|
|
if !blendMetadataJSON.Valid || blendMetadataJSON.String == "" {
|
|
s.respondJSON(w, http.StatusOK, nil)
|
|
return
|
|
}
|
|
|
|
var metadata types.BlendMetadata
|
|
if err := json.Unmarshal([]byte(blendMetadataJSON.String), &metadata); err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, "Failed to parse metadata")
|
|
return
|
|
}
|
|
|
|
s.respondJSON(w, http.StatusOK, metadata)
|
|
}
|
|
|
|
// handleDownloadFileForRunner allows runners to download a file by fileName
|
|
func (s *Server) handleDownloadFileForRunner(w http.ResponseWriter, r *http.Request) {
|
|
jobID, err := parseID(r, "jobId")
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
// Get fileName from URL path (may need URL decoding)
|
|
fileName := chi.URLParam(r, "fileName")
|
|
if fileName == "" {
|
|
s.respondError(w, http.StatusBadRequest, "fileName is required")
|
|
return
|
|
}
|
|
|
|
// URL decode the fileName in case it contains encoded characters
|
|
decodedFileName, err := url.QueryUnescape(fileName)
|
|
if err != nil {
|
|
// If decoding fails, use original fileName
|
|
decodedFileName = fileName
|
|
}
|
|
|
|
// Get file info from database
|
|
var filePath string
|
|
err = s.db.QueryRow(
|
|
`SELECT file_path FROM job_files WHERE job_id = ? AND file_name = ?`,
|
|
jobID, decodedFileName,
|
|
).Scan(&filePath)
|
|
if err == sql.ErrNoRows {
|
|
s.respondError(w, http.StatusNotFound, "File not found")
|
|
return
|
|
}
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query file: %v", err))
|
|
return
|
|
}
|
|
|
|
// Open file
|
|
file, err := s.storage.GetFile(filePath)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusNotFound, "File not found on disk")
|
|
return
|
|
}
|
|
defer file.Close()
|
|
|
|
// Determine content type based on file extension
|
|
contentType := "application/octet-stream"
|
|
fileNameLower := strings.ToLower(decodedFileName)
|
|
switch {
|
|
case strings.HasSuffix(fileNameLower, ".png"):
|
|
contentType = "image/png"
|
|
case strings.HasSuffix(fileNameLower, ".jpg") || strings.HasSuffix(fileNameLower, ".jpeg"):
|
|
contentType = "image/jpeg"
|
|
case strings.HasSuffix(fileNameLower, ".gif"):
|
|
contentType = "image/gif"
|
|
case strings.HasSuffix(fileNameLower, ".webp"):
|
|
contentType = "image/webp"
|
|
case strings.HasSuffix(fileNameLower, ".exr") || strings.HasSuffix(fileNameLower, ".EXR"):
|
|
contentType = "image/x-exr"
|
|
case strings.HasSuffix(fileNameLower, ".mp4"):
|
|
contentType = "video/mp4"
|
|
case strings.HasSuffix(fileNameLower, ".webm"):
|
|
contentType = "video/webm"
|
|
}
|
|
|
|
// Set headers
|
|
w.Header().Set("Content-Type", contentType)
|
|
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", decodedFileName))
|
|
|
|
// Stream file
|
|
io.Copy(w, file)
|
|
}
|
|
|
|
// WebSocket message types
|
|
type WSMessage struct {
|
|
Type string `json:"type"`
|
|
Data json.RawMessage `json:"data"`
|
|
Timestamp int64 `json:"timestamp"`
|
|
}
|
|
|
|
type WSTaskAssignment struct {
|
|
TaskID int64 `json:"task_id"`
|
|
JobID int64 `json:"job_id"`
|
|
JobName string `json:"job_name"`
|
|
OutputFormat string `json:"output_format"`
|
|
FrameStart int `json:"frame_start"`
|
|
FrameEnd int `json:"frame_end"`
|
|
TaskType string `json:"task_type"`
|
|
InputFiles []string `json:"input_files"`
|
|
}
|
|
|
|
type WSLogEntry struct {
|
|
TaskID int64 `json:"task_id"`
|
|
LogLevel string `json:"log_level"`
|
|
Message string `json:"message"`
|
|
StepName string `json:"step_name,omitempty"`
|
|
}
|
|
|
|
type WSTaskUpdate struct {
|
|
TaskID int64 `json:"task_id"`
|
|
Status string `json:"status"`
|
|
OutputPath string `json:"output_path,omitempty"`
|
|
Success bool `json:"success"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
// handleRunnerWebSocket handles WebSocket connections from runners
|
|
func (s *Server) handleRunnerWebSocket(w http.ResponseWriter, r *http.Request) {
|
|
// Get runner ID and secret from query params
|
|
runnerIDStr := r.URL.Query().Get("runner_id")
|
|
providedSecret := r.URL.Query().Get("secret")
|
|
|
|
if runnerIDStr == "" || providedSecret == "" {
|
|
s.respondError(w, http.StatusBadRequest, "runner_id and secret required")
|
|
return
|
|
}
|
|
|
|
var runnerID int64
|
|
_, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, "invalid runner_id")
|
|
return
|
|
}
|
|
|
|
// Get runner secret
|
|
runnerSecret, err := s.secrets.GetRunnerSecret(runnerID)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusUnauthorized, "runner not found or not verified")
|
|
return
|
|
}
|
|
|
|
// Verify shared secret
|
|
if providedSecret != runnerSecret {
|
|
s.respondError(w, http.StatusUnauthorized, "invalid secret")
|
|
return
|
|
}
|
|
|
|
// Upgrade to WebSocket
|
|
conn, err := s.wsUpgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
log.Printf("Failed to upgrade WebSocket: %v", err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Register connection (must be done before any distribution checks)
|
|
// Close old connection outside lock to avoid blocking
|
|
var oldConn *websocket.Conn
|
|
s.runnerConnsMu.Lock()
|
|
if existingConn, exists := s.runnerConns[runnerID]; exists {
|
|
oldConn = existingConn
|
|
}
|
|
s.runnerConns[runnerID] = conn
|
|
s.runnerConnsMu.Unlock()
|
|
|
|
// Close old connection outside lock (if it existed)
|
|
if oldConn != nil {
|
|
oldConn.Close()
|
|
}
|
|
|
|
// Create a write mutex for this connection
|
|
s.runnerConnsWriteMuMu.Lock()
|
|
s.runnerConnsWriteMu[runnerID] = &sync.Mutex{}
|
|
s.runnerConnsWriteMuMu.Unlock()
|
|
|
|
// Update runner status to online
|
|
_, _ = s.db.Exec(
|
|
`UPDATE runners SET status = ?, last_heartbeat = ? WHERE id = ?`,
|
|
types.RunnerStatusOnline, time.Now(), runnerID,
|
|
)
|
|
|
|
// Immediately try to distribute pending tasks to this newly connected runner
|
|
log.Printf("Runner %d connected, distributing pending tasks", runnerID)
|
|
s.triggerTaskDistribution()
|
|
|
|
// Note: We don't log to task logs here because we don't know which tasks will be assigned yet
|
|
// Task assignment logging happens in distributeTasksToRunners
|
|
|
|
// Cleanup on disconnect
|
|
defer func() {
|
|
s.runnerConnsMu.Lock()
|
|
delete(s.runnerConns, runnerID)
|
|
s.runnerConnsMu.Unlock()
|
|
s.runnerConnsWriteMuMu.Lock()
|
|
delete(s.runnerConnsWriteMu, runnerID)
|
|
s.runnerConnsWriteMuMu.Unlock()
|
|
_, _ = s.db.Exec(
|
|
`UPDATE runners SET status = ? WHERE id = ?`,
|
|
types.RunnerStatusOffline, runnerID,
|
|
)
|
|
|
|
// Immediately redistribute tasks that were assigned to this runner
|
|
log.Printf("Runner %d disconnected, redistributing its tasks", runnerID)
|
|
s.redistributeRunnerTasks(runnerID)
|
|
}()
|
|
|
|
// Set pong handler to update heartbeat when we receive pong responses from runner
|
|
// Also reset read deadline to keep connection alive
|
|
conn.SetPongHandler(func(string) error {
|
|
conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Increased to 90 seconds
|
|
_, _ = s.db.Exec(
|
|
`UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`,
|
|
time.Now(), types.RunnerStatusOnline, runnerID,
|
|
)
|
|
return nil
|
|
})
|
|
|
|
// Set read deadline to ensure we process control frames (like pong)
|
|
conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Increased to 90 seconds
|
|
|
|
// Send ping every 30 seconds to trigger pong responses
|
|
go func() {
|
|
ticker := time.NewTicker(30 * time.Second)
|
|
defer ticker.Stop()
|
|
for range ticker.C {
|
|
s.runnerConnsMu.RLock()
|
|
currentConn, exists := s.runnerConns[runnerID]
|
|
s.runnerConnsMu.RUnlock()
|
|
if !exists || currentConn != conn {
|
|
// Connection was replaced or removed
|
|
return
|
|
}
|
|
// Get write mutex for this connection
|
|
s.runnerConnsWriteMuMu.RLock()
|
|
writeMu, hasMu := s.runnerConnsWriteMu[runnerID]
|
|
s.runnerConnsWriteMuMu.RUnlock()
|
|
if !hasMu || writeMu == nil {
|
|
return
|
|
}
|
|
// Send ping - runner should respond with pong automatically
|
|
// Reset read deadline before sending ping to ensure we can receive pong
|
|
conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Increased to 90 seconds
|
|
writeMu.Lock()
|
|
err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second))
|
|
writeMu.Unlock()
|
|
if err != nil {
|
|
// Write failed - connection is likely dead, read loop will detect and cleanup
|
|
log.Printf("Failed to send ping to runner %d: %v", runnerID, err)
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Handle incoming messages
|
|
for {
|
|
// Reset read deadline for each message - this is critical to keep connection alive
|
|
conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Increased to 90 seconds for safety
|
|
|
|
var msg WSMessage
|
|
err := conn.ReadJSON(&msg)
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
log.Printf("WebSocket error for runner %d: %v", runnerID, err)
|
|
}
|
|
break
|
|
}
|
|
|
|
// Reset read deadline after successfully reading a message
|
|
// This ensures the connection stays alive as long as we're receiving messages
|
|
conn.SetReadDeadline(time.Now().Add(90 * time.Second))
|
|
|
|
switch msg.Type {
|
|
case "heartbeat":
|
|
// Update heartbeat from explicit heartbeat message
|
|
// Reset read deadline to keep connection alive
|
|
conn.SetReadDeadline(time.Now().Add(90 * time.Second))
|
|
_, _ = s.db.Exec(
|
|
`UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`,
|
|
time.Now(), types.RunnerStatusOnline, runnerID,
|
|
)
|
|
|
|
case "log_entry":
|
|
var logEntry WSLogEntry
|
|
if err := json.Unmarshal(msg.Data, &logEntry); err == nil {
|
|
s.handleWebSocketLog(runnerID, logEntry)
|
|
}
|
|
|
|
case "task_update":
|
|
var taskUpdate WSTaskUpdate
|
|
if err := json.Unmarshal(msg.Data, &taskUpdate); err == nil {
|
|
s.handleWebSocketTaskUpdate(runnerID, taskUpdate)
|
|
}
|
|
|
|
case "task_complete":
|
|
var taskUpdate WSTaskUpdate
|
|
if err := json.Unmarshal(msg.Data, &taskUpdate); err == nil {
|
|
s.handleWebSocketTaskComplete(runnerID, taskUpdate)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleWebSocketLog handles log entries from WebSocket
|
|
func (s *Server) handleWebSocketLog(runnerID int64, logEntry WSLogEntry) {
|
|
// Store log in database
|
|
_, err := s.db.Exec(
|
|
`INSERT INTO task_logs (task_id, runner_id, log_level, message, step_name, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)`,
|
|
logEntry.TaskID, runnerID, logEntry.LogLevel, logEntry.Message, logEntry.StepName, time.Now(),
|
|
)
|
|
if err != nil {
|
|
log.Printf("Failed to store log: %v", err)
|
|
return
|
|
}
|
|
|
|
// Broadcast to frontend clients
|
|
s.broadcastLogToFrontend(logEntry.TaskID, logEntry)
|
|
|
|
// If this log contains a frame number (Fra:), update progress for single-runner render jobs
|
|
if strings.Contains(logEntry.Message, "Fra:") {
|
|
// Get job ID from task
|
|
var jobID int64
|
|
err := s.db.QueryRow("SELECT job_id FROM tasks WHERE id = ?", logEntry.TaskID).Scan(&jobID)
|
|
if err == nil {
|
|
// Throttle progress updates (max once per 2 seconds per job)
|
|
s.progressUpdateTimesMu.RLock()
|
|
lastUpdate, exists := s.progressUpdateTimes[jobID]
|
|
s.progressUpdateTimesMu.RUnlock()
|
|
|
|
shouldUpdate := !exists || time.Since(lastUpdate) >= 2*time.Second
|
|
if shouldUpdate {
|
|
s.progressUpdateTimesMu.Lock()
|
|
s.progressUpdateTimes[jobID] = time.Now()
|
|
s.progressUpdateTimesMu.Unlock()
|
|
|
|
// Update progress in background to avoid blocking log processing
|
|
go s.updateJobStatusFromTasks(jobID)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleWebSocketTaskUpdate handles task status updates from WebSocket
|
|
func (s *Server) handleWebSocketTaskUpdate(runnerID int64, taskUpdate WSTaskUpdate) {
|
|
// This can be used for progress updates
|
|
// For now, we'll just log it
|
|
log.Printf("Task %d update from runner %d: %s", taskUpdate.TaskID, runnerID, taskUpdate.Status)
|
|
}
|
|
|
|
// handleWebSocketTaskComplete handles task completion from WebSocket
|
|
func (s *Server) handleWebSocketTaskComplete(runnerID int64, taskUpdate WSTaskUpdate) {
|
|
// Verify task belongs to runner
|
|
var taskRunnerID sql.NullInt64
|
|
err := s.db.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskUpdate.TaskID).Scan(&taskRunnerID)
|
|
if err != nil || !taskRunnerID.Valid || taskRunnerID.Int64 != runnerID {
|
|
log.Printf("Task %d does not belong to runner %d", taskUpdate.TaskID, runnerID)
|
|
return
|
|
}
|
|
|
|
status := types.TaskStatusCompleted
|
|
if !taskUpdate.Success {
|
|
status = types.TaskStatusFailed
|
|
}
|
|
|
|
now := time.Now()
|
|
_, err = s.db.Exec(
|
|
`UPDATE tasks SET status = ?, output_path = ?, completed_at = ?, error_message = ? WHERE id = ?`,
|
|
status, taskUpdate.OutputPath, now, taskUpdate.Error, taskUpdate.TaskID,
|
|
)
|
|
if err != nil {
|
|
log.Printf("Failed to update task: %v", err)
|
|
return
|
|
}
|
|
|
|
// Update job status and progress
|
|
var jobID int64
|
|
err = s.db.QueryRow(
|
|
`SELECT job_id FROM tasks WHERE id = ?`,
|
|
taskUpdate.TaskID,
|
|
).Scan(&jobID)
|
|
if err == nil {
|
|
// Broadcast task update
|
|
s.broadcastTaskUpdate(jobID, taskUpdate.TaskID, "task_update", map[string]interface{}{
|
|
"status": status,
|
|
"output_path": taskUpdate.OutputPath,
|
|
"completed_at": now,
|
|
"error": taskUpdate.Error,
|
|
})
|
|
s.updateJobStatusFromTasks(jobID)
|
|
}
|
|
}
|
|
|
|
// parseBlenderFrame extracts the current frame number from Blender log messages
|
|
// Looks for patterns like "Fra:2470" in log messages
|
|
func parseBlenderFrame(logMessage string) (int, bool) {
|
|
// Look for "Fra:" followed by digits
|
|
// Pattern: "Fra:2470" or "Fra: 2470" or similar variations
|
|
fraIndex := strings.Index(logMessage, "Fra:")
|
|
if fraIndex == -1 {
|
|
return 0, false
|
|
}
|
|
|
|
// Find the number after "Fra:"
|
|
start := fraIndex + 4 // Skip "Fra:"
|
|
// Skip whitespace
|
|
for start < len(logMessage) && (logMessage[start] == ' ' || logMessage[start] == '\t') {
|
|
start++
|
|
}
|
|
|
|
// Extract digits
|
|
end := start
|
|
for end < len(logMessage) && logMessage[end] >= '0' && logMessage[end] <= '9' {
|
|
end++
|
|
}
|
|
|
|
if end > start {
|
|
frame, err := strconv.Atoi(logMessage[start:end])
|
|
if err == nil {
|
|
return frame, true
|
|
}
|
|
}
|
|
|
|
return 0, false
|
|
}
|
|
|
|
// getCurrentFrameFromLogs gets the highest frame number found in logs for a job's render tasks
|
|
func (s *Server) getCurrentFrameFromLogs(jobID int64) (int, bool) {
|
|
// Get all render tasks for this job
|
|
rows, err := s.db.Query(
|
|
`SELECT id FROM tasks WHERE job_id = ? AND task_type = ? AND status = ?`,
|
|
jobID, types.TaskTypeRender, types.TaskStatusRunning,
|
|
)
|
|
if err != nil {
|
|
return 0, false
|
|
}
|
|
defer rows.Close()
|
|
|
|
maxFrame := 0
|
|
found := false
|
|
|
|
for rows.Next() {
|
|
var taskID int64
|
|
if err := rows.Scan(&taskID); err != nil {
|
|
log.Printf("Failed to scan task ID in getCurrentFrameFromLogs: %v", err)
|
|
continue
|
|
}
|
|
|
|
// Get the most recent log entries for this task (last 100 to avoid scanning all logs)
|
|
logRows, err := s.db.Query(
|
|
`SELECT message FROM task_logs
|
|
WHERE task_id = ? AND message LIKE '%Fra:%'
|
|
ORDER BY id DESC LIMIT 100`,
|
|
taskID,
|
|
)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
for logRows.Next() {
|
|
var message string
|
|
if err := logRows.Scan(&message); err != nil {
|
|
continue
|
|
}
|
|
|
|
if frame, ok := parseBlenderFrame(message); ok {
|
|
if frame > maxFrame {
|
|
maxFrame = frame
|
|
found = true
|
|
}
|
|
}
|
|
}
|
|
logRows.Close()
|
|
}
|
|
|
|
return maxFrame, found
|
|
}
|
|
|
|
// updateJobStatusFromTasks updates job status and progress based on task states
|
|
func (s *Server) updateJobStatusFromTasks(jobID int64) {
|
|
now := time.Now()
|
|
|
|
// Get job info to check if it's a render job without parallel runners
|
|
var jobType string
|
|
var frameStart, frameEnd sql.NullInt64
|
|
var allowParallelRunners sql.NullBool
|
|
err := s.db.QueryRow(
|
|
`SELECT job_type, frame_start, frame_end, allow_parallel_runners FROM jobs WHERE id = ?`,
|
|
jobID,
|
|
).Scan(&jobType, &frameStart, &frameEnd, &allowParallelRunners)
|
|
if err != nil {
|
|
log.Printf("Failed to get job info for job %d: %v", jobID, err)
|
|
return
|
|
}
|
|
|
|
// Check if we should use frame-based progress (render job, single runner)
|
|
useFrameProgress := jobType == string(types.JobTypeRender) &&
|
|
allowParallelRunners.Valid && !allowParallelRunners.Bool &&
|
|
frameStart.Valid && frameEnd.Valid
|
|
|
|
// Get current job status to detect changes
|
|
var currentStatus string
|
|
err = s.db.QueryRow(`SELECT status FROM jobs WHERE id = ?`, jobID).Scan(¤tStatus)
|
|
if err != nil {
|
|
log.Printf("Failed to get current job status for job %d: %v", jobID, err)
|
|
return
|
|
}
|
|
|
|
// Count total tasks and completed tasks
|
|
var totalTasks, completedTasks int
|
|
err = s.db.QueryRow(
|
|
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?, ?, ?)`,
|
|
jobID, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed,
|
|
).Scan(&totalTasks)
|
|
if err != nil {
|
|
log.Printf("Failed to count total tasks for job %d: %v", jobID, err)
|
|
return
|
|
}
|
|
err = s.db.QueryRow(
|
|
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
|
|
jobID, types.TaskStatusCompleted,
|
|
).Scan(&completedTasks)
|
|
if err != nil {
|
|
log.Printf("Failed to count completed tasks for job %d: %v", jobID, err)
|
|
return
|
|
}
|
|
|
|
// Calculate progress
|
|
var progress float64
|
|
if totalTasks == 0 {
|
|
// All tasks cancelled or no tasks, set progress to 0
|
|
progress = 0.0
|
|
} else if useFrameProgress {
|
|
// For single-runner render jobs, use frame-based progress from logs
|
|
currentFrame, frameFound := s.getCurrentFrameFromLogs(jobID)
|
|
frameStartVal := int(frameStart.Int64)
|
|
frameEndVal := int(frameEnd.Int64)
|
|
totalFrames := frameEndVal - frameStartVal + 1
|
|
|
|
// Count non-render tasks (like video generation) separately
|
|
var nonRenderTasks, nonRenderCompleted int
|
|
s.db.QueryRow(
|
|
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND task_type != ? AND status IN (?, ?, ?, ?)`,
|
|
jobID, types.TaskTypeRender, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed,
|
|
).Scan(&nonRenderTasks)
|
|
s.db.QueryRow(
|
|
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND task_type != ? AND status = ?`,
|
|
jobID, types.TaskTypeRender, types.TaskStatusCompleted,
|
|
).Scan(&nonRenderCompleted)
|
|
|
|
// Calculate render task progress from frames
|
|
var renderProgress float64
|
|
if frameFound && totalFrames > 0 {
|
|
// Calculate how many frames have been rendered (current - start + 1)
|
|
// But cap at frame_end to handle cases where logs show frames beyond end
|
|
renderedFrames := currentFrame - frameStartVal + 1
|
|
if currentFrame > frameEndVal {
|
|
renderedFrames = totalFrames
|
|
} else if renderedFrames < 0 {
|
|
renderedFrames = 0
|
|
}
|
|
if renderedFrames > totalFrames {
|
|
renderedFrames = totalFrames
|
|
}
|
|
renderProgress = float64(renderedFrames) / float64(totalFrames) * 100.0
|
|
} else {
|
|
// Fall back to task-based progress for render tasks
|
|
var renderTasks, renderCompleted int
|
|
s.db.QueryRow(
|
|
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND task_type = ? AND status IN (?, ?, ?, ?)`,
|
|
jobID, types.TaskTypeRender, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed,
|
|
).Scan(&renderTasks)
|
|
s.db.QueryRow(
|
|
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND task_type = ? AND status = ?`,
|
|
jobID, types.TaskTypeRender, types.TaskStatusCompleted,
|
|
).Scan(&renderCompleted)
|
|
if renderTasks > 0 {
|
|
renderProgress = float64(renderCompleted) / float64(renderTasks) * 100.0
|
|
}
|
|
}
|
|
|
|
// Combine render progress with non-render task progress
|
|
// Weight: render tasks contribute 90%, other tasks contribute 10% (adjust as needed)
|
|
var nonRenderProgress float64
|
|
if nonRenderTasks > 0 {
|
|
nonRenderProgress = float64(nonRenderCompleted) / float64(nonRenderTasks) * 100.0
|
|
}
|
|
|
|
// Weighted average: render progress is most important
|
|
if totalTasks > 0 {
|
|
renderWeight := 0.9
|
|
nonRenderWeight := 0.1
|
|
progress = renderProgress*renderWeight + nonRenderProgress*nonRenderWeight
|
|
} else {
|
|
progress = renderProgress
|
|
}
|
|
} else {
|
|
// Standard task-based progress
|
|
progress = float64(completedTasks) / float64(totalTasks) * 100.0
|
|
}
|
|
|
|
var jobStatus string
|
|
var outputFormat sql.NullString
|
|
s.db.QueryRow(`SELECT output_format FROM jobs WHERE id = ?`, jobID).Scan(&outputFormat)
|
|
outputFormatStr := ""
|
|
if outputFormat.Valid {
|
|
outputFormatStr = outputFormat.String
|
|
}
|
|
|
|
// Check if all non-cancelled tasks are completed
|
|
var pendingOrRunningTasks int
|
|
err = s.db.QueryRow(
|
|
`SELECT COUNT(*) FROM tasks
|
|
WHERE job_id = ? AND status IN (?, ?)`,
|
|
jobID, types.TaskStatusPending, types.TaskStatusRunning,
|
|
).Scan(&pendingOrRunningTasks)
|
|
if err != nil {
|
|
log.Printf("Failed to count pending/running tasks for job %d: %v", jobID, err)
|
|
return
|
|
}
|
|
|
|
if pendingOrRunningTasks == 0 && totalTasks > 0 {
|
|
// All tasks are either completed or failed/cancelled
|
|
// Check if any tasks failed
|
|
var failedTasks int
|
|
s.db.QueryRow(
|
|
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
|
|
jobID, types.TaskStatusFailed,
|
|
).Scan(&failedTasks)
|
|
|
|
if failedTasks > 0 {
|
|
// Some tasks failed - mark job as failed
|
|
jobStatus = string(types.JobStatusFailed)
|
|
} else {
|
|
// All tasks completed successfully
|
|
jobStatus = string(types.JobStatusCompleted)
|
|
progress = 100.0 // Ensure progress is 100% when all tasks complete
|
|
}
|
|
_, err := s.db.Exec(
|
|
`UPDATE jobs SET status = ?, progress = ?, completed_at = ? WHERE id = ?`,
|
|
jobStatus, progress, now, jobID,
|
|
)
|
|
if err != nil {
|
|
log.Printf("Failed to update job %d status to %s: %v", jobID, jobStatus, err)
|
|
} else {
|
|
// Only log if status actually changed
|
|
if currentStatus != jobStatus {
|
|
log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed tasks: %d/%d)", jobID, currentStatus, jobStatus, progress, completedTasks, totalTasks)
|
|
}
|
|
// Broadcast job update via WebSocket
|
|
s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{
|
|
"status": jobStatus,
|
|
"progress": progress,
|
|
"completed_at": now,
|
|
})
|
|
}
|
|
|
|
if outputFormatStr == "EXR_264_MP4" || outputFormatStr == "EXR_AV1_MP4" {
|
|
// Check if a video generation task already exists for this job (any status)
|
|
var existingVideoTask int
|
|
s.db.QueryRow(
|
|
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND task_type = ?`,
|
|
jobID, types.TaskTypeVideoGeneration,
|
|
).Scan(&existingVideoTask)
|
|
|
|
if existingVideoTask == 0 {
|
|
// Create a video generation task instead of calling generateMP4Video directly
|
|
// This prevents race conditions when multiple runners complete frames simultaneously
|
|
videoTaskTimeout := 86400 // 24 hours for video generation
|
|
var videoTaskID int64
|
|
err := s.db.QueryRow(
|
|
`INSERT INTO tasks (job_id, frame_start, frame_end, task_type, status, timeout_seconds, max_retries)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
RETURNING id`,
|
|
jobID, 0, 0, types.TaskTypeVideoGeneration, types.TaskStatusPending, videoTaskTimeout, 1,
|
|
).Scan(&videoTaskID)
|
|
if err != nil {
|
|
log.Printf("Failed to create video generation task for job %d: %v", jobID, err)
|
|
} else {
|
|
// Broadcast that a new task was added
|
|
log.Printf("Broadcasting task_added for job %d: video generation task %d", jobID, videoTaskID)
|
|
s.broadcastTaskUpdate(jobID, videoTaskID, "task_added", map[string]interface{}{
|
|
"task_id": videoTaskID,
|
|
"task_type": types.TaskTypeVideoGeneration,
|
|
})
|
|
// Update job status to ensure it's marked as running (has pending video task)
|
|
s.updateJobStatusFromTasks(jobID)
|
|
// Try to distribute the task immediately
|
|
s.triggerTaskDistribution()
|
|
}
|
|
} else {
|
|
log.Printf("Skipping video generation task creation for job %d (video task already exists)", jobID)
|
|
}
|
|
}
|
|
} else {
|
|
// Job has pending or running tasks - determine if it's running or still pending
|
|
var runningTasks int
|
|
s.db.QueryRow(
|
|
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
|
|
jobID, types.TaskStatusRunning,
|
|
).Scan(&runningTasks)
|
|
|
|
if runningTasks > 0 {
|
|
// Has running tasks - job is running
|
|
jobStatus = string(types.JobStatusRunning)
|
|
var startedAt sql.NullTime
|
|
s.db.QueryRow(`SELECT started_at FROM jobs WHERE id = ?`, jobID).Scan(&startedAt)
|
|
if !startedAt.Valid {
|
|
s.db.Exec(`UPDATE jobs SET started_at = ? WHERE id = ?`, now, jobID)
|
|
}
|
|
} else {
|
|
// All tasks are pending - job is pending
|
|
jobStatus = string(types.JobStatusPending)
|
|
}
|
|
|
|
_, err := s.db.Exec(
|
|
`UPDATE jobs SET status = ?, progress = ? WHERE id = ?`,
|
|
jobStatus, progress, jobID,
|
|
)
|
|
if err != nil {
|
|
log.Printf("Failed to update job %d status to %s: %v", jobID, jobStatus, err)
|
|
} else {
|
|
// Only log if status actually changed
|
|
if currentStatus != jobStatus {
|
|
log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed: %d/%d, pending: %d, running: %d)", jobID, currentStatus, jobStatus, progress, completedTasks, totalTasks, pendingOrRunningTasks-runningTasks, runningTasks)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// broadcastLogToFrontend broadcasts log to connected frontend clients
|
|
func (s *Server) broadcastLogToFrontend(taskID int64, logEntry WSLogEntry) {
|
|
// Get job_id from task
|
|
var jobID int64
|
|
err := s.db.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&jobID)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
key := fmt.Sprintf("%d:%d", jobID, taskID)
|
|
s.frontendConnsMu.RLock()
|
|
conn, exists := s.frontendConns[key]
|
|
s.frontendConnsMu.RUnlock()
|
|
|
|
if exists && conn != nil {
|
|
// Get full log entry from database for consistency
|
|
// Use a more reliable query that gets the most recent log with matching message
|
|
// This avoids race conditions with concurrent inserts
|
|
var log types.TaskLog
|
|
var runnerID sql.NullInt64
|
|
err := s.db.QueryRow(
|
|
`SELECT id, task_id, runner_id, log_level, message, step_name, created_at
|
|
FROM task_logs WHERE task_id = ? AND message = ? ORDER BY id DESC LIMIT 1`,
|
|
taskID, logEntry.Message,
|
|
).Scan(&log.ID, &log.TaskID, &runnerID, &log.LogLevel, &log.Message, &log.StepName, &log.CreatedAt)
|
|
if err == nil {
|
|
if runnerID.Valid {
|
|
log.RunnerID = &runnerID.Int64
|
|
}
|
|
msg := map[string]interface{}{
|
|
"type": "log",
|
|
"data": log,
|
|
"timestamp": time.Now().Unix(),
|
|
}
|
|
// Serialize writes to prevent concurrent write panics
|
|
s.frontendConnsWriteMuMu.RLock()
|
|
writeMu, hasMu := s.frontendConnsWriteMu[key]
|
|
s.frontendConnsWriteMuMu.RUnlock()
|
|
|
|
if hasMu && writeMu != nil {
|
|
writeMu.Lock()
|
|
conn.WriteJSON(msg)
|
|
writeMu.Unlock()
|
|
} else {
|
|
// Fallback if mutex doesn't exist yet (shouldn't happen, but be safe)
|
|
conn.WriteJSON(msg)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// triggerTaskDistribution triggers task distribution in a serialized manner
|
|
func (s *Server) triggerTaskDistribution() {
|
|
go func() {
|
|
// Try to acquire lock - if already running, skip
|
|
if !s.taskDistMu.TryLock() {
|
|
return // Distribution already in progress
|
|
}
|
|
defer s.taskDistMu.Unlock()
|
|
s.distributeTasksToRunners()
|
|
}()
|
|
}
|
|
|
|
// distributeTasksToRunners pushes available tasks to connected runners
|
|
// This function should only be called while holding taskDistMu lock
|
|
func (s *Server) distributeTasksToRunners() {
|
|
// Quick check: if there are no pending tasks, skip the expensive query
|
|
var pendingCount int
|
|
err := s.db.QueryRow(
|
|
`SELECT COUNT(*) FROM tasks t
|
|
JOIN jobs j ON t.job_id = j.id
|
|
WHERE t.status = ? AND j.status != ?`,
|
|
types.TaskStatusPending, types.JobStatusCancelled,
|
|
).Scan(&pendingCount)
|
|
if err != nil {
|
|
log.Printf("Failed to check pending tasks count: %v", err)
|
|
return
|
|
}
|
|
if pendingCount == 0 {
|
|
// No pending tasks, nothing to distribute
|
|
return
|
|
}
|
|
|
|
// Get all pending tasks
|
|
rows, err := s.db.Query(
|
|
`SELECT t.id, t.job_id, t.frame_start, t.frame_end, t.task_type, j.allow_parallel_runners, j.status as job_status, j.name as job_name
|
|
FROM tasks t
|
|
JOIN jobs j ON t.job_id = j.id
|
|
WHERE t.status = ? AND j.status != ?
|
|
ORDER BY t.created_at ASC`,
|
|
types.TaskStatusPending, types.JobStatusCancelled,
|
|
)
|
|
if err != nil {
|
|
log.Printf("Failed to query pending tasks: %v", err)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
var pendingTasks []struct {
|
|
TaskID int64
|
|
JobID int64
|
|
FrameStart int
|
|
FrameEnd int
|
|
TaskType string
|
|
AllowParallelRunners bool
|
|
JobName string
|
|
JobStatus string
|
|
}
|
|
|
|
for rows.Next() {
|
|
var t struct {
|
|
TaskID int64
|
|
JobID int64
|
|
FrameStart int
|
|
FrameEnd int
|
|
TaskType string
|
|
AllowParallelRunners bool
|
|
JobName string
|
|
JobStatus string
|
|
}
|
|
var allowParallel sql.NullBool
|
|
err := rows.Scan(&t.TaskID, &t.JobID, &t.FrameStart, &t.FrameEnd, &t.TaskType, &allowParallel, &t.JobStatus, &t.JobName)
|
|
if err != nil {
|
|
log.Printf("Failed to scan pending task: %v", err)
|
|
continue
|
|
}
|
|
// Default to true if NULL (for metadata jobs or legacy data)
|
|
if allowParallel.Valid {
|
|
t.AllowParallelRunners = allowParallel.Bool
|
|
} else {
|
|
t.AllowParallelRunners = true
|
|
}
|
|
pendingTasks = append(pendingTasks, t)
|
|
}
|
|
|
|
if len(pendingTasks) == 0 {
|
|
log.Printf("No pending tasks found for distribution")
|
|
return
|
|
}
|
|
|
|
log.Printf("Found %d pending tasks for distribution", len(pendingTasks))
|
|
|
|
// Get connected runners (WebSocket connection is source of truth)
|
|
// Use a read lock to safely read the map
|
|
s.runnerConnsMu.RLock()
|
|
connectedRunners := make([]int64, 0, len(s.runnerConns))
|
|
for runnerID := range s.runnerConns {
|
|
// Verify connection is still valid (not closed)
|
|
conn := s.runnerConns[runnerID]
|
|
if conn != nil {
|
|
connectedRunners = append(connectedRunners, runnerID)
|
|
}
|
|
}
|
|
s.runnerConnsMu.RUnlock()
|
|
|
|
// Get runner priorities and capabilities for all connected runners
|
|
runnerPriorities := make(map[int64]int)
|
|
runnerCapabilities := make(map[int64]map[string]interface{})
|
|
for _, runnerID := range connectedRunners {
|
|
var priority int
|
|
var capabilitiesJSON string
|
|
err := s.db.QueryRow("SELECT priority, capabilities FROM runners WHERE id = ?", runnerID).Scan(&priority, &capabilitiesJSON)
|
|
if err != nil {
|
|
// Default to 100 if priority not found
|
|
priority = 100
|
|
capabilitiesJSON = "{}"
|
|
}
|
|
runnerPriorities[runnerID] = priority
|
|
|
|
// Parse capabilities JSON (can contain both bools and numbers)
|
|
var capabilities map[string]interface{}
|
|
if err := json.Unmarshal([]byte(capabilitiesJSON), &capabilities); err != nil {
|
|
// If parsing fails, try old format (map[string]bool) for backward compatibility
|
|
var oldCapabilities map[string]bool
|
|
if err2 := json.Unmarshal([]byte(capabilitiesJSON), &oldCapabilities); err2 == nil {
|
|
// Convert old format to new format
|
|
capabilities = make(map[string]interface{})
|
|
for k, v := range oldCapabilities {
|
|
capabilities[k] = v
|
|
}
|
|
} else {
|
|
// Both formats failed, assume no capabilities
|
|
capabilities = make(map[string]interface{})
|
|
}
|
|
}
|
|
runnerCapabilities[runnerID] = capabilities
|
|
}
|
|
|
|
// Update database status for all connected runners (outside the lock to avoid holding it too long)
|
|
for _, runnerID := range connectedRunners {
|
|
// Ensure database status matches WebSocket connection
|
|
// Update status to online if it's not already
|
|
_, _ = s.db.Exec(
|
|
`UPDATE runners SET status = ?, last_heartbeat = ? WHERE id = ? AND status != ?`,
|
|
types.RunnerStatusOnline, time.Now(), runnerID, types.RunnerStatusOnline,
|
|
)
|
|
}
|
|
|
|
if len(connectedRunners) == 0 {
|
|
log.Printf("No connected runners available for task distribution (checked WebSocket connections)")
|
|
// Log to task logs that no runners are available
|
|
for _, task := range pendingTasks {
|
|
if task.TaskType == string(types.TaskTypeMetadata) {
|
|
s.logTaskEvent(task.TaskID, nil, types.LogLevelWarn, "No connected runners available for task assignment", "")
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// Log task types being distributed
|
|
taskTypes := make(map[string]int)
|
|
for _, task := range pendingTasks {
|
|
taskTypes[task.TaskType]++
|
|
}
|
|
log.Printf("Distributing %d pending tasks (%v) to %d connected runners: %v", len(pendingTasks), taskTypes, len(connectedRunners), connectedRunners)
|
|
|
|
// Distribute tasks to runners
|
|
// Sort tasks to prioritize metadata tasks
|
|
sort.Slice(pendingTasks, func(i, j int) bool {
|
|
// Metadata tasks first
|
|
if pendingTasks[i].TaskType == string(types.TaskTypeMetadata) && pendingTasks[j].TaskType != string(types.TaskTypeMetadata) {
|
|
return true
|
|
}
|
|
if pendingTasks[i].TaskType != string(types.TaskTypeMetadata) && pendingTasks[j].TaskType == string(types.TaskTypeMetadata) {
|
|
return false
|
|
}
|
|
return false // Keep original order for same type
|
|
})
|
|
|
|
// Track how many tasks each runner has been assigned in this distribution cycle
|
|
runnerTaskCounts := make(map[int64]int)
|
|
|
|
for _, task := range pendingTasks {
|
|
// Determine required capability for this task
|
|
var requiredCapability string
|
|
switch task.TaskType {
|
|
case string(types.TaskTypeRender), string(types.TaskTypeMetadata):
|
|
requiredCapability = "blender"
|
|
case string(types.TaskTypeVideoGeneration):
|
|
requiredCapability = "ffmpeg"
|
|
default:
|
|
requiredCapability = "" // Unknown task type
|
|
}
|
|
|
|
// Find available runner
|
|
var selectedRunnerID int64
|
|
var bestRunnerID int64
|
|
var bestCapabilityMatch int = -1 // 0 = only required, 1 = required + others, 2 = no match
|
|
var bestPriority int = -1
|
|
var bestTaskCount int = -1
|
|
var bestRandom float64 = -1 // Random tie-breaker
|
|
|
|
isMetadataTask := task.TaskType == string(types.TaskTypeMetadata)
|
|
|
|
// Try to find the best runner for this task
|
|
for _, runnerID := range connectedRunners {
|
|
// Check if runner has required capability
|
|
capabilities := runnerCapabilities[runnerID]
|
|
hasRequired := false
|
|
if reqVal, ok := capabilities[requiredCapability]; ok {
|
|
if reqBool, ok := reqVal.(bool); ok {
|
|
hasRequired = reqBool
|
|
} else if reqFloat, ok := reqVal.(float64); ok {
|
|
hasRequired = reqFloat > 0
|
|
} else if reqInt, ok := reqVal.(int); ok {
|
|
hasRequired = reqInt > 0
|
|
}
|
|
}
|
|
if !hasRequired && requiredCapability != "" {
|
|
continue // Runner doesn't have required capability
|
|
}
|
|
|
|
// For video generation tasks, check GPU availability and ensure no blender tasks are running
|
|
if task.TaskType == string(types.TaskTypeVideoGeneration) {
|
|
// Check if runner has any blender/render tasks running (mutual exclusion)
|
|
var runningBlenderTasks int
|
|
s.db.QueryRow(
|
|
`SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status = ? AND task_type = ?`,
|
|
runnerID, types.TaskStatusRunning, types.TaskTypeRender,
|
|
).Scan(&runningBlenderTasks)
|
|
|
|
if runningBlenderTasks > 0 {
|
|
continue // Runner is busy with blender tasks, cannot run video tasks simultaneously
|
|
}
|
|
|
|
// Get GPU count from capabilities
|
|
var gpuCount int
|
|
if videoGPUs, ok := capabilities["video_gpu_count"]; ok {
|
|
if count, ok := videoGPUs.(float64); ok {
|
|
gpuCount = int(count)
|
|
} else if count, ok := videoGPUs.(int); ok {
|
|
gpuCount = count
|
|
}
|
|
}
|
|
|
|
// Count how many video generation tasks are currently running on this runner
|
|
var runningVideoTasks int
|
|
s.db.QueryRow(
|
|
`SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status = ? AND task_type = ?`,
|
|
runnerID, types.TaskStatusRunning, types.TaskTypeVideoGeneration,
|
|
).Scan(&runningVideoTasks)
|
|
|
|
// If all GPUs are in use, skip this runner
|
|
if gpuCount > 0 && runningVideoTasks >= gpuCount {
|
|
continue // All GPUs are busy
|
|
}
|
|
}
|
|
|
|
// For render/blender tasks, check if runner is busy and ensure no video tasks are running
|
|
if !isMetadataTask && task.TaskType != string(types.TaskTypeVideoGeneration) {
|
|
// Check if runner has any video generation tasks running (mutual exclusion)
|
|
var runningVideoTasks int
|
|
s.db.QueryRow(
|
|
`SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status = ? AND task_type = ?`,
|
|
runnerID, types.TaskStatusRunning, types.TaskTypeVideoGeneration,
|
|
).Scan(&runningVideoTasks)
|
|
|
|
if runningVideoTasks > 0 {
|
|
continue // Runner is busy with video tasks, cannot run blender tasks simultaneously
|
|
}
|
|
|
|
// Check if runner is busy (has running render tasks) - only for non-metadata, non-video tasks
|
|
var runningCount int
|
|
s.db.QueryRow(
|
|
`SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status = ? AND task_type NOT IN (?, ?)`,
|
|
runnerID, types.TaskStatusRunning, types.TaskTypeMetadata, types.TaskTypeVideoGeneration,
|
|
).Scan(&runningCount)
|
|
|
|
if runningCount > 0 {
|
|
continue // Runner is busy with render tasks
|
|
}
|
|
}
|
|
|
|
// For non-parallel jobs, check if runner already has tasks from this job
|
|
if !task.AllowParallelRunners {
|
|
var jobTaskCount int
|
|
s.db.QueryRow(
|
|
`SELECT COUNT(*) FROM tasks
|
|
WHERE job_id = ? AND runner_id = ? AND status IN (?, ?)`,
|
|
task.JobID, runnerID, types.TaskStatusPending, types.TaskStatusRunning,
|
|
).Scan(&jobTaskCount)
|
|
if jobTaskCount > 0 {
|
|
continue // Another runner is working on this job
|
|
}
|
|
}
|
|
|
|
// Determine capability match type
|
|
// Count how many capabilities the runner has
|
|
capabilityCount := 0
|
|
hasBlender := false
|
|
hasFFmpeg := false
|
|
if blenderVal, ok := capabilities["blender"]; ok {
|
|
if b, ok := blenderVal.(bool); ok {
|
|
hasBlender = b
|
|
}
|
|
}
|
|
if ffmpegVal, ok := capabilities["ffmpeg"]; ok {
|
|
if f, ok := ffmpegVal.(bool); ok {
|
|
hasFFmpeg = f
|
|
}
|
|
}
|
|
if hasBlender {
|
|
capabilityCount++
|
|
}
|
|
if hasFFmpeg {
|
|
capabilityCount++
|
|
}
|
|
|
|
// Determine match type: 0 = only required capability, 1 = required + others
|
|
var capabilityMatch int
|
|
if capabilityCount == 1 {
|
|
capabilityMatch = 0 // Only has the required capability
|
|
} else {
|
|
capabilityMatch = 1 // Has required + other capabilities
|
|
}
|
|
|
|
// Get runner priority and task count
|
|
priority := runnerPriorities[runnerID]
|
|
currentTaskCount := runnerTaskCounts[runnerID]
|
|
// Generate a small random value for absolute tie-breaking
|
|
randomValue := rand.Float64()
|
|
|
|
// Selection priority:
|
|
// 1. Capability match (0 = only required, 1 = required + others)
|
|
// 2. Priority (higher is better)
|
|
// 3. Task count (fewer is better)
|
|
// 4. Random value (absolute tie-breaker)
|
|
isBetter := false
|
|
if bestRunnerID == 0 {
|
|
isBetter = true
|
|
} else if capabilityMatch < bestCapabilityMatch {
|
|
// Prefer runners with only the required capability
|
|
isBetter = true
|
|
} else if capabilityMatch == bestCapabilityMatch {
|
|
if priority > bestPriority {
|
|
// Same capability match, but higher priority
|
|
isBetter = true
|
|
} else if priority == bestPriority {
|
|
if currentTaskCount < bestTaskCount {
|
|
// Same capability match and priority, but fewer tasks
|
|
isBetter = true
|
|
} else if currentTaskCount == bestTaskCount {
|
|
// Absolute tie - use random value as tie-breaker
|
|
if randomValue > bestRandom {
|
|
isBetter = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if isBetter {
|
|
bestRunnerID = runnerID
|
|
bestCapabilityMatch = capabilityMatch
|
|
bestPriority = priority
|
|
bestTaskCount = currentTaskCount
|
|
bestRandom = randomValue
|
|
}
|
|
}
|
|
|
|
// Use the best runner we found (prioritized by capability match, then priority, then load balanced)
|
|
if bestRunnerID != 0 {
|
|
selectedRunnerID = bestRunnerID
|
|
}
|
|
|
|
if selectedRunnerID == 0 {
|
|
if task.TaskType == string(types.TaskTypeMetadata) {
|
|
log.Printf("Warning: No available runner for metadata task %d (job %d)", task.TaskID, task.JobID)
|
|
// Log that no runner is available
|
|
s.logTaskEvent(task.TaskID, nil, types.LogLevelWarn, "No available runner for task assignment", "")
|
|
}
|
|
continue // No available runner - task stays in queue
|
|
}
|
|
|
|
// Track assignment for load balancing
|
|
runnerTaskCounts[selectedRunnerID]++
|
|
|
|
// Atomically assign task to runner using UPDATE with WHERE runner_id IS NULL
|
|
// This prevents race conditions when multiple goroutines try to assign the same task
|
|
// Use a transaction to ensure atomicity and handle DuckDB's foreign key constraints
|
|
now := time.Now()
|
|
tx, err := s.db.Begin()
|
|
if err != nil {
|
|
log.Printf("Failed to begin transaction for task %d: %v", task.TaskID, err)
|
|
continue
|
|
}
|
|
|
|
result, err := tx.Exec(
|
|
`UPDATE tasks SET runner_id = ?, status = ?, started_at = ?
|
|
WHERE id = ? AND runner_id IS NULL AND status = ?`,
|
|
selectedRunnerID, types.TaskStatusRunning, now, task.TaskID, types.TaskStatusPending,
|
|
)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
log.Printf("Failed to atomically assign task %d: %v", task.TaskID, err)
|
|
continue
|
|
}
|
|
|
|
// Check if the update actually affected a row (task was successfully assigned)
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
tx.Rollback()
|
|
log.Printf("Failed to get rows affected for task %d: %v", task.TaskID, err)
|
|
continue
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
// Task was already assigned by another goroutine, skip
|
|
tx.Rollback()
|
|
continue
|
|
}
|
|
|
|
// Commit the assignment before attempting WebSocket send
|
|
// If send fails, we'll rollback in a separate transaction
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
log.Printf("Failed to commit transaction for task %d: %v", task.TaskID, err)
|
|
continue
|
|
}
|
|
|
|
// Broadcast task assignment
|
|
s.broadcastTaskUpdate(task.JobID, task.TaskID, "task_update", map[string]interface{}{
|
|
"status": types.TaskStatusRunning,
|
|
"runner_id": selectedRunnerID,
|
|
"started_at": now,
|
|
})
|
|
|
|
// Task was successfully assigned in database, now send via WebSocket
|
|
log.Printf("Assigned task %d (type: %s, job: %d) to runner %d", task.TaskID, task.TaskType, task.JobID, selectedRunnerID)
|
|
|
|
// Log runner assignment to task logs
|
|
s.logTaskEvent(task.TaskID, nil, types.LogLevelInfo, fmt.Sprintf("Task assigned to runner %d", selectedRunnerID), "")
|
|
|
|
// Attempt to send task to runner via WebSocket
|
|
if err := s.assignTaskToRunner(selectedRunnerID, task.TaskID); err != nil {
|
|
log.Printf("Failed to send task %d to runner %d: %v", task.TaskID, selectedRunnerID, err)
|
|
// Log assignment failure
|
|
s.logTaskEvent(task.TaskID, nil, types.LogLevelError, fmt.Sprintf("Failed to send task to runner %d: %v", selectedRunnerID, err), "")
|
|
// Rollback the assignment if WebSocket send fails using a new transaction
|
|
rollbackTx, rollbackErr := s.db.Begin()
|
|
if rollbackErr == nil {
|
|
_, rollbackErr = rollbackTx.Exec(
|
|
`UPDATE tasks SET runner_id = NULL, status = ?, started_at = NULL
|
|
WHERE id = ? AND runner_id = ?`,
|
|
types.TaskStatusPending, task.TaskID, selectedRunnerID,
|
|
)
|
|
if rollbackErr == nil {
|
|
rollbackTx.Commit()
|
|
// Log rollback
|
|
s.logTaskEvent(task.TaskID, nil, types.LogLevelWarn, fmt.Sprintf("Task assignment rolled back - runner %d connection failed", selectedRunnerID), "")
|
|
// Update job status after rollback
|
|
s.updateJobStatusFromTasks(task.JobID)
|
|
// Trigger redistribution
|
|
s.triggerTaskDistribution()
|
|
} else {
|
|
rollbackTx.Rollback()
|
|
log.Printf("Failed to rollback task %d assignment: %v", task.TaskID, rollbackErr)
|
|
}
|
|
}
|
|
} else {
|
|
// WebSocket send succeeded, update job status
|
|
s.updateJobStatusFromTasks(task.JobID)
|
|
}
|
|
}
|
|
}
|
|
|
|
// assignTaskToRunner sends a task to a runner via WebSocket
|
|
func (s *Server) assignTaskToRunner(runnerID int64, taskID int64) error {
|
|
s.runnerConnsMu.RLock()
|
|
conn, exists := s.runnerConns[runnerID]
|
|
s.runnerConnsMu.RUnlock()
|
|
|
|
if !exists {
|
|
return fmt.Errorf("runner %d not connected", runnerID)
|
|
}
|
|
|
|
// Get task details
|
|
var task WSTaskAssignment
|
|
var jobName string
|
|
var outputFormat sql.NullString
|
|
var taskType string
|
|
err := s.db.QueryRow(
|
|
`SELECT t.job_id, t.frame_start, t.frame_end, t.task_type, j.name, j.output_format
|
|
FROM tasks t JOIN jobs j ON t.job_id = j.id WHERE t.id = ?`,
|
|
taskID,
|
|
).Scan(&task.JobID, &task.FrameStart, &task.FrameEnd, &taskType, &jobName, &outputFormat)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
task.TaskID = taskID
|
|
task.JobName = jobName
|
|
if outputFormat.Valid {
|
|
task.OutputFormat = outputFormat.String
|
|
log.Printf("Task %d assigned with output_format: '%s' (from job %d)", taskID, outputFormat.String, task.JobID)
|
|
} else {
|
|
log.Printf("Task %d assigned with no output_format (job %d)", taskID, task.JobID)
|
|
}
|
|
task.TaskType = taskType
|
|
|
|
// Get input files
|
|
rows, err := s.db.Query(
|
|
`SELECT file_path FROM job_files WHERE job_id = ? AND file_type = ?`,
|
|
task.JobID, types.JobFileTypeInput,
|
|
)
|
|
if err == nil {
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var filePath string
|
|
if err := rows.Scan(&filePath); err == nil {
|
|
task.InputFiles = append(task.InputFiles, filePath)
|
|
} else {
|
|
log.Printf("Failed to scan input file path for task %d: %v", taskID, err)
|
|
}
|
|
}
|
|
} else {
|
|
log.Printf("Warning: Failed to query input files for task %d (job %d): %v", taskID, task.JobID, err)
|
|
}
|
|
|
|
if len(task.InputFiles) == 0 {
|
|
errMsg := fmt.Sprintf("No input files found for task %d (job %d). Cannot assign task without input files.", taskID, task.JobID)
|
|
log.Printf("ERROR: %s", errMsg)
|
|
// Don't send the task - it will fail anyway
|
|
// Rollback the assignment
|
|
s.db.Exec(
|
|
`UPDATE tasks SET runner_id = NULL, status = ?, started_at = NULL
|
|
WHERE id = ?`,
|
|
types.TaskStatusPending, taskID,
|
|
)
|
|
s.logTaskEvent(taskID, nil, types.LogLevelError, errMsg, "")
|
|
return errors.New(errMsg)
|
|
}
|
|
|
|
// Note: Task is already assigned in database by the atomic update in distributeTasksToRunners
|
|
// We just need to verify it's still assigned to this runner
|
|
var assignedRunnerID sql.NullInt64
|
|
err = s.db.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskID).Scan(&assignedRunnerID)
|
|
if err != nil {
|
|
return fmt.Errorf("task not found: %w", err)
|
|
}
|
|
if !assignedRunnerID.Valid || assignedRunnerID.Int64 != runnerID {
|
|
return fmt.Errorf("task %d is not assigned to runner %d", taskID, runnerID)
|
|
}
|
|
|
|
// Send task via WebSocket with write mutex protection
|
|
msg := WSMessage{
|
|
Type: "task_assignment",
|
|
Timestamp: time.Now().Unix(),
|
|
}
|
|
msg.Data, _ = json.Marshal(task)
|
|
|
|
// Get write mutex for this connection
|
|
s.runnerConnsWriteMuMu.RLock()
|
|
writeMu, hasMu := s.runnerConnsWriteMu[runnerID]
|
|
s.runnerConnsWriteMuMu.RUnlock()
|
|
|
|
if !hasMu || writeMu == nil {
|
|
return fmt.Errorf("runner %d write mutex not found", runnerID)
|
|
}
|
|
|
|
// Re-check connection is still valid before writing
|
|
s.runnerConnsMu.RLock()
|
|
_, stillExists := s.runnerConns[runnerID]
|
|
s.runnerConnsMu.RUnlock()
|
|
if !stillExists {
|
|
return fmt.Errorf("runner %d disconnected", runnerID)
|
|
}
|
|
|
|
writeMu.Lock()
|
|
err = conn.WriteJSON(msg)
|
|
writeMu.Unlock()
|
|
return err
|
|
}
|
|
|
|
// redistributeRunnerTasks resets tasks assigned to a disconnected/dead runner and redistributes them
|
|
func (s *Server) redistributeRunnerTasks(runnerID int64) {
|
|
// Get tasks assigned to this runner
|
|
taskRows, err := s.db.Query(
|
|
`SELECT id, retry_count, max_retries FROM tasks
|
|
WHERE runner_id = ? AND status = ?`,
|
|
runnerID, types.TaskStatusRunning,
|
|
)
|
|
if err != nil {
|
|
log.Printf("Failed to query tasks for runner %d: %v", runnerID, err)
|
|
return
|
|
}
|
|
defer taskRows.Close()
|
|
|
|
var tasksToReset []struct {
|
|
ID int64
|
|
RetryCount int
|
|
MaxRetries int
|
|
}
|
|
|
|
for taskRows.Next() {
|
|
var t struct {
|
|
ID int64
|
|
RetryCount int
|
|
MaxRetries int
|
|
}
|
|
if err := taskRows.Scan(&t.ID, &t.RetryCount, &t.MaxRetries); err == nil {
|
|
tasksToReset = append(tasksToReset, t)
|
|
}
|
|
}
|
|
|
|
if len(tasksToReset) == 0 {
|
|
return // No tasks to redistribute
|
|
}
|
|
|
|
log.Printf("Redistributing %d tasks from runner %d", len(tasksToReset), runnerID)
|
|
|
|
// Reset or fail tasks
|
|
for _, task := range tasksToReset {
|
|
if task.RetryCount >= task.MaxRetries {
|
|
// Mark as failed
|
|
_, err = s.db.Exec(
|
|
`UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL
|
|
WHERE id = ?`,
|
|
types.TaskStatusFailed, "Runner died, max retries exceeded", task.ID,
|
|
)
|
|
if err != nil {
|
|
log.Printf("Failed to mark task %d as failed: %v", task.ID, err)
|
|
} else {
|
|
// Log task failure
|
|
s.logTaskEvent(task.ID, &runnerID, types.LogLevelError, fmt.Sprintf("Task failed - runner %d disconnected, max retries (%d) exceeded", runnerID, task.MaxRetries), "")
|
|
}
|
|
} else {
|
|
// Reset to pending so it can be redistributed
|
|
_, err = s.db.Exec(
|
|
`UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL,
|
|
retry_count = retry_count + 1 WHERE id = ?`,
|
|
types.TaskStatusPending, task.ID,
|
|
)
|
|
if err != nil {
|
|
log.Printf("Failed to reset task %d: %v", task.ID, err)
|
|
} else {
|
|
// Log task reset for redistribution
|
|
s.logTaskEvent(task.ID, &runnerID, types.LogLevelWarn, fmt.Sprintf("Runner %d disconnected, task reset for redistribution (retry %d/%d)", runnerID, task.RetryCount+1, task.MaxRetries), "")
|
|
}
|
|
}
|
|
}
|
|
|
|
// Immediately redistribute the reset tasks
|
|
s.triggerTaskDistribution()
|
|
}
|
|
|
|
// logTaskEvent logs an event to a task's log (manager-side logging)
|
|
func (s *Server) logTaskEvent(taskID int64, runnerID *int64, logLevel types.LogLevel, message, stepName string) {
|
|
var runnerIDValue interface{}
|
|
if runnerID != nil {
|
|
runnerIDValue = *runnerID
|
|
}
|
|
|
|
_, err := s.db.Exec(
|
|
`INSERT INTO task_logs (task_id, runner_id, log_level, message, step_name, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)`,
|
|
taskID, runnerIDValue, logLevel, message, stepName, time.Now(),
|
|
)
|
|
if err != nil {
|
|
log.Printf("Failed to log task event for task %d: %v", taskID, err)
|
|
return
|
|
}
|
|
|
|
// Broadcast to frontend if there are connected clients
|
|
s.broadcastLogToFrontend(taskID, WSLogEntry{
|
|
TaskID: taskID,
|
|
LogLevel: string(logLevel),
|
|
Message: message,
|
|
StepName: stepName,
|
|
})
|
|
}
|