package api import ( "context" "crypto/subtle" "database/sql" "encoding/json" "fmt" "io" "log" "net/http" "time" "fuego/internal/auth" "fuego/pkg/types" "github.com/go-chi/chi/v5" "github.com/gorilla/websocket" ) // handleListRunners lists all runners func (s *Server) handleListRunners(w http.ResponseWriter, r *http.Request) { _, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) return } rows, err := s.db.Query( `SELECT id, name, hostname, ip_address, status, last_heartbeat, capabilities, created_at FROM runners ORDER BY created_at DESC`, ) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query runners: %v", err)) return } defer rows.Close() runners := []types.Runner{} for rows.Next() { var runner types.Runner err := rows.Scan( &runner.ID, &runner.Name, &runner.Hostname, &runner.IPAddress, &runner.Status, &runner.LastHeartbeat, &runner.Capabilities, &runner.CreatedAt, ) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan runner: %v", err)) return } runners = append(runners, runner) } s.respondJSON(w, http.StatusOK, runners) } // runnerAuthMiddleware verifies runner requests using HMAC signatures func (s *Server) runnerAuthMiddleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Get runner ID from query string runnerIDStr := r.URL.Query().Get("runner_id") if runnerIDStr == "" { s.respondError(w, http.StatusBadRequest, "runner_id required in query string") return } var runnerID int64 _, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID) if err != nil { s.respondError(w, http.StatusBadRequest, "invalid runner_id") return } // Get runner secret runnerSecret, err := s.secrets.GetRunnerSecret(runnerID) if err != nil { s.respondError(w, http.StatusUnauthorized, "runner not found or not verified") return } // Verify request signature valid, err := auth.VerifyRequest(r, runnerSecret, 5*time.Minute) if err != nil || !valid { s.respondError(w, http.StatusUnauthorized, "invalid signature") return } // Add runner ID to context ctx := r.Context() ctx = context.WithValue(ctx, "runner_id", runnerID) next(w, r.WithContext(ctx)) } } // handleRegisterRunner registers a new runner func (s *Server) handleRegisterRunner(w http.ResponseWriter, r *http.Request) { var req struct { types.RegisterRunnerRequest RegistrationToken string `json:"registration_token"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { s.respondError(w, http.StatusBadRequest, "Invalid request body") return } if req.Name == "" { s.respondError(w, http.StatusBadRequest, "Runner name is required") return } if req.RegistrationToken == "" { s.respondError(w, http.StatusBadRequest, "Registration token is required") return } // Validate registration token valid, err := s.secrets.ValidateRegistrationToken(req.RegistrationToken) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to validate token: %v", err)) return } if !valid { s.respondError(w, http.StatusUnauthorized, "Invalid or expired registration token") return } // Get manager secret managerSecret, err := s.secrets.GetManagerSecret() if err != nil { s.respondError(w, http.StatusInternalServerError, "Failed to get manager secret") return } // Generate runner secret runnerSecret, err := s.secrets.GenerateRunnerSecret() if err != nil { s.respondError(w, http.StatusInternalServerError, "Failed to generate runner secret") return } // Register runner result, err := s.db.Exec( `INSERT INTO runners (name, hostname, ip_address, status, last_heartbeat, capabilities, registration_token, runner_secret, manager_secret, verified) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, req.Name, req.Hostname, req.IPAddress, types.RunnerStatusOnline, time.Now(), req.Capabilities, req.RegistrationToken, runnerSecret, managerSecret, true, ) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to register runner: %v", err)) return } runnerID, _ := result.LastInsertId() // Return runner info with secrets s.respondJSON(w, http.StatusCreated, map[string]interface{}{ "id": runnerID, "name": req.Name, "hostname": req.Hostname, "ip_address": req.IPAddress, "status": types.RunnerStatusOnline, "runner_secret": runnerSecret, "manager_secret": managerSecret, "verified": true, }) } // handleUpdateTaskProgress updates task progress func (s *Server) handleUpdateTaskProgress(w http.ResponseWriter, r *http.Request) { _, err := parseID(r, "id") if err != nil { s.respondError(w, http.StatusBadRequest, err.Error()) return } var req struct { Progress float64 `json:"progress"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { s.respondError(w, http.StatusBadRequest, "Invalid request body") return } // This is mainly for logging/debugging, actual progress is calculated from completed tasks s.respondJSON(w, http.StatusOK, map[string]string{"message": "Progress updated"}) } // handleUpdateTaskStep handles step start/complete events from runners func (s *Server) handleUpdateTaskStep(w http.ResponseWriter, r *http.Request) { // Get runner ID from context (set by runnerAuthMiddleware) runnerID, ok := r.Context().Value("runner_id").(int64) if !ok { s.respondError(w, http.StatusUnauthorized, "runner_id not found in context") return } taskID, err := parseID(r, "id") if err != nil { s.respondError(w, http.StatusBadRequest, err.Error()) return } var req struct { StepName string `json:"step_name"` Status string `json:"status"` // "pending", "running", "completed", "failed", "skipped" DurationMs *int `json:"duration_ms,omitempty"` ErrorMessage string `json:"error_message,omitempty"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { s.respondError(w, http.StatusBadRequest, "Invalid request body") return } // Verify task belongs to runner var taskRunnerID sql.NullInt64 err = s.db.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskID).Scan(&taskRunnerID) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Task not found") return } if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify task: %v", err)) return } if !taskRunnerID.Valid || taskRunnerID.Int64 != runnerID { s.respondError(w, http.StatusForbidden, "Task does not belong to this runner") return } now := time.Now() var stepID int64 // Check if step already exists var existingStepID sql.NullInt64 err = s.db.QueryRow( `SELECT id FROM task_steps WHERE task_id = ? AND step_name = ?`, taskID, req.StepName, ).Scan(&existingStepID) if err == sql.ErrNoRows || !existingStepID.Valid { // Create new step var startedAt *time.Time var completedAt *time.Time if req.Status == string(types.StepStatusRunning) || req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) { startedAt = &now } if req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) { completedAt = &now } result, err := s.db.Exec( `INSERT INTO task_steps (task_id, step_name, status, started_at, completed_at, duration_ms, error_message) VALUES (?, ?, ?, ?, ?, ?, ?)`, taskID, req.StepName, req.Status, startedAt, completedAt, req.DurationMs, req.ErrorMessage, ) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create step: %v", err)) return } stepID, _ = result.LastInsertId() } else { // Update existing step stepID = existingStepID.Int64 var startedAt *time.Time var completedAt *time.Time // Get existing started_at if status is running/completed/failed if req.Status == string(types.StepStatusRunning) || req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) { var existingStartedAt sql.NullTime s.db.QueryRow(`SELECT started_at FROM task_steps WHERE id = ?`, stepID).Scan(&existingStartedAt) if existingStartedAt.Valid { startedAt = &existingStartedAt.Time } else { startedAt = &now } } if req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) { completedAt = &now } _, err = s.db.Exec( `UPDATE task_steps SET status = ?, started_at = ?, completed_at = ?, duration_ms = ?, error_message = ? WHERE id = ?`, req.Status, startedAt, completedAt, req.DurationMs, req.ErrorMessage, stepID, ) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update step: %v", err)) return } } s.respondJSON(w, http.StatusOK, map[string]interface{}{ "step_id": stepID, "message": "Step updated successfully", }) } // handleDownloadFileForRunner allows runners to download job files func (s *Server) handleDownloadFileForRunner(w http.ResponseWriter, r *http.Request) { jobID, err := parseID(r, "jobId") if err != nil { s.respondError(w, http.StatusBadRequest, err.Error()) return } fileName := chi.URLParam(r, "fileName") // Find the file in the database var filePath string err = s.db.QueryRow( `SELECT file_path FROM job_files WHERE job_id = ? AND file_name = ?`, jobID, fileName, ).Scan(&filePath) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "File not found") return } if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query file: %v", err)) return } // Open and serve file file, err := s.storage.GetFile(filePath) if err != nil { s.respondError(w, http.StatusNotFound, "File not found on disk") return } defer file.Close() w.Header().Set("Content-Type", "application/octet-stream") w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", fileName)) io.Copy(w, file) } // handleUploadFileFromRunner allows runners to upload output files func (s *Server) handleUploadFileFromRunner(w http.ResponseWriter, r *http.Request) { jobID, err := parseID(r, "jobId") if err != nil { s.respondError(w, http.StatusBadRequest, err.Error()) return } err = r.ParseMultipartForm(100 << 20) // 100 MB if err != nil { s.respondError(w, http.StatusBadRequest, "Failed to parse form") return } file, header, err := r.FormFile("file") if err != nil { s.respondError(w, http.StatusBadRequest, "No file provided") return } defer file.Close() // Save file filePath, err := s.storage.SaveOutput(jobID, header.Filename, file) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to save file: %v", err)) return } // Record in database _, err = s.db.Exec( `INSERT INTO job_files (job_id, file_type, file_path, file_name, file_size) VALUES (?, ?, ?, ?, ?)`, jobID, types.JobFileTypeOutput, filePath, header.Filename, header.Size, ) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to record file: %v", err)) return } s.respondJSON(w, http.StatusCreated, map[string]interface{}{ "file_path": filePath, "file_name": header.Filename, }) } // generateMP4Video generates MP4 video from PNG frames for a completed job func (s *Server) generateMP4Video(jobID int64) { // This would be called by a runner or external process // For now, we'll create a special task that runners can pick up // In a production system, you might want to use a job queue or have a dedicated video processor // Get all PNG output files for this job rows, err := s.db.Query( `SELECT file_path, file_name FROM job_files WHERE job_id = ? AND file_type = ? AND file_name LIKE '%.png' ORDER BY file_name`, jobID, types.JobFileTypeOutput, ) if err != nil { log.Printf("Failed to query PNG files for job %d: %v", jobID, err) return } defer rows.Close() var pngFiles []string for rows.Next() { var filePath, fileName string if err := rows.Scan(&filePath, &fileName); err == nil { pngFiles = append(pngFiles, filePath) } } if len(pngFiles) == 0 { log.Printf("No PNG files found for job %d", jobID) return } // Note: Video generation will be handled by runners when they complete tasks // Runners can check job status and generate MP4 when all frames are complete log.Printf("Job %d completed with %d PNG frames - ready for MP4 generation", jobID, len(pngFiles)) } // handleGetJobStatusForRunner allows runners to check job status func (s *Server) handleGetJobStatusForRunner(w http.ResponseWriter, r *http.Request) { jobID, err := parseID(r, "jobId") if err != nil { s.respondError(w, http.StatusBadRequest, err.Error()) return } var job types.Job var startedAt, completedAt sql.NullTime err = s.db.QueryRow( `SELECT id, user_id, name, status, progress, frame_start, frame_end, output_format, allow_parallel_runners, created_at, started_at, completed_at, error_message FROM jobs WHERE id = ?`, jobID, ).Scan( &job.ID, &job.UserID, &job.Name, &job.Status, &job.Progress, &job.FrameStart, &job.FrameEnd, &job.OutputFormat, &job.AllowParallelRunners, &job.CreatedAt, &startedAt, &completedAt, &job.ErrorMessage, ) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") return } if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query job: %v", err)) return } if startedAt.Valid { job.StartedAt = &startedAt.Time } if completedAt.Valid { job.CompletedAt = &completedAt.Time } s.respondJSON(w, http.StatusOK, job) } // handleGetJobFilesForRunner allows runners to get job files func (s *Server) handleGetJobFilesForRunner(w http.ResponseWriter, r *http.Request) { jobID, err := parseID(r, "jobId") if err != nil { s.respondError(w, http.StatusBadRequest, err.Error()) return } rows, err := s.db.Query( `SELECT id, job_id, file_type, file_path, file_name, file_size, created_at FROM job_files WHERE job_id = ? ORDER BY file_name`, jobID, ) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query files: %v", err)) return } defer rows.Close() files := []types.JobFile{} for rows.Next() { var file types.JobFile err := rows.Scan( &file.ID, &file.JobID, &file.FileType, &file.FilePath, &file.FileName, &file.FileSize, &file.CreatedAt, ) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan file: %v", err)) return } files = append(files, file) } s.respondJSON(w, http.StatusOK, files) } // WebSocket message types type WSMessage struct { Type string `json:"type"` Data json.RawMessage `json:"data"` Timestamp int64 `json:"timestamp"` } type WSTaskAssignment struct { TaskID int64 `json:"task_id"` JobID int64 `json:"job_id"` JobName string `json:"job_name"` OutputFormat string `json:"output_format"` FrameStart int `json:"frame_start"` FrameEnd int `json:"frame_end"` TaskType string `json:"task_type"` InputFiles []string `json:"input_files"` } type WSLogEntry struct { TaskID int64 `json:"task_id"` LogLevel string `json:"log_level"` Message string `json:"message"` StepName string `json:"step_name,omitempty"` } type WSTaskUpdate struct { TaskID int64 `json:"task_id"` Status string `json:"status"` OutputPath string `json:"output_path,omitempty"` Success bool `json:"success"` Error string `json:"error,omitempty"` } // handleRunnerWebSocket handles WebSocket connections from runners func (s *Server) handleRunnerWebSocket(w http.ResponseWriter, r *http.Request) { // Get runner ID and signature from query params runnerIDStr := r.URL.Query().Get("runner_id") signature := r.URL.Query().Get("signature") timestampStr := r.URL.Query().Get("timestamp") if runnerIDStr == "" || signature == "" || timestampStr == "" { s.respondError(w, http.StatusBadRequest, "runner_id, signature, and timestamp required") return } var runnerID int64 _, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID) if err != nil { s.respondError(w, http.StatusBadRequest, "invalid runner_id") return } // Get runner secret runnerSecret, err := s.secrets.GetRunnerSecret(runnerID) if err != nil { s.respondError(w, http.StatusUnauthorized, "runner not found or not verified") return } // Verify signature var timestamp int64 _, err = fmt.Sscanf(timestampStr, "%d", ×tamp) if err != nil { s.respondError(w, http.StatusBadRequest, "invalid timestamp") return } // Verify signature manually (similar to HTTP auth) timestampTime := time.Unix(timestamp, 0) // Check timestamp is not too old if time.Since(timestampTime) > 5*time.Minute { s.respondError(w, http.StatusUnauthorized, "timestamp too old") return } // Check timestamp is not in the future (allow 1 minute clock skew) if timestampTime.After(time.Now().Add(1 * time.Minute)) { s.respondError(w, http.StatusUnauthorized, "timestamp in future") return } // Build the message that should be signed path := r.URL.Path expectedSig := auth.SignRequest("GET", path, "", runnerSecret, timestampTime) // Compare signatures (constant-time) if subtle.ConstantTimeCompare([]byte(signature), []byte(expectedSig)) != 1 { s.respondError(w, http.StatusUnauthorized, "invalid signature") return } // Upgrade to WebSocket conn, err := s.wsUpgrader.Upgrade(w, r, nil) if err != nil { log.Printf("Failed to upgrade WebSocket: %v", err) return } defer conn.Close() // Register connection s.runnerConnsMu.Lock() // Remove old connection if exists if oldConn, exists := s.runnerConns[runnerID]; exists { oldConn.Close() } s.runnerConns[runnerID] = conn s.runnerConnsMu.Unlock() // Update runner status to online _, _ = s.db.Exec( `UPDATE runners SET status = ?, last_heartbeat = ? WHERE id = ?`, types.RunnerStatusOnline, time.Now(), runnerID, ) // Cleanup on disconnect defer func() { s.runnerConnsMu.Lock() delete(s.runnerConns, runnerID) s.runnerConnsMu.Unlock() _, _ = s.db.Exec( `UPDATE runners SET status = ? WHERE id = ?`, types.RunnerStatusOffline, runnerID, ) }() // Set ping handler to update heartbeat conn.SetPongHandler(func(string) error { _, _ = s.db.Exec( `UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`, time.Now(), types.RunnerStatusOnline, runnerID, ) return nil }) // Send ping every 30 seconds go func() { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { select { case <-ticker.C: s.runnerConnsMu.RLock() conn, exists := s.runnerConns[runnerID] s.runnerConnsMu.RUnlock() if !exists { return } if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { return } } } }() // Handle incoming messages for { var msg WSMessage err := conn.ReadJSON(&msg) if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.Printf("WebSocket error for runner %d: %v", runnerID, err) } break } switch msg.Type { case "heartbeat": // Update heartbeat _, _ = s.db.Exec( `UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`, time.Now(), types.RunnerStatusOnline, runnerID, ) case "log_entry": var logEntry WSLogEntry if err := json.Unmarshal(msg.Data, &logEntry); err == nil { s.handleWebSocketLog(runnerID, logEntry) } case "task_update": var taskUpdate WSTaskUpdate if err := json.Unmarshal(msg.Data, &taskUpdate); err == nil { s.handleWebSocketTaskUpdate(runnerID, taskUpdate) } case "task_complete": var taskUpdate WSTaskUpdate if err := json.Unmarshal(msg.Data, &taskUpdate); err == nil { s.handleWebSocketTaskComplete(runnerID, taskUpdate) } } } } // handleWebSocketLog handles log entries from WebSocket func (s *Server) handleWebSocketLog(runnerID int64, logEntry WSLogEntry) { // Store log in database _, err := s.db.Exec( `INSERT INTO task_logs (task_id, runner_id, log_level, message, step_name, created_at) VALUES (?, ?, ?, ?, ?, ?)`, logEntry.TaskID, runnerID, logEntry.LogLevel, logEntry.Message, logEntry.StepName, time.Now(), ) if err != nil { log.Printf("Failed to store log: %v", err) return } // Broadcast to frontend clients s.broadcastLogToFrontend(logEntry.TaskID, logEntry) } // handleWebSocketTaskUpdate handles task status updates from WebSocket func (s *Server) handleWebSocketTaskUpdate(runnerID int64, taskUpdate WSTaskUpdate) { // This can be used for progress updates // For now, we'll just log it log.Printf("Task %d update from runner %d: %s", taskUpdate.TaskID, runnerID, taskUpdate.Status) } // handleWebSocketTaskComplete handles task completion from WebSocket func (s *Server) handleWebSocketTaskComplete(runnerID int64, taskUpdate WSTaskUpdate) { // Verify task belongs to runner var taskRunnerID sql.NullInt64 err := s.db.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskUpdate.TaskID).Scan(&taskRunnerID) if err != nil || !taskRunnerID.Valid || taskRunnerID.Int64 != runnerID { log.Printf("Task %d does not belong to runner %d", taskUpdate.TaskID, runnerID) return } status := types.TaskStatusCompleted if !taskUpdate.Success { status = types.TaskStatusFailed } now := time.Now() _, err = s.db.Exec( `UPDATE tasks SET status = ?, output_path = ?, completed_at = ?, error_message = ? WHERE id = ?`, status, taskUpdate.OutputPath, now, taskUpdate.Error, taskUpdate.TaskID, ) if err != nil { log.Printf("Failed to update task: %v", err) return } // Update job progress var jobID int64 var frameStart, frameEnd int err = s.db.QueryRow( `SELECT job_id, frame_start, frame_end FROM tasks WHERE id = ?`, taskUpdate.TaskID, ).Scan(&jobID, &frameStart, &frameEnd) if err == nil { // Count total tasks excluding failed ones (failed tasks are retried, so we count them) // We exclude tasks that are in a terminal failed state with max retries exceeded var totalTasks, completedTasks int s.db.QueryRow( `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?, ?, ?)`, jobID, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed, ).Scan(&totalTasks) s.db.QueryRow( `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`, jobID, types.TaskStatusCompleted, ).Scan(&completedTasks) // Handle edge cases: division by zero and all tasks cancelled var progress float64 if totalTasks == 0 { // All tasks cancelled or no tasks, set progress to 0 progress = 0.0 } else { progress = float64(completedTasks) / float64(totalTasks) * 100.0 } var jobStatus string var outputFormat string s.db.QueryRow(`SELECT output_format FROM jobs WHERE id = ?`, jobID).Scan(&outputFormat) // Check if all non-cancelled tasks are completed var pendingOrRunningTasks int s.db.QueryRow( `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?)`, jobID, types.TaskStatusPending, types.TaskStatusRunning, ).Scan(&pendingOrRunningTasks) if pendingOrRunningTasks == 0 && totalTasks > 0 { // All tasks are either completed or failed/cancelled jobStatus = string(types.JobStatusCompleted) s.db.Exec( `UPDATE jobs SET status = ?, progress = ?, completed_at = ? WHERE id = ?`, jobStatus, progress, now, jobID, ) if outputFormat == "MP4" { // Create a video generation task instead of calling generateMP4Video directly // This prevents race conditions when multiple runners complete frames simultaneously videoTaskTimeout := 86400 // 24 hours for video generation _, err := s.db.Exec( `INSERT INTO tasks (job_id, frame_start, frame_end, task_type, status, timeout_seconds, max_retries) VALUES (?, ?, ?, ?, ?, ?, ?)`, jobID, 0, 0, types.TaskTypeVideoGeneration, types.TaskStatusPending, videoTaskTimeout, 1, ) if err != nil { log.Printf("Failed to create video generation task for job %d: %v", jobID, err) } else { // Try to distribute the task immediately go s.distributeTasksToRunners() } } } else { jobStatus = string(types.JobStatusRunning) var startedAt sql.NullTime s.db.QueryRow(`SELECT started_at FROM jobs WHERE id = ?`, jobID).Scan(&startedAt) if !startedAt.Valid { s.db.Exec(`UPDATE jobs SET started_at = ? WHERE id = ?`, now, jobID) } s.db.Exec( `UPDATE jobs SET status = ?, progress = ? WHERE id = ?`, jobStatus, progress, jobID, ) } } } // broadcastLogToFrontend broadcasts log to connected frontend clients func (s *Server) broadcastLogToFrontend(taskID int64, logEntry WSLogEntry) { // Get job_id from task var jobID int64 err := s.db.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&jobID) if err != nil { return } key := fmt.Sprintf("%d:%d", jobID, taskID) s.frontendConnsMu.RLock() conn, exists := s.frontendConns[key] s.frontendConnsMu.RUnlock() if exists && conn != nil { // Get full log entry from database for consistency var log types.TaskLog var runnerID sql.NullInt64 err := s.db.QueryRow( `SELECT id, task_id, runner_id, log_level, message, step_name, created_at FROM task_logs WHERE task_id = ? AND message = ? ORDER BY id DESC LIMIT 1`, taskID, logEntry.Message, ).Scan(&log.ID, &log.TaskID, &runnerID, &log.LogLevel, &log.Message, &log.StepName, &log.CreatedAt) if err == nil { if runnerID.Valid { log.RunnerID = &runnerID.Int64 } msg := map[string]interface{}{ "type": "log", "data": log, "timestamp": time.Now().Unix(), } conn.WriteJSON(msg) } } } // distributeTasksToRunners pushes available tasks to connected runners func (s *Server) distributeTasksToRunners() { // Get all pending tasks rows, err := s.db.Query( `SELECT t.id, t.job_id, t.frame_start, t.frame_end, t.task_type, j.allow_parallel_runners, j.status as job_status FROM tasks t JOIN jobs j ON t.job_id = j.id WHERE t.status = ? AND j.status != ? ORDER BY t.created_at ASC LIMIT 100`, types.TaskStatusPending, types.JobStatusCancelled, ) if err != nil { log.Printf("Failed to query pending tasks: %v", err) return } defer rows.Close() var pendingTasks []struct { TaskID int64 JobID int64 FrameStart int FrameEnd int TaskType string AllowParallelRunners bool } for rows.Next() { var t struct { TaskID int64 JobID int64 FrameStart int FrameEnd int TaskType string AllowParallelRunners bool } var allowParallel int var jobStatus string err := rows.Scan(&t.TaskID, &t.JobID, &t.FrameStart, &t.FrameEnd, &t.TaskType, &allowParallel, &jobStatus) if err != nil { continue } t.AllowParallelRunners = allowParallel == 1 pendingTasks = append(pendingTasks, t) } if len(pendingTasks) == 0 { return } // Get connected runners s.runnerConnsMu.RLock() connectedRunners := make([]int64, 0, len(s.runnerConns)) for runnerID := range s.runnerConns { connectedRunners = append(connectedRunners, runnerID) } s.runnerConnsMu.RUnlock() if len(connectedRunners) == 0 { return } // Distribute tasks to runners for _, task := range pendingTasks { // Find available runner var selectedRunnerID int64 for _, runnerID := range connectedRunners { // Check if runner is busy (has running tasks) var runningCount int s.db.QueryRow( `SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status = ?`, runnerID, types.TaskStatusRunning, ).Scan(&runningCount) if runningCount > 0 { continue // Runner is busy } // For non-parallel jobs, check if runner already has tasks from this job if !task.AllowParallelRunners { var jobTaskCount int s.db.QueryRow( `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND runner_id = ? AND status IN (?, ?)`, task.JobID, runnerID, types.TaskStatusPending, types.TaskStatusRunning, ).Scan(&jobTaskCount) if jobTaskCount > 0 { continue // Another runner is working on this job } } selectedRunnerID = runnerID break } if selectedRunnerID == 0 { continue // No available runner } // Atomically assign task to runner using UPDATE with WHERE runner_id IS NULL // This prevents race conditions when multiple goroutines try to assign the same task now := time.Now() result, err := s.db.Exec( `UPDATE tasks SET runner_id = ?, status = ?, started_at = ? WHERE id = ? AND runner_id IS NULL AND status = ?`, selectedRunnerID, types.TaskStatusRunning, now, task.TaskID, types.TaskStatusPending, ) if err != nil { log.Printf("Failed to atomically assign task %d: %v", task.TaskID, err) continue } // Check if the update actually affected a row (task was successfully assigned) rowsAffected, err := result.RowsAffected() if err != nil { log.Printf("Failed to get rows affected for task %d: %v", task.TaskID, err) continue } if rowsAffected == 0 { // Task was already assigned by another goroutine, skip continue } // Task was successfully assigned, send via WebSocket if err := s.assignTaskToRunner(selectedRunnerID, task.TaskID); err != nil { log.Printf("Failed to send task %d to runner %d: %v", task.TaskID, selectedRunnerID, err) // Rollback the assignment if WebSocket send fails s.db.Exec( `UPDATE tasks SET runner_id = NULL, status = ?, started_at = NULL WHERE id = ?`, types.TaskStatusPending, task.TaskID, ) } } } // assignTaskToRunner sends a task to a runner via WebSocket func (s *Server) assignTaskToRunner(runnerID int64, taskID int64) error { s.runnerConnsMu.RLock() conn, exists := s.runnerConns[runnerID] s.runnerConnsMu.RUnlock() if !exists { return fmt.Errorf("runner %d not connected", runnerID) } // Get task details var task WSTaskAssignment var jobName, outputFormat, taskType string err := s.db.QueryRow( `SELECT t.job_id, t.frame_start, t.frame_end, t.task_type, j.name, j.output_format FROM tasks t JOIN jobs j ON t.job_id = j.id WHERE t.id = ?`, taskID, ).Scan(&task.JobID, &task.FrameStart, &task.FrameEnd, &taskType, &jobName, &outputFormat) if err != nil { return err } task.TaskID = taskID task.JobName = jobName task.OutputFormat = outputFormat task.TaskType = taskType // Get input files rows, err := s.db.Query( `SELECT file_path FROM job_files WHERE job_id = ? AND file_type = ?`, task.JobID, types.JobFileTypeInput, ) if err == nil { defer rows.Close() for rows.Next() { var filePath string if err := rows.Scan(&filePath); err == nil { task.InputFiles = append(task.InputFiles, filePath) } } } // Note: Task is already assigned in database by the atomic update in distributeTasksToRunners // We just need to verify it's still assigned to this runner var assignedRunnerID sql.NullInt64 err = s.db.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskID).Scan(&assignedRunnerID) if err != nil { return fmt.Errorf("task not found: %w", err) } if !assignedRunnerID.Valid || assignedRunnerID.Int64 != runnerID { return fmt.Errorf("task %d is not assigned to runner %d", taskID, runnerID) } // Send task via WebSocket msg := WSMessage{ Type: "task_assignment", Timestamp: time.Now().Unix(), } msg.Data, _ = json.Marshal(task) return conn.WriteJSON(msg) }