diff --git a/cmd/runner/main.go b/cmd/runner/main.go index 0eb6eb1..1ac2bfd 100644 --- a/cmd/runner/main.go +++ b/cmd/runner/main.go @@ -81,11 +81,18 @@ func main() { } } - // Start heartbeat loop + // Start WebSocket connection with reconnection + go client.ConnectWebSocketWithReconnect() + + // Start heartbeat loop (for WebSocket ping/pong and HTTP fallback) go client.HeartbeatLoop() - // Start task processing loop - client.ProcessTasks() + // ProcessTasks is now handled via WebSocket, but kept for HTTP fallback + // WebSocket will handle task assignment automatically + log.Printf("Runner started, connecting to manager via WebSocket...") + + // Block forever + select {} } func loadSecrets(path string) (*SecretsFile, error) { diff --git a/go.mod b/go.mod index c0756b2..698388e 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/go-chi/chi/v5 v5.2.3 // indirect github.com/go-chi/cors v1.2.2 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/mattn/go-sqlite3 v1.14.32 // indirect golang.org/x/oauth2 v0.33.0 // indirect ) diff --git a/go.sum b/go.sum index 95b0f42..0ab934d 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/go-chi/cors v1.2.2 h1:Jmey33TE+b+rB7fT8MUy1u0I4L+NARQlK6LhzKPSyQE= github.com/go-chi/cors v1.2.2/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= golang.org/x/oauth2 v0.33.0 h1:4Q+qn+E5z8gPRJfmRy7C2gGG3T4jIprK6aSYgTXGRpo= diff --git a/internal/api/jobs.go b/internal/api/jobs.go index 6e61dc4..308454b 100644 --- a/internal/api/jobs.go +++ b/internal/api/jobs.go @@ -5,10 +5,14 @@ import ( "encoding/json" "fmt" "io" + "log" "net/http" + "strconv" "time" "fuego/pkg/types" + + "github.com/go-chi/chi/v5" ) // handleCreateJob creates a new job @@ -39,10 +43,19 @@ func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) { req.OutputFormat = "PNG" } + // Default allow_parallel_runners to true if not provided + allowParallelRunners := true + if req.AllowParallelRunners != nil { + allowParallelRunners = *req.AllowParallelRunners + } + + // Set job timeout to 24 hours (86400 seconds) + jobTimeout := 86400 + result, err := s.db.Exec( - `INSERT INTO jobs (user_id, name, status, progress, frame_start, frame_end, output_format) - VALUES (?, ?, ?, ?, ?, ?, ?)`, - userID, req.Name, types.JobStatusPending, 0.0, req.FrameStart, req.FrameEnd, req.OutputFormat, + `INSERT INTO jobs (user_id, name, status, progress, frame_start, frame_end, output_format, allow_parallel_runners, timeout_seconds) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + userID, req.Name, types.JobStatusPending, 0.0, req.FrameStart, req.FrameEnd, req.OutputFormat, allowParallelRunners, jobTimeout, ) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create job: %v", err)) @@ -51,11 +64,21 @@ func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) { jobID, _ := result.LastInsertId() + // Determine task timeout based on output format + // 5 minutes (300 seconds) for frame tasks, 24 hours (86400 seconds) for FFmpeg video generation + taskTimeout := 300 // Default: 5 minutes for frame rendering + if req.OutputFormat == "MP4" { + // For MP4, we'll create frame tasks with 5 min timeout + // Video generation tasks will be created later with 24h timeout + taskTimeout = 300 + } + // Create tasks for the job (one task per frame for simplicity, could be batched) for frame := req.FrameStart; frame <= req.FrameEnd; frame++ { _, err = s.db.Exec( - `INSERT INTO tasks (job_id, frame_start, frame_end, status) VALUES (?, ?, ?, ?)`, - jobID, frame, frame, types.TaskStatusPending, + `INSERT INTO tasks (job_id, frame_start, frame_end, status, timeout_seconds, max_retries) + VALUES (?, ?, ?, ?, ?, ?)`, + jobID, frame, frame, types.TaskStatusPending, taskTimeout, 3, ) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create tasks: %v", err)) @@ -64,17 +87,22 @@ func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) { } job := types.Job{ - ID: jobID, - UserID: userID, - Name: req.Name, - Status: types.JobStatusPending, - Progress: 0.0, - FrameStart: req.FrameStart, - FrameEnd: req.FrameEnd, - OutputFormat: req.OutputFormat, - CreatedAt: time.Now(), + ID: jobID, + UserID: userID, + Name: req.Name, + Status: types.JobStatusPending, + Progress: 0.0, + FrameStart: req.FrameStart, + FrameEnd: req.FrameEnd, + OutputFormat: req.OutputFormat, + AllowParallelRunners: allowParallelRunners, + TimeoutSeconds: jobTimeout, + CreatedAt: time.Now(), } + // Immediately try to distribute tasks to connected runners + go s.distributeTasksToRunners() + s.respondJSON(w, http.StatusCreated, job) } @@ -88,7 +116,7 @@ func (s *Server) handleListJobs(w http.ResponseWriter, r *http.Request) { rows, err := s.db.Query( `SELECT id, user_id, name, status, progress, frame_start, frame_end, output_format, - created_at, started_at, completed_at, error_message + allow_parallel_runners, timeout_seconds, created_at, started_at, completed_at, error_message FROM jobs WHERE user_id = ? ORDER BY created_at DESC`, userID, ) @@ -105,7 +133,7 @@ func (s *Server) handleListJobs(w http.ResponseWriter, r *http.Request) { err := rows.Scan( &job.ID, &job.UserID, &job.Name, &job.Status, &job.Progress, - &job.FrameStart, &job.FrameEnd, &job.OutputFormat, + &job.FrameStart, &job.FrameEnd, &job.OutputFormat, &job.AllowParallelRunners, &job.TimeoutSeconds, &job.CreatedAt, &startedAt, &completedAt, &job.ErrorMessage, ) if err != nil { @@ -145,12 +173,12 @@ func (s *Server) handleGetJob(w http.ResponseWriter, r *http.Request) { err = s.db.QueryRow( `SELECT id, user_id, name, status, progress, frame_start, frame_end, output_format, - created_at, started_at, completed_at, error_message + allow_parallel_runners, timeout_seconds, created_at, started_at, completed_at, error_message FROM jobs WHERE id = ? AND user_id = ?`, jobID, userID, ).Scan( &job.ID, &job.UserID, &job.Name, &job.Status, &job.Progress, - &job.FrameStart, &job.FrameEnd, &job.OutputFormat, + &job.FrameStart, &job.FrameEnd, &job.OutputFormat, &job.AllowParallelRunners, &job.TimeoutSeconds, &job.CreatedAt, &startedAt, &completedAt, &job.ErrorMessage, ) @@ -280,7 +308,7 @@ func (s *Server) handleUploadJobFile(w http.ResponseWriter, r *http.Request) { fileID, _ := result.LastInsertId() s.respondJSON(w, http.StatusCreated, map[string]interface{}{ - "id": fileID, + "id": fileID, "file_name": header.Filename, "file_path": filePath, "file_size": header.Size, @@ -496,3 +524,454 @@ func (s *Server) handleStreamVideo(w http.ResponseWriter, r *http.Request) { } } +// handleGetTaskLogs retrieves logs for a specific task +func (s *Server) handleGetTaskLogs(w http.ResponseWriter, r *http.Request) { + userID, err := getUserID(r) + if err != nil { + s.respondError(w, http.StatusUnauthorized, err.Error()) + return + } + + jobID, err := parseID(r, "id") + if err != nil { + s.respondError(w, http.StatusBadRequest, err.Error()) + return + } + + taskIDStr := chi.URLParam(r, "taskId") + taskID, err := strconv.ParseInt(taskIDStr, 10, 64) + if err != nil { + s.respondError(w, http.StatusBadRequest, "Invalid task ID") + return + } + + // Verify job belongs to user + var jobUserID int64 + err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + if err == sql.ErrNoRows { + s.respondError(w, http.StatusNotFound, "Job not found") + return + } + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify job: %v", err)) + return + } + if jobUserID != userID { + s.respondError(w, http.StatusForbidden, "Access denied") + return + } + + // Verify task belongs to job + var taskJobID int64 + err = s.db.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&taskJobID) + if err == sql.ErrNoRows { + s.respondError(w, http.StatusNotFound, "Task not found") + return + } + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify task: %v", err)) + return + } + if taskJobID != jobID { + s.respondError(w, http.StatusBadRequest, "Task does not belong to this job") + return + } + + // Get query parameters for filtering + stepName := r.URL.Query().Get("step_name") + logLevel := r.URL.Query().Get("log_level") + limitStr := r.URL.Query().Get("limit") + limit := 1000 // default + if limitStr != "" { + if l, err := strconv.Atoi(limitStr); err == nil && l > 0 { + limit = l + } + } + + // Build query + query := `SELECT id, task_id, runner_id, log_level, message, step_name, created_at + FROM task_logs WHERE task_id = ?` + args := []interface{}{taskID} + if stepName != "" { + query += " AND step_name = ?" + args = append(args, stepName) + } + if logLevel != "" { + query += " AND log_level = ?" + args = append(args, logLevel) + } + query += " ORDER BY created_at ASC LIMIT ?" + args = append(args, limit) + + rows, err := s.db.Query(query, args...) + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query logs: %v", err)) + return + } + defer rows.Close() + + logs := []types.TaskLog{} + for rows.Next() { + var log types.TaskLog + var runnerID sql.NullInt64 + err := rows.Scan( + &log.ID, &log.TaskID, &runnerID, &log.LogLevel, &log.Message, + &log.StepName, &log.CreatedAt, + ) + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan log: %v", err)) + return + } + if runnerID.Valid { + log.RunnerID = &runnerID.Int64 + } + logs = append(logs, log) + } + + s.respondJSON(w, http.StatusOK, logs) +} + +// handleGetTaskSteps retrieves step timeline for a specific task +func (s *Server) handleGetTaskSteps(w http.ResponseWriter, r *http.Request) { + userID, err := getUserID(r) + if err != nil { + s.respondError(w, http.StatusUnauthorized, err.Error()) + return + } + + jobID, err := parseID(r, "id") + if err != nil { + s.respondError(w, http.StatusBadRequest, err.Error()) + return + } + + taskIDStr := chi.URLParam(r, "taskId") + taskID, err := strconv.ParseInt(taskIDStr, 10, 64) + if err != nil { + s.respondError(w, http.StatusBadRequest, "Invalid task ID") + return + } + + // Verify job belongs to user + var jobUserID int64 + err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + if err == sql.ErrNoRows { + s.respondError(w, http.StatusNotFound, "Job not found") + return + } + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify job: %v", err)) + return + } + if jobUserID != userID { + s.respondError(w, http.StatusForbidden, "Access denied") + return + } + + // Verify task belongs to job + var taskJobID int64 + err = s.db.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&taskJobID) + if err == sql.ErrNoRows { + s.respondError(w, http.StatusNotFound, "Task not found") + return + } + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify task: %v", err)) + return + } + if taskJobID != jobID { + s.respondError(w, http.StatusBadRequest, "Task does not belong to this job") + return + } + + rows, err := s.db.Query( + `SELECT id, task_id, step_name, status, started_at, completed_at, duration_ms, error_message + FROM task_steps WHERE task_id = ? ORDER BY started_at ASC`, + taskID, + ) + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query steps: %v", err)) + return + } + defer rows.Close() + + steps := []types.TaskStep{} + for rows.Next() { + var step types.TaskStep + var startedAt, completedAt sql.NullTime + var durationMs sql.NullInt64 + err := rows.Scan( + &step.ID, &step.TaskID, &step.StepName, &step.Status, + &startedAt, &completedAt, &durationMs, &step.ErrorMessage, + ) + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan step: %v", err)) + return + } + if startedAt.Valid { + step.StartedAt = &startedAt.Time + } + if completedAt.Valid { + step.CompletedAt = &completedAt.Time + } + if durationMs.Valid { + duration := int(durationMs.Int64) + step.DurationMs = &duration + } + steps = append(steps, step) + } + + s.respondJSON(w, http.StatusOK, steps) +} + +// handleRetryTask retries a failed task +func (s *Server) handleRetryTask(w http.ResponseWriter, r *http.Request) { + userID, err := getUserID(r) + if err != nil { + s.respondError(w, http.StatusUnauthorized, err.Error()) + return + } + + jobID, err := parseID(r, "id") + if err != nil { + s.respondError(w, http.StatusBadRequest, err.Error()) + return + } + + taskIDStr := chi.URLParam(r, "taskId") + taskID, err := strconv.ParseInt(taskIDStr, 10, 64) + if err != nil { + s.respondError(w, http.StatusBadRequest, "Invalid task ID") + return + } + + // Verify job belongs to user + var jobUserID int64 + err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + if err == sql.ErrNoRows { + s.respondError(w, http.StatusNotFound, "Job not found") + return + } + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify job: %v", err)) + return + } + if jobUserID != userID { + s.respondError(w, http.StatusForbidden, "Access denied") + return + } + + // Verify task belongs to job and is in a retryable state + var taskJobID int64 + var taskStatus string + var retryCount, maxRetries int + err = s.db.QueryRow( + "SELECT job_id, status, retry_count, max_retries FROM tasks WHERE id = ?", + taskID, + ).Scan(&taskJobID, &taskStatus, &retryCount, &maxRetries) + if err == sql.ErrNoRows { + s.respondError(w, http.StatusNotFound, "Task not found") + return + } + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify task: %v", err)) + return + } + if taskJobID != jobID { + s.respondError(w, http.StatusBadRequest, "Task does not belong to this job") + return + } + + if taskStatus != string(types.TaskStatusFailed) { + s.respondError(w, http.StatusBadRequest, "Task is not in failed state") + return + } + + if retryCount >= maxRetries { + s.respondError(w, http.StatusBadRequest, "Maximum retries exceeded") + return + } + + // Reset task to pending + _, err = s.db.Exec( + `UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL, + error_message = NULL, started_at = NULL, completed_at = NULL + WHERE id = ?`, + types.TaskStatusPending, taskID, + ) + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to retry task: %v", err)) + return + } + + s.respondJSON(w, http.StatusOK, map[string]string{"message": "Task queued for retry"}) +} + +// handleStreamTaskLogsWebSocket streams task logs via WebSocket +// Note: This is called after auth middleware, so userID is already verified +func (s *Server) handleStreamTaskLogsWebSocket(w http.ResponseWriter, r *http.Request) { + userID, err := getUserID(r) + if err != nil { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + jobID, err := parseID(r, "id") + if err != nil { + s.respondError(w, http.StatusBadRequest, err.Error()) + return + } + + taskIDStr := chi.URLParam(r, "taskId") + taskID, err := strconv.ParseInt(taskIDStr, 10, 64) + if err != nil { + s.respondError(w, http.StatusBadRequest, "Invalid task ID") + return + } + + // Verify job belongs to user + var jobUserID int64 + err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + if err == sql.ErrNoRows { + s.respondError(w, http.StatusNotFound, "Job not found") + return + } + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify job: %v", err)) + return + } + if jobUserID != userID { + s.respondError(w, http.StatusForbidden, "Access denied") + return + } + + // Verify task belongs to job + var taskJobID int64 + err = s.db.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&taskJobID) + if err == sql.ErrNoRows { + s.respondError(w, http.StatusNotFound, "Task not found") + return + } + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify task: %v", err)) + return + } + if taskJobID != jobID { + s.respondError(w, http.StatusBadRequest, "Task does not belong to this job") + return + } + + // Upgrade to WebSocket + conn, err := s.wsUpgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("Failed to upgrade WebSocket: %v", err) + return + } + defer conn.Close() + + key := fmt.Sprintf("%d:%d", jobID, taskID) + s.frontendConnsMu.Lock() + s.frontendConns[key] = conn + s.frontendConnsMu.Unlock() + + defer func() { + s.frontendConnsMu.Lock() + delete(s.frontendConns, key) + s.frontendConnsMu.Unlock() + }() + + // Send initial connection message + conn.WriteJSON(map[string]interface{}{ + "type": "connected", + "timestamp": time.Now().Unix(), + }) + + // Get last log ID to start streaming from + lastIDStr := r.URL.Query().Get("last_id") + lastID := int64(0) + if lastIDStr != "" { + if id, err := strconv.ParseInt(lastIDStr, 10, 64); err == nil { + lastID = id + } + } + + // Send existing logs + rows, err := s.db.Query( + `SELECT id, task_id, runner_id, log_level, message, step_name, created_at + FROM task_logs WHERE task_id = ? AND id > ? ORDER BY created_at ASC LIMIT 100`, + taskID, lastID, + ) + if err == nil { + defer rows.Close() + for rows.Next() { + var log types.TaskLog + var runnerID sql.NullInt64 + err := rows.Scan( + &log.ID, &log.TaskID, &runnerID, &log.LogLevel, &log.Message, + &log.StepName, &log.CreatedAt, + ) + if err != nil { + continue + } + if runnerID.Valid { + log.RunnerID = &runnerID.Int64 + } + if log.ID > lastID { + lastID = log.ID + } + + conn.WriteJSON(map[string]interface{}{ + "type": "log", + "data": log, + "timestamp": time.Now().Unix(), + }) + } + } + + // Poll for new logs and send them + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + ctx := r.Context() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + rows, err := s.db.Query( + `SELECT id, task_id, runner_id, log_level, message, step_name, created_at + FROM task_logs WHERE task_id = ? AND id > ? ORDER BY created_at ASC LIMIT 100`, + taskID, lastID, + ) + if err != nil { + continue + } + + for rows.Next() { + var log types.TaskLog + var runnerID sql.NullInt64 + err := rows.Scan( + &log.ID, &log.TaskID, &runnerID, &log.LogLevel, &log.Message, + &log.StepName, &log.CreatedAt, + ) + if err != nil { + rows.Close() + continue + } + if runnerID.Valid { + log.RunnerID = &runnerID.Int64 + } + if log.ID > lastID { + lastID = log.ID + } + + conn.WriteJSON(map[string]interface{}{ + "type": "log", + "data": log, + "timestamp": time.Now().Unix(), + }) + } + rows.Close() + } + } +} diff --git a/internal/api/runners.go b/internal/api/runners.go index 6306c24..d0c4d41 100644 --- a/internal/api/runners.go +++ b/internal/api/runners.go @@ -2,18 +2,20 @@ package api import ( "context" + "crypto/subtle" "database/sql" "encoding/json" "fmt" "io" "log" "net/http" - "strings" "time" - "github.com/go-chi/chi/v5" "fuego/internal/auth" "fuego/pkg/types" + + "github.com/go-chi/chi/v5" + "github.com/gorilla/websocket" ) // handleListRunners lists all runners @@ -154,215 +156,15 @@ func (s *Server) handleRegisterRunner(w http.ResponseWriter, r *http.Request) { s.respondJSON(w, http.StatusCreated, map[string]interface{}{ "id": runnerID, "name": req.Name, - "hostname": req.Hostname, - "ip_address": req.IPAddress, - "status": types.RunnerStatusOnline, - "runner_secret": runnerSecret, + "hostname": req.Hostname, + "ip_address": req.IPAddress, + "status": types.RunnerStatusOnline, + "runner_secret": runnerSecret, "manager_secret": managerSecret, - "verified": true, + "verified": true, }) } -// handleRunnerHeartbeat updates runner heartbeat -func (s *Server) handleRunnerHeartbeat(w http.ResponseWriter, r *http.Request) { - runnerID, ok := r.Context().Value("runner_id").(int64) - if !ok { - s.respondError(w, http.StatusBadRequest, "runner_id not found in context") - return - } - - _, err := s.db.Exec( - `UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`, - time.Now(), types.RunnerStatusOnline, runnerID, - ) - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update heartbeat: %v", err)) - return - } - - s.respondJSON(w, http.StatusOK, map[string]string{"message": "Heartbeat updated"}) -} - -// handleGetRunnerTasks gets pending tasks for a runner -func (s *Server) handleGetRunnerTasks(w http.ResponseWriter, r *http.Request) { - runnerID, ok := r.Context().Value("runner_id").(int64) - if !ok { - s.respondError(w, http.StatusBadRequest, "runner_id not found in context") - return - } - - // Get pending tasks - rows, err := s.db.Query( - `SELECT t.id, t.job_id, t.runner_id, t.frame_start, t.frame_end, t.status, t.output_path, - t.created_at, t.started_at, t.completed_at, t.error_message, - j.name as job_name, j.output_format - FROM tasks t - JOIN jobs j ON t.job_id = j.id - WHERE t.status = ? AND j.status != ? - ORDER BY t.created_at ASC - LIMIT 10`, - types.TaskStatusPending, types.JobStatusCancelled, - ) - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query tasks: %v", err)) - return - } - defer rows.Close() - - tasks := []map[string]interface{}{} - for rows.Next() { - var task types.Task - var runnerID sql.NullInt64 - var startedAt, completedAt sql.NullTime - var jobName, outputFormat string - - err := rows.Scan( - &task.ID, &task.JobID, &runnerID, &task.FrameStart, &task.FrameEnd, - &task.Status, &task.OutputPath, &task.CreatedAt, - &startedAt, &completedAt, &task.ErrorMessage, - &jobName, &outputFormat, - ) - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan task: %v", err)) - return - } - - if runnerID.Valid { - task.RunnerID = &runnerID.Int64 - } - if startedAt.Valid { - task.StartedAt = &startedAt.Time - } - if completedAt.Valid { - task.CompletedAt = &completedAt.Time - } - - // Get input files for the job - var inputFiles []string - fileRows, err := s.db.Query( - `SELECT file_path FROM job_files WHERE job_id = ? AND file_type = ?`, - task.JobID, types.JobFileTypeInput, - ) - if err == nil { - for fileRows.Next() { - var filePath string - if err := fileRows.Scan(&filePath); err == nil { - inputFiles = append(inputFiles, filePath) - } - } - fileRows.Close() - } - - tasks = append(tasks, map[string]interface{}{ - "task": task, - "job_name": jobName, - "output_format": outputFormat, - "input_files": inputFiles, - }) - - // Assign task to runner - _, err = s.db.Exec( - `UPDATE tasks SET runner_id = ?, status = ? WHERE id = ?`, - runnerID, types.TaskStatusRunning, task.ID, - ) - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to assign task: %v", err)) - return - } - } - - s.respondJSON(w, http.StatusOK, tasks) -} - -// handleCompleteTask marks a task as completed -func (s *Server) handleCompleteTask(w http.ResponseWriter, r *http.Request) { - taskID, err := parseID(r, "id") - if err != nil { - s.respondError(w, http.StatusBadRequest, err.Error()) - return - } - - var req struct { - OutputPath string `json:"output_path"` - Success bool `json:"success"` - Error string `json:"error,omitempty"` - } - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - s.respondError(w, http.StatusBadRequest, "Invalid request body") - return - } - - status := types.TaskStatusCompleted - if !req.Success { - status = types.TaskStatusFailed - } - - now := time.Now() - _, err = s.db.Exec( - `UPDATE tasks SET status = ?, output_path = ?, completed_at = ?, error_message = ? WHERE id = ?`, - status, req.OutputPath, now, req.Error, taskID, - ) - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update task: %v", err)) - return - } - - // Update job progress - var jobID int64 - var frameStart, frameEnd int - err = s.db.QueryRow( - `SELECT job_id, frame_start, frame_end FROM tasks WHERE id = ?`, - taskID, - ).Scan(&jobID, &frameStart, &frameEnd) - if err == nil { - // Count completed tasks - var totalTasks, completedTasks int - s.db.QueryRow( - `SELECT COUNT(*) FROM tasks WHERE job_id = ?`, - jobID, - ).Scan(&totalTasks) - s.db.QueryRow( - `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`, - jobID, types.TaskStatusCompleted, - ).Scan(&completedTasks) - - progress := float64(completedTasks) / float64(totalTasks) * 100.0 - - // Update job status - var jobStatus string - var outputFormat string - s.db.QueryRow(`SELECT output_format FROM jobs WHERE id = ?`, jobID).Scan(&outputFormat) - - if completedTasks == totalTasks { - jobStatus = string(types.JobStatusCompleted) - now := time.Now() - s.db.Exec( - `UPDATE jobs SET status = ?, progress = ?, completed_at = ? WHERE id = ?`, - jobStatus, progress, now, jobID, - ) - - // For MP4 jobs, create a video generation task - if outputFormat == "MP4" { - go s.generateMP4Video(jobID) - } - } else { - jobStatus = string(types.JobStatusRunning) - var startedAt sql.NullTime - s.db.QueryRow(`SELECT started_at FROM jobs WHERE id = ?`, jobID).Scan(&startedAt) - if !startedAt.Valid { - now := time.Now() - s.db.Exec(`UPDATE jobs SET started_at = ? WHERE id = ?`, now, jobID) - } - s.db.Exec( - `UPDATE jobs SET status = ?, progress = ? WHERE id = ?`, - jobStatus, progress, jobID, - ) - } - } - - s.respondJSON(w, http.StatusOK, map[string]string{"message": "Task completed"}) -} - // handleUpdateTaskProgress updates task progress func (s *Server) handleUpdateTaskProgress(w http.ResponseWriter, r *http.Request) { _, err := parseID(r, "id") @@ -471,7 +273,7 @@ func (s *Server) generateMP4Video(jobID int64) { // This would be called by a runner or external process // For now, we'll create a special task that runners can pick up // In a production system, you might want to use a job queue or have a dedicated video processor - + // Get all PNG output files for this job rows, err := s.db.Query( `SELECT file_path, file_name FROM job_files @@ -516,12 +318,12 @@ func (s *Server) handleGetJobStatusForRunner(w http.ResponseWriter, r *http.Requ err = s.db.QueryRow( `SELECT id, user_id, name, status, progress, frame_start, frame_end, output_format, - created_at, started_at, completed_at, error_message + allow_parallel_runners, created_at, started_at, completed_at, error_message FROM jobs WHERE id = ?`, jobID, ).Scan( &job.ID, &job.UserID, &job.Name, &job.Status, &job.Progress, - &job.FrameStart, &job.FrameEnd, &job.OutputFormat, + &job.FrameStart, &job.FrameEnd, &job.OutputFormat, &job.AllowParallelRunners, &job.CreatedAt, &startedAt, &completedAt, &job.ErrorMessage, ) @@ -580,3 +382,500 @@ func (s *Server) handleGetJobFilesForRunner(w http.ResponseWriter, r *http.Reque s.respondJSON(w, http.StatusOK, files) } +// WebSocket message types +type WSMessage struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + Timestamp int64 `json:"timestamp"` +} + +type WSTaskAssignment struct { + TaskID int64 `json:"task_id"` + JobID int64 `json:"job_id"` + JobName string `json:"job_name"` + OutputFormat string `json:"output_format"` + FrameStart int `json:"frame_start"` + FrameEnd int `json:"frame_end"` + InputFiles []string `json:"input_files"` +} + +type WSLogEntry struct { + TaskID int64 `json:"task_id"` + LogLevel string `json:"log_level"` + Message string `json:"message"` + StepName string `json:"step_name,omitempty"` +} + +type WSTaskUpdate struct { + TaskID int64 `json:"task_id"` + Status string `json:"status"` + OutputPath string `json:"output_path,omitempty"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` +} + +// handleRunnerWebSocket handles WebSocket connections from runners +func (s *Server) handleRunnerWebSocket(w http.ResponseWriter, r *http.Request) { + // Get runner ID and signature from query params + runnerIDStr := r.URL.Query().Get("runner_id") + signature := r.URL.Query().Get("signature") + timestampStr := r.URL.Query().Get("timestamp") + + if runnerIDStr == "" || signature == "" || timestampStr == "" { + s.respondError(w, http.StatusBadRequest, "runner_id, signature, and timestamp required") + return + } + + var runnerID int64 + _, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID) + if err != nil { + s.respondError(w, http.StatusBadRequest, "invalid runner_id") + return + } + + // Get runner secret + runnerSecret, err := s.secrets.GetRunnerSecret(runnerID) + if err != nil { + s.respondError(w, http.StatusUnauthorized, "runner not found or not verified") + return + } + + // Verify signature + var timestamp int64 + _, err = fmt.Sscanf(timestampStr, "%d", ×tamp) + if err != nil { + s.respondError(w, http.StatusBadRequest, "invalid timestamp") + return + } + + // Verify signature manually (similar to HTTP auth) + timestampTime := time.Unix(timestamp, 0) + + // Check timestamp is not too old + if time.Since(timestampTime) > 5*time.Minute { + s.respondError(w, http.StatusUnauthorized, "timestamp too old") + return + } + + // Check timestamp is not in the future (allow 1 minute clock skew) + if timestampTime.After(time.Now().Add(1 * time.Minute)) { + s.respondError(w, http.StatusUnauthorized, "timestamp in future") + return + } + + // Build the message that should be signed + path := r.URL.Path + expectedSig := auth.SignRequest("GET", path, "", runnerSecret, timestampTime) + + // Compare signatures (constant-time) + if subtle.ConstantTimeCompare([]byte(signature), []byte(expectedSig)) != 1 { + s.respondError(w, http.StatusUnauthorized, "invalid signature") + return + } + + // Upgrade to WebSocket + conn, err := s.wsUpgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("Failed to upgrade WebSocket: %v", err) + return + } + defer conn.Close() + + // Register connection + s.runnerConnsMu.Lock() + // Remove old connection if exists + if oldConn, exists := s.runnerConns[runnerID]; exists { + oldConn.Close() + } + s.runnerConns[runnerID] = conn + s.runnerConnsMu.Unlock() + + // Update runner status to online + _, _ = s.db.Exec( + `UPDATE runners SET status = ?, last_heartbeat = ? WHERE id = ?`, + types.RunnerStatusOnline, time.Now(), runnerID, + ) + + // Cleanup on disconnect + defer func() { + s.runnerConnsMu.Lock() + delete(s.runnerConns, runnerID) + s.runnerConnsMu.Unlock() + _, _ = s.db.Exec( + `UPDATE runners SET status = ? WHERE id = ?`, + types.RunnerStatusOffline, runnerID, + ) + }() + + // Set ping handler to update heartbeat + conn.SetPongHandler(func(string) error { + _, _ = s.db.Exec( + `UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`, + time.Now(), types.RunnerStatusOnline, runnerID, + ) + return nil + }) + + // Send ping every 30 seconds + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + s.runnerConnsMu.RLock() + conn, exists := s.runnerConns[runnerID] + s.runnerConnsMu.RUnlock() + if !exists { + return + } + if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { + return + } + } + } + }() + + // Handle incoming messages + for { + var msg WSMessage + err := conn.ReadJSON(&msg) + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Printf("WebSocket error for runner %d: %v", runnerID, err) + } + break + } + + switch msg.Type { + case "heartbeat": + // Update heartbeat + _, _ = s.db.Exec( + `UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`, + time.Now(), types.RunnerStatusOnline, runnerID, + ) + + case "log_entry": + var logEntry WSLogEntry + if err := json.Unmarshal(msg.Data, &logEntry); err == nil { + s.handleWebSocketLog(runnerID, logEntry) + } + + case "task_update": + var taskUpdate WSTaskUpdate + if err := json.Unmarshal(msg.Data, &taskUpdate); err == nil { + s.handleWebSocketTaskUpdate(runnerID, taskUpdate) + } + + case "task_complete": + var taskUpdate WSTaskUpdate + if err := json.Unmarshal(msg.Data, &taskUpdate); err == nil { + s.handleWebSocketTaskComplete(runnerID, taskUpdate) + } + } + } +} + +// handleWebSocketLog handles log entries from WebSocket +func (s *Server) handleWebSocketLog(runnerID int64, logEntry WSLogEntry) { + // Store log in database + _, err := s.db.Exec( + `INSERT INTO task_logs (task_id, runner_id, log_level, message, step_name, created_at) + VALUES (?, ?, ?, ?, ?, ?)`, + logEntry.TaskID, runnerID, logEntry.LogLevel, logEntry.Message, logEntry.StepName, time.Now(), + ) + if err != nil { + log.Printf("Failed to store log: %v", err) + return + } + + // Broadcast to frontend clients + s.broadcastLogToFrontend(logEntry.TaskID, logEntry) +} + +// handleWebSocketTaskUpdate handles task status updates from WebSocket +func (s *Server) handleWebSocketTaskUpdate(runnerID int64, taskUpdate WSTaskUpdate) { + // This can be used for progress updates + // For now, we'll just log it + log.Printf("Task %d update from runner %d: %s", taskUpdate.TaskID, runnerID, taskUpdate.Status) +} + +// handleWebSocketTaskComplete handles task completion from WebSocket +func (s *Server) handleWebSocketTaskComplete(runnerID int64, taskUpdate WSTaskUpdate) { + // Verify task belongs to runner + var taskRunnerID sql.NullInt64 + err := s.db.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskUpdate.TaskID).Scan(&taskRunnerID) + if err != nil || !taskRunnerID.Valid || taskRunnerID.Int64 != runnerID { + log.Printf("Task %d does not belong to runner %d", taskUpdate.TaskID, runnerID) + return + } + + status := types.TaskStatusCompleted + if !taskUpdate.Success { + status = types.TaskStatusFailed + } + + now := time.Now() + _, err = s.db.Exec( + `UPDATE tasks SET status = ?, output_path = ?, completed_at = ?, error_message = ? WHERE id = ?`, + status, taskUpdate.OutputPath, now, taskUpdate.Error, taskUpdate.TaskID, + ) + if err != nil { + log.Printf("Failed to update task: %v", err) + return + } + + // Update job progress + var jobID int64 + var frameStart, frameEnd int + err = s.db.QueryRow( + `SELECT job_id, frame_start, frame_end FROM tasks WHERE id = ?`, + taskUpdate.TaskID, + ).Scan(&jobID, &frameStart, &frameEnd) + if err == nil { + var totalTasks, completedTasks int + s.db.QueryRow(`SELECT COUNT(*) FROM tasks WHERE job_id = ?`, jobID).Scan(&totalTasks) + s.db.QueryRow( + `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`, + jobID, types.TaskStatusCompleted, + ).Scan(&completedTasks) + + progress := float64(completedTasks) / float64(totalTasks) * 100.0 + + var jobStatus string + var outputFormat string + s.db.QueryRow(`SELECT output_format FROM jobs WHERE id = ?`, jobID).Scan(&outputFormat) + + if completedTasks == totalTasks { + jobStatus = string(types.JobStatusCompleted) + s.db.Exec( + `UPDATE jobs SET status = ?, progress = ?, completed_at = ? WHERE id = ?`, + jobStatus, progress, now, jobID, + ) + + if outputFormat == "MP4" { + go s.generateMP4Video(jobID) + } + } else { + jobStatus = string(types.JobStatusRunning) + var startedAt sql.NullTime + s.db.QueryRow(`SELECT started_at FROM jobs WHERE id = ?`, jobID).Scan(&startedAt) + if !startedAt.Valid { + s.db.Exec(`UPDATE jobs SET started_at = ? WHERE id = ?`, now, jobID) + } + s.db.Exec( + `UPDATE jobs SET status = ?, progress = ? WHERE id = ?`, + jobStatus, progress, jobID, + ) + } + } +} + +// broadcastLogToFrontend broadcasts log to connected frontend clients +func (s *Server) broadcastLogToFrontend(taskID int64, logEntry WSLogEntry) { + // Get job_id from task + var jobID int64 + err := s.db.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&jobID) + if err != nil { + return + } + + key := fmt.Sprintf("%d:%d", jobID, taskID) + s.frontendConnsMu.RLock() + conn, exists := s.frontendConns[key] + s.frontendConnsMu.RUnlock() + + if exists && conn != nil { + // Get full log entry from database for consistency + var log types.TaskLog + var runnerID sql.NullInt64 + err := s.db.QueryRow( + `SELECT id, task_id, runner_id, log_level, message, step_name, created_at + FROM task_logs WHERE task_id = ? AND message = ? ORDER BY id DESC LIMIT 1`, + taskID, logEntry.Message, + ).Scan(&log.ID, &log.TaskID, &runnerID, &log.LogLevel, &log.Message, &log.StepName, &log.CreatedAt) + if err == nil { + if runnerID.Valid { + log.RunnerID = &runnerID.Int64 + } + msg := map[string]interface{}{ + "type": "log", + "data": log, + "timestamp": time.Now().Unix(), + } + conn.WriteJSON(msg) + } + } +} + +// distributeTasksToRunners pushes available tasks to connected runners +func (s *Server) distributeTasksToRunners() { + // Get all pending tasks + rows, err := s.db.Query( + `SELECT t.id, t.job_id, t.frame_start, t.frame_end, j.allow_parallel_runners, j.status as job_status + FROM tasks t + JOIN jobs j ON t.job_id = j.id + WHERE t.status = ? AND j.status != ? + ORDER BY t.created_at ASC + LIMIT 100`, + types.TaskStatusPending, types.JobStatusCancelled, + ) + if err != nil { + log.Printf("Failed to query pending tasks: %v", err) + return + } + defer rows.Close() + + var pendingTasks []struct { + TaskID int64 + JobID int64 + FrameStart int + FrameEnd int + AllowParallelRunners bool + } + + for rows.Next() { + var t struct { + TaskID int64 + JobID int64 + FrameStart int + FrameEnd int + AllowParallelRunners bool + } + var allowParallel int + var jobStatus string + err := rows.Scan(&t.TaskID, &t.JobID, &t.FrameStart, &t.FrameEnd, &allowParallel, &jobStatus) + if err != nil { + continue + } + t.AllowParallelRunners = allowParallel == 1 + pendingTasks = append(pendingTasks, t) + } + + if len(pendingTasks) == 0 { + return + } + + // Get connected runners + s.runnerConnsMu.RLock() + connectedRunners := make([]int64, 0, len(s.runnerConns)) + for runnerID := range s.runnerConns { + connectedRunners = append(connectedRunners, runnerID) + } + s.runnerConnsMu.RUnlock() + + if len(connectedRunners) == 0 { + return + } + + // Distribute tasks to runners + for _, task := range pendingTasks { + // Check if task is already assigned + var assignedRunnerID sql.NullInt64 + err := s.db.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", task.TaskID).Scan(&assignedRunnerID) + if err == nil && assignedRunnerID.Valid { + continue // Already assigned + } + + // Find available runner + var selectedRunnerID int64 + for _, runnerID := range connectedRunners { + // Check if runner is busy (has running tasks) + var runningCount int + s.db.QueryRow( + `SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status = ?`, + runnerID, types.TaskStatusRunning, + ).Scan(&runningCount) + + if runningCount > 0 { + continue // Runner is busy + } + + // For non-parallel jobs, check if runner already has tasks from this job + if !task.AllowParallelRunners { + var jobTaskCount int + s.db.QueryRow( + `SELECT COUNT(*) FROM tasks + WHERE job_id = ? AND runner_id = ? AND status IN (?, ?)`, + task.JobID, runnerID, types.TaskStatusPending, types.TaskStatusRunning, + ).Scan(&jobTaskCount) + if jobTaskCount > 0 { + continue // Another runner is working on this job + } + } + + selectedRunnerID = runnerID + break + } + + if selectedRunnerID == 0 { + continue // No available runner + } + + // Assign task to runner + if err := s.assignTaskToRunner(selectedRunnerID, task.TaskID); err != nil { + log.Printf("Failed to assign task %d to runner %d: %v", task.TaskID, selectedRunnerID, err) + } + } +} + +// assignTaskToRunner sends a task to a runner via WebSocket +func (s *Server) assignTaskToRunner(runnerID int64, taskID int64) error { + s.runnerConnsMu.RLock() + conn, exists := s.runnerConns[runnerID] + s.runnerConnsMu.RUnlock() + + if !exists { + return fmt.Errorf("runner %d not connected", runnerID) + } + + // Get task details + var task WSTaskAssignment + var jobName, outputFormat string + err := s.db.QueryRow( + `SELECT t.job_id, t.frame_start, t.frame_end, j.name, j.output_format + FROM tasks t JOIN jobs j ON t.job_id = j.id WHERE t.id = ?`, + taskID, + ).Scan(&task.JobID, &task.FrameStart, &task.FrameEnd, &jobName, &outputFormat) + if err != nil { + return err + } + + task.TaskID = taskID + task.JobID = task.JobID + task.JobName = jobName + task.OutputFormat = outputFormat + + // Get input files + rows, err := s.db.Query( + `SELECT file_path FROM job_files WHERE job_id = ? AND file_type = ?`, + task.JobID, types.JobFileTypeInput, + ) + if err == nil { + defer rows.Close() + for rows.Next() { + var filePath string + if err := rows.Scan(&filePath); err == nil { + task.InputFiles = append(task.InputFiles, filePath) + } + } + } + + // Assign task to runner in database + now := time.Now() + _, err = s.db.Exec( + `UPDATE tasks SET runner_id = ?, status = ?, started_at = ? WHERE id = ?`, + runnerID, types.TaskStatusRunning, now, taskID, + ) + if err != nil { + return err + } + + // Send task via WebSocket + msg := WSMessage{ + Type: "task_assignment", + Timestamp: time.Now().Unix(), + } + msg.Data, _ = json.Marshal(task) + return conn.WriteJSON(msg) +} diff --git a/internal/api/server.go b/internal/api/server.go index 5e66444..6778706 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -1,19 +1,23 @@ package api import ( + "database/sql" "encoding/json" "fmt" "log" "net/http" "strconv" + "sync" "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/cors" + "github.com/gorilla/websocket" "fuego/internal/auth" "fuego/internal/database" "fuego/internal/storage" + "fuego/pkg/types" ) // Server represents the API server @@ -23,6 +27,13 @@ type Server struct { secrets *auth.Secrets storage *storage.Storage router *chi.Mux + + // WebSocket connections + wsUpgrader websocket.Upgrader + runnerConns map[int64]*websocket.Conn + runnerConnsMu sync.RWMutex + frontendConns map[string]*websocket.Conn // key: "jobId:taskId" + frontendConnsMu sync.RWMutex } // NewServer creates a new API server @@ -38,10 +49,20 @@ func NewServer(db *database.DB, auth *auth.Auth, storage *storage.Storage) (*Ser secrets: secrets, storage: storage, router: chi.NewRouter(), + wsUpgrader: websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // Allow all origins for now + }, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + }, + runnerConns: make(map[int64]*websocket.Conn), + frontendConns: make(map[string]*websocket.Conn), } s.setupMiddleware() s.setupRoutes() + s.StartBackgroundTasks() return s, nil } @@ -87,6 +108,10 @@ func (s *Server) setupRoutes() { r.Get("/{id}/files", s.handleListJobFiles) r.Get("/{id}/files/{fileId}/download", s.handleDownloadJobFile) r.Get("/{id}/video", s.handleStreamVideo) + r.Get("/{id}/tasks/{taskId}/logs", s.handleGetTaskLogs) + r.Get("/{id}/tasks/{taskId}/logs/ws", s.handleStreamTaskLogsWebSocket) + r.Get("/{id}/tasks/{taskId}/steps", s.handleGetTaskSteps) + r.Post("/{id}/tasks/{taskId}/retry", s.handleRetryTask) }) s.router.Route("/api/runners", func(r chi.Router) { @@ -118,14 +143,14 @@ func (s *Server) setupRoutes() { // Registration doesn't require auth (uses token) r.Post("/register", s.handleRegisterRunner) - // All other endpoints require runner authentication + // WebSocket endpoint (auth handled in handler) + r.Get("/ws", s.handleRunnerWebSocket) + + // File operations still use HTTP (WebSocket not suitable for large files) r.Group(func(r chi.Router) { r.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(s.runnerAuthMiddleware(next.ServeHTTP)) }) - r.Post("/heartbeat", s.handleRunnerHeartbeat) - r.Get("/tasks", s.handleGetRunnerTasks) - r.Post("/tasks/{id}/complete", s.handleCompleteTask) r.Post("/tasks/{id}/progress", s.handleUpdateTaskProgress) r.Get("/files/{jobId}/{fileName}", s.handleDownloadFileForRunner) r.Post("/files/{jobId}/upload", s.handleUploadFileFromRunner) @@ -282,3 +307,207 @@ func parseID(r *http.Request, param string) (int64, error) { return id, nil } +// StartBackgroundTasks starts background goroutines for error recovery +func (s *Server) StartBackgroundTasks() { + go s.recoverStuckTasks() +} + +// recoverStuckTasks periodically checks for dead runners and stuck tasks +func (s *Server) recoverStuckTasks() { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + // Also distribute tasks every 5 seconds + distributeTicker := time.NewTicker(5 * time.Second) + defer distributeTicker.Stop() + + go func() { + for range distributeTicker.C { + s.distributeTasksToRunners() + } + }() + + for range ticker.C { + func() { + defer func() { + if r := recover(); r != nil { + log.Printf("Panic in recoverStuckTasks: %v", r) + } + }() + + // Find dead runners (no heartbeat for 90 seconds) + rows, err := s.db.Query( + `SELECT id FROM runners + WHERE last_heartbeat < datetime('now', '-90 seconds') + AND status = ?`, + types.RunnerStatusOnline, + ) + if err != nil { + log.Printf("Failed to query dead runners: %v", err) + return + } + defer rows.Close() + + var deadRunnerIDs []int64 + for rows.Next() { + var runnerID int64 + if err := rows.Scan(&runnerID); err == nil { + deadRunnerIDs = append(deadRunnerIDs, runnerID) + } + } + rows.Close() + + if len(deadRunnerIDs) == 0 { + // Check for task timeouts + s.recoverTaskTimeouts() + return + } + + // Reset tasks assigned to dead runners + for _, runnerID := range deadRunnerIDs { + // Get tasks assigned to this runner + taskRows, err := s.db.Query( + `SELECT id, retry_count, max_retries FROM tasks + WHERE runner_id = ? AND status = ?`, + runnerID, types.TaskStatusRunning, + ) + if err != nil { + log.Printf("Failed to query tasks for runner %d: %v", runnerID, err) + continue + } + + var tasksToReset []struct { + ID int64 + RetryCount int + MaxRetries int + } + + for taskRows.Next() { + var t struct { + ID int64 + RetryCount int + MaxRetries int + } + if err := taskRows.Scan(&t.ID, &t.RetryCount, &t.MaxRetries); err == nil { + tasksToReset = append(tasksToReset, t) + } + } + taskRows.Close() + + // Reset or fail tasks + for _, task := range tasksToReset { + if task.RetryCount >= task.MaxRetries { + // Mark as failed + _, err = s.db.Exec( + `UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL + WHERE id = ?`, + types.TaskStatusFailed, "Runner died, max retries exceeded", task.ID, + ) + if err != nil { + log.Printf("Failed to mark task %d as failed: %v", task.ID, err) + } + } else { + // Reset to pending + _, err = s.db.Exec( + `UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL, + retry_count = retry_count + 1 WHERE id = ?`, + types.TaskStatusPending, task.ID, + ) + if err != nil { + log.Printf("Failed to reset task %d: %v", task.ID, err) + } else { + // Add log entry + _, _ = s.db.Exec( + `INSERT INTO task_logs (task_id, log_level, message, step_name, created_at) + VALUES (?, ?, ?, ?, ?)`, + task.ID, types.LogLevelWarn, fmt.Sprintf("Runner died, task reset (retry %d/%d)", task.RetryCount+1, task.MaxRetries), + "", time.Now(), + ) + } + } + } + + // Mark runner as offline + _, _ = s.db.Exec( + `UPDATE runners SET status = ? WHERE id = ?`, + types.RunnerStatusOffline, runnerID, + ) + } + + // Check for task timeouts + s.recoverTaskTimeouts() + + // Distribute newly recovered tasks + s.distributeTasksToRunners() + }() + } +} + +// recoverTaskTimeouts handles tasks that have exceeded their timeout +func (s *Server) recoverTaskTimeouts() { + // Find tasks running longer than their timeout + rows, err := s.db.Query( + `SELECT t.id, t.runner_id, t.retry_count, t.max_retries, t.timeout_seconds, t.started_at + FROM tasks t + WHERE t.status = ? + AND t.started_at IS NOT NULL + AND (t.timeout_seconds IS NULL OR + datetime(t.started_at, '+' || t.timeout_seconds || ' seconds') < datetime('now'))`, + types.TaskStatusRunning, + ) + if err != nil { + log.Printf("Failed to query timed out tasks: %v", err) + return + } + defer rows.Close() + + for rows.Next() { + var taskID int64 + var runnerID sql.NullInt64 + var retryCount, maxRetries int + var timeoutSeconds sql.NullInt64 + var startedAt time.Time + + err := rows.Scan(&taskID, &runnerID, &retryCount, &maxRetries, &timeoutSeconds, &startedAt) + if err != nil { + continue + } + + // Use default timeout if not set (5 minutes for frame tasks, 24 hours for FFmpeg) + timeout := 300 // 5 minutes default + if timeoutSeconds.Valid { + timeout = int(timeoutSeconds.Int64) + } + + // Check if actually timed out + if time.Since(startedAt).Seconds() < float64(timeout) { + continue + } + + if retryCount >= maxRetries { + // Mark as failed + _, err = s.db.Exec( + `UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL + WHERE id = ?`, + types.TaskStatusFailed, "Task timeout exceeded, max retries reached", taskID, + ) + } else { + // Reset to pending + _, err = s.db.Exec( + `UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL, + retry_count = retry_count + 1 WHERE id = ?`, + types.TaskStatusPending, taskID, + ) + if err == nil { + // Add log entry + _, _ = s.db.Exec( + `INSERT INTO task_logs (task_id, log_level, message, step_name, created_at) + VALUES (?, ?, ?, ?, ?)`, + taskID, types.LogLevelWarn, fmt.Sprintf("Task timeout exceeded, resetting (retry %d/%d)", retryCount+1, maxRetries), + "", time.Now(), + ) + } + } + } +} + diff --git a/internal/database/schema.go b/internal/database/schema.go index 5f1b0e1..f418936 100644 --- a/internal/database/schema.go +++ b/internal/database/schema.go @@ -119,14 +119,43 @@ func (db *DB) migrate() error { FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE SET NULL ); + CREATE TABLE IF NOT EXISTS task_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_id INTEGER NOT NULL, + runner_id INTEGER, + log_level TEXT NOT NULL, + message TEXT NOT NULL, + step_name TEXT, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (task_id) REFERENCES tasks(id) ON DELETE CASCADE, + FOREIGN KEY (runner_id) REFERENCES runners(id) ON DELETE SET NULL + ); + + CREATE TABLE IF NOT EXISTS task_steps ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_id INTEGER NOT NULL, + step_name TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + started_at DATETIME, + completed_at DATETIME, + duration_ms INTEGER, + error_message TEXT, + FOREIGN KEY (task_id) REFERENCES tasks(id) ON DELETE CASCADE + ); + CREATE INDEX IF NOT EXISTS idx_jobs_user_id ON jobs(user_id); CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status); CREATE INDEX IF NOT EXISTS idx_tasks_job_id ON tasks(job_id); CREATE INDEX IF NOT EXISTS idx_tasks_runner_id ON tasks(runner_id); CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status); + CREATE INDEX IF NOT EXISTS idx_tasks_started_at ON tasks(started_at); CREATE INDEX IF NOT EXISTS idx_job_files_job_id ON job_files(job_id); CREATE INDEX IF NOT EXISTS idx_registration_tokens_token ON registration_tokens(token); CREATE INDEX IF NOT EXISTS idx_registration_tokens_expires_at ON registration_tokens(expires_at); + CREATE INDEX IF NOT EXISTS idx_task_logs_task_id_created_at ON task_logs(task_id, created_at); + CREATE INDEX IF NOT EXISTS idx_task_logs_runner_id ON task_logs(runner_id); + CREATE INDEX IF NOT EXISTS idx_task_steps_task_id ON task_steps(task_id); + CREATE INDEX IF NOT EXISTS idx_runners_last_heartbeat ON runners(last_heartbeat); ` if _, err := db.Exec(schema); err != nil { @@ -142,6 +171,15 @@ func (db *DB) migrate() error { `ALTER TABLE runners ADD COLUMN runner_secret TEXT`, `ALTER TABLE runners ADD COLUMN manager_secret TEXT`, `ALTER TABLE runners ADD COLUMN verified BOOLEAN NOT NULL DEFAULT 0`, + // Add allow_parallel_runners to jobs if it doesn't exist + `ALTER TABLE jobs ADD COLUMN allow_parallel_runners BOOLEAN NOT NULL DEFAULT 1`, + // Add timeout_seconds to jobs if it doesn't exist + `ALTER TABLE jobs ADD COLUMN timeout_seconds INTEGER DEFAULT 86400`, + // Add new columns to tasks if they don't exist + `ALTER TABLE tasks ADD COLUMN current_step TEXT`, + `ALTER TABLE tasks ADD COLUMN retry_count INTEGER DEFAULT 0`, + `ALTER TABLE tasks ADD COLUMN max_retries INTEGER DEFAULT 3`, + `ALTER TABLE tasks ADD COLUMN timeout_seconds INTEGER`, } for _, migration := range migrations { diff --git a/internal/runner/client.go b/internal/runner/client.go index 605c4cf..4588ee1 100644 --- a/internal/runner/client.go +++ b/internal/runner/client.go @@ -11,12 +11,17 @@ import ( "log" "mime/multipart" "net/http" + "net/url" "os" "os/exec" "path/filepath" "sort" "strings" + "sync" "time" + + "github.com/gorilla/websocket" + "fuego/pkg/types" ) // Client represents a runner client @@ -29,6 +34,9 @@ type Client struct { runnerID int64 runnerSecret string managerSecret string + wsConn *websocket.Conn + wsConnMu sync.Mutex + stopChan chan struct{} } // NewClient creates a new runner client @@ -39,6 +47,7 @@ func NewClient(managerURL, name, hostname, ipAddress string) *Client { hostname: hostname, ipAddress: ipAddress, httpClient: &http.Client{Timeout: 30 * time.Second}, + stopChan: make(chan struct{}), } } @@ -121,81 +130,219 @@ func (c *Client) doSignedRequest(method, path string, body []byte) (*http.Respon return c.httpClient.Do(req) } -// HeartbeatLoop sends periodic heartbeats to the manager +// ConnectWebSocket establishes a WebSocket connection to the manager +func (c *Client) ConnectWebSocket() error { + if c.runnerID == 0 || c.runnerSecret == "" { + return fmt.Errorf("runner not authenticated") + } + + // Build WebSocket URL with authentication + timestamp := time.Now().Unix() + path := "/api/runner/ws" + // Sign the request + message := fmt.Sprintf("GET\n%s\n\n%d", path, timestamp) + h := hmac.New(sha256.New, []byte(c.runnerSecret)) + h.Write([]byte(message)) + signature := hex.EncodeToString(h.Sum(nil)) + + // Convert HTTP URL to WebSocket URL + wsURL := strings.Replace(c.managerURL, "http://", "ws://", 1) + wsURL = strings.Replace(wsURL, "https://", "wss://", 1) + wsURL = fmt.Sprintf("%s%s?runner_id=%d&signature=%s×tamp=%d", + wsURL, path, c.runnerID, signature, timestamp) + + // Parse URL + u, err := url.Parse(wsURL) + if err != nil { + return fmt.Errorf("invalid WebSocket URL: %w", err) + } + + // Connect + dialer := websocket.Dialer{ + HandshakeTimeout: 10 * time.Second, + } + conn, _, err := dialer.Dial(u.String(), nil) + if err != nil { + return fmt.Errorf("failed to connect WebSocket: %w", err) + } + + c.wsConnMu.Lock() + if c.wsConn != nil { + c.wsConn.Close() + } + c.wsConn = conn + c.wsConnMu.Unlock() + + log.Printf("WebSocket connected to manager") + return nil +} + +// ConnectWebSocketWithReconnect connects with automatic reconnection +func (c *Client) ConnectWebSocketWithReconnect() { + backoff := 1 * time.Second + maxBackoff := 60 * time.Second + + for { + err := c.ConnectWebSocket() + if err == nil { + backoff = 1 * time.Second // Reset on success + c.HandleWebSocketMessages() + } else { + log.Printf("WebSocket connection failed: %v, retrying in %v", err, backoff) + time.Sleep(backoff) + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + } + + // Check if we should stop + select { + case <-c.stopChan: + return + default: + } + } +} + +// HandleWebSocketMessages handles incoming WebSocket messages +func (c *Client) HandleWebSocketMessages() { + c.wsConnMu.Lock() + conn := c.wsConn + c.wsConnMu.Unlock() + + if conn == nil { + return + } + + // Set pong handler + conn.SetPongHandler(func(string) error { + return nil + }) + + // Handle messages + for { + var msg map[string]interface{} + err := conn.ReadJSON(&msg) + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Printf("WebSocket error: %v", err) + } + c.wsConnMu.Lock() + c.wsConn = nil + c.wsConnMu.Unlock() + return + } + + msgType, _ := msg["type"].(string) + switch msgType { + case "task_assignment": + c.handleTaskAssignment(msg) + case "ping": + // Respond to ping with pong (automatic) + } + } +} + +// handleTaskAssignment handles a task assignment message +func (c *Client) handleTaskAssignment(msg map[string]interface{}) { + data, ok := msg["data"].(map[string]interface{}) + if !ok { + log.Printf("Invalid task assignment message") + return + } + + taskID, _ := data["task_id"].(float64) + jobID, _ := data["job_id"].(float64) + jobName, _ := data["job_name"].(string) + outputFormat, _ := data["output_format"].(string) + frameStart, _ := data["frame_start"].(float64) + frameEnd, _ := data["frame_end"].(float64) + inputFilesRaw, _ := data["input_files"].([]interface{}) + + if len(inputFilesRaw) == 0 { + log.Printf("No input files for task %v", taskID) + c.sendTaskComplete(int64(taskID), "", false, "No input files") + return + } + + // Convert to task map format + taskMap := map[string]interface{}{ + "id": taskID, + "job_id": jobID, + "frame_start": frameStart, + "frame_end": frameEnd, + } + + // Process the task + go func() { + if err := c.processTask(taskMap, jobName, outputFormat, inputFilesRaw); err != nil { + log.Printf("Failed to process task %v: %v", taskID, err) + c.sendTaskComplete(int64(taskID), "", false, err.Error()) + } + }() +} + +// HeartbeatLoop sends periodic heartbeats via WebSocket func (c *Client) HeartbeatLoop() { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for range ticker.C { - req := map[string]interface{}{} - body, _ := json.Marshal(req) + c.wsConnMu.RLock() + conn := c.wsConn + c.wsConnMu.RUnlock() - resp, err := c.doSignedRequest("POST", "/api/runner/heartbeat?runner_id="+fmt.Sprintf("%d", c.runnerID), body) - if err != nil { - log.Printf("Heartbeat failed: %v", err) - continue - } - resp.Body.Close() - } -} - -// ProcessTasks polls for tasks and processes them -func (c *Client) ProcessTasks() { - ticker := time.NewTicker(5 * time.Second) - defer ticker.Stop() - - for range ticker.C { - tasks, err := c.getTasks() - if err != nil { - log.Printf("Failed to get tasks: %v", err) - continue - } - - for _, taskData := range tasks { - taskMap, ok := taskData["task"].(map[string]interface{}) - if !ok { - continue + if conn != nil { + // Send heartbeat via WebSocket + msg := map[string]interface{}{ + "type": "heartbeat", + "timestamp": time.Now().Unix(), } - - jobName, _ := taskData["job_name"].(string) - outputFormat, _ := taskData["output_format"].(string) - inputFilesRaw, _ := taskData["input_files"].([]interface{}) - - if len(inputFilesRaw) == 0 { - log.Printf("No input files for task %v", taskMap["id"]) - continue - } - - // Process the task - if err := c.processTask(taskMap, jobName, outputFormat, inputFilesRaw); err != nil { - taskID, _ := taskMap["id"].(float64) - log.Printf("Failed to process task %v: %v", taskID, err) - c.completeTask(int64(taskID), "", false, err.Error()) + if err := conn.WriteJSON(msg); err != nil { + log.Printf("Failed to send heartbeat: %v", err) } } } } -// getTasks fetches tasks from the manager -func (c *Client) getTasks() ([]map[string]interface{}, error) { - path := fmt.Sprintf("/api/runner/tasks?runner_id=%d", c.runnerID) - resp, err := c.doSignedRequest("GET", path, nil) - if err != nil { - return nil, err - } - defer resp.Body.Close() +// sendLog sends a log entry to the manager via WebSocket +func (c *Client) sendLog(taskID int64, logLevel types.LogLevel, message, stepName string) { + c.wsConnMu.RLock() + conn := c.wsConn + c.wsConnMu.RUnlock() - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("failed to get tasks: %s", string(body)) + if conn != nil { + msg := map[string]interface{}{ + "type": "log_entry", + "data": map[string]interface{}{ + "task_id": taskID, + "log_level": string(logLevel), + "message": message, + "step_name": stepName, + }, + "timestamp": time.Now().Unix(), + } + if err := conn.WriteJSON(msg); err != nil { + log.Printf("Failed to send log: %v", err) + } + } else { + log.Printf("WebSocket not connected, cannot send log") } +} - var tasks []map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&tasks); err != nil { - return nil, err +// sendStepUpdate sends a step start/complete event to the manager +func (c *Client) sendStepUpdate(taskID int64, stepName string, status types.StepStatus, errorMsg string) { + // This would ideally be a separate endpoint, but for now we'll use logs + msg := fmt.Sprintf("Step %s: %s", stepName, status) + if errorMsg != "" { + msg += " - " + errorMsg } - - return tasks, nil + logLevel := types.LogLevelInfo + if status == types.StepStatusFailed { + logLevel = types.LogLevelError + } + c.sendLog(taskID, logLevel, msg, stepName) } // processTask processes a single task @@ -205,6 +352,7 @@ func (c *Client) processTask(task map[string]interface{}, jobName, outputFormat frameStart := int(task["frame_start"].(float64)) frameEnd := int(task["frame_end"].(float64)) + c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Starting task: job %d, frames %d-%d, format: %s", jobID, frameStart, frameEnd, outputFormat), "") log.Printf("Processing task %d: job %d, frames %d-%d, format: %s", taskID, jobID, frameStart, frameEnd, outputFormat) // Create work directory @@ -214,11 +362,14 @@ func (c *Client) processTask(task map[string]interface{}, jobName, outputFormat } defer os.RemoveAll(workDir) - // Download input files + // Step: download + c.sendStepUpdate(taskID, "download", types.StepStatusRunning, "") + c.sendLog(taskID, types.LogLevelInfo, "Downloading input files...", "download") blendFile := "" for _, filePath := range inputFiles { filePathStr := filePath.(string) if err := c.downloadFile(filePathStr, workDir); err != nil { + c.sendStepUpdate(taskID, "download", types.StepStatusFailed, err.Error()) return fmt.Errorf("failed to download file %s: %w", filePathStr, err) } if filepath.Ext(filePathStr) == ".blend" { @@ -227,8 +378,12 @@ func (c *Client) processTask(task map[string]interface{}, jobName, outputFormat } if blendFile == "" { - return fmt.Errorf("no .blend file found in input files") + err := fmt.Errorf("no .blend file found in input files") + c.sendStepUpdate(taskID, "download", types.StepStatusFailed, err.Error()) + return err } + c.sendStepUpdate(taskID, "download", types.StepStatusCompleted, "") + c.sendLog(taskID, types.LogLevelInfo, "Input files downloaded successfully", "download") // Render frames outputDir := filepath.Join(workDir, "output") @@ -244,30 +399,60 @@ func (c *Client) processTask(task map[string]interface{}, jobName, outputFormat outputPattern := filepath.Join(outputDir, fmt.Sprintf("frame_%%04d.%s", strings.ToLower(renderFormat))) + // Step: render_blender + c.sendStepUpdate(taskID, "render_blender", types.StepStatusRunning, "") + c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Starting Blender render for frame %d...", frameStart), "render_blender") + // Execute Blender cmd := exec.Command("blender", "-b", blendFile, "-o", outputPattern, "-f", fmt.Sprintf("%d", frameStart)) cmd.Dir = workDir output, err := cmd.CombinedOutput() if err != nil { - return fmt.Errorf("blender failed: %w\nOutput: %s", err, string(output)) + errMsg := fmt.Sprintf("blender failed: %w\nOutput: %s", err, string(output)) + c.sendLog(taskID, types.LogLevelError, errMsg, "render_blender") + c.sendStepUpdate(taskID, "render_blender", types.StepStatusFailed, errMsg) + return fmt.Errorf(errMsg) } // Find rendered output file outputFile := filepath.Join(outputDir, fmt.Sprintf("frame_%04d.%s", frameStart, strings.ToLower(renderFormat))) if _, err := os.Stat(outputFile); os.IsNotExist(err) { - return fmt.Errorf("output file not found: %s", outputFile) + errMsg := fmt.Sprintf("output file not found: %s", outputFile) + c.sendLog(taskID, types.LogLevelError, errMsg, "render_blender") + c.sendStepUpdate(taskID, "render_blender", types.StepStatusFailed, errMsg) + return fmt.Errorf(errMsg) } + c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Blender render completed for frame %d", frameStart), "render_blender") + c.sendStepUpdate(taskID, "render_blender", types.StepStatusCompleted, "") - // Upload frame file + // Step: upload or upload_frames + uploadStepName := "upload" + if outputFormat == "MP4" { + uploadStepName = "upload_frames" + } + c.sendStepUpdate(taskID, uploadStepName, types.StepStatusRunning, "") + c.sendLog(taskID, types.LogLevelInfo, "Uploading output file...", uploadStepName) + outputPath, err := c.uploadFile(jobID, outputFile) if err != nil { - return fmt.Errorf("failed to upload output: %w", err) + errMsg := fmt.Sprintf("failed to upload output: %w", err) + c.sendLog(taskID, types.LogLevelError, errMsg, uploadStepName) + c.sendStepUpdate(taskID, uploadStepName, types.StepStatusFailed, errMsg) + return fmt.Errorf(errMsg) } + c.sendLog(taskID, types.LogLevelInfo, "Output file uploaded successfully", uploadStepName) + c.sendStepUpdate(taskID, uploadStepName, types.StepStatusCompleted, "") + // Step: complete + c.sendStepUpdate(taskID, "complete", types.StepStatusRunning, "") + c.sendLog(taskID, types.LogLevelInfo, "Task completed successfully", "complete") + // Mark task as complete if err := c.completeTask(taskID, outputPath, true, ""); err != nil { + c.sendStepUpdate(taskID, "complete", types.StepStatusFailed, err.Error()) return err } + c.sendStepUpdate(taskID, "complete", types.StepStatusCompleted, "") // For MP4 format, check if all frames are done and generate video if outputFormat == "MP4" { @@ -599,29 +784,33 @@ func (c *Client) uploadFile(jobID int64, filePath string) (string, error) { return result.FilePath, nil } -// completeTask marks a task as complete +// completeTask marks a task as complete via WebSocket (or HTTP fallback) func (c *Client) completeTask(taskID int64, outputPath string, success bool, errorMsg string) error { - req := map[string]interface{}{ - "output_path": outputPath, - "success": success, - } - if !success { - req["error"] = errorMsg - } - - body, _ := json.Marshal(req) - path := fmt.Sprintf("/api/runner/tasks/%d/complete?runner_id=%d", taskID, c.runnerID) - resp, err := c.doSignedRequest("POST", path, body) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("failed to complete task: %s", string(body)) - } - - return nil + return c.sendTaskComplete(taskID, outputPath, success, errorMsg) +} + +// sendTaskComplete sends task completion via WebSocket +func (c *Client) sendTaskComplete(taskID int64, outputPath string, success bool, errorMsg string) error { + c.wsConnMu.RLock() + conn := c.wsConn + c.wsConnMu.RUnlock() + + if conn != nil { + msg := map[string]interface{}{ + "type": "task_complete", + "data": map[string]interface{}{ + "task_id": taskID, + "output_path": outputPath, + "success": success, + "error": errorMsg, + }, + "timestamp": time.Now().Unix(), + } + if err := conn.WriteJSON(msg); err != nil { + return fmt.Errorf("failed to send task completion: %w", err) + } + return nil + } + return fmt.Errorf("WebSocket not connected, cannot complete task") } diff --git a/pkg/types/types.go b/pkg/types/types.go index a66eccd..dbe69c0 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -25,18 +25,20 @@ const ( // Job represents a render job type Job struct { - ID int64 `json:"id"` - UserID int64 `json:"user_id"` - Name string `json:"name"` - Status JobStatus `json:"status"` - Progress float64 `json:"progress"` // 0.0 to 100.0 - FrameStart int `json:"frame_start"` - FrameEnd int `json:"frame_end"` - OutputFormat string `json:"output_format"` // PNG, JPEG, EXR, etc. - CreatedAt time.Time `json:"created_at"` - StartedAt *time.Time `json:"started_at,omitempty"` - CompletedAt *time.Time `json:"completed_at,omitempty"` - ErrorMessage string `json:"error_message,omitempty"` + ID int64 `json:"id"` + UserID int64 `json:"user_id"` + Name string `json:"name"` + Status JobStatus `json:"status"` + Progress float64 `json:"progress"` // 0.0 to 100.0 + FrameStart int `json:"frame_start"` + FrameEnd int `json:"frame_end"` + OutputFormat string `json:"output_format"` // PNG, JPEG, EXR, etc. + AllowParallelRunners bool `json:"allow_parallel_runners"` // Allow multiple runners to work on this job + TimeoutSeconds int `json:"timeout_seconds"` // Job-level timeout (24 hours default) + CreatedAt time.Time `json:"created_at"` + StartedAt *time.Time `json:"started_at,omitempty"` + CompletedAt *time.Time `json:"completed_at,omitempty"` + ErrorMessage string `json:"error_message,omitempty"` } // RunnerStatus represents the status of a runner @@ -72,17 +74,21 @@ const ( // Task represents a render task assigned to a runner type Task struct { - ID int64 `json:"id"` - JobID int64 `json:"job_id"` - RunnerID *int64 `json:"runner_id,omitempty"` - FrameStart int `json:"frame_start"` - FrameEnd int `json:"frame_end"` - Status TaskStatus `json:"status"` - OutputPath string `json:"output_path,omitempty"` - CreatedAt time.Time `json:"created_at"` - StartedAt *time.Time `json:"started_at,omitempty"` - CompletedAt *time.Time `json:"completed_at,omitempty"` - ErrorMessage string `json:"error_message,omitempty"` + ID int64 `json:"id"` + JobID int64 `json:"job_id"` + RunnerID *int64 `json:"runner_id,omitempty"` + FrameStart int `json:"frame_start"` + FrameEnd int `json:"frame_end"` + Status TaskStatus `json:"status"` + CurrentStep string `json:"current_step,omitempty"` + RetryCount int `json:"retry_count"` + MaxRetries int `json:"max_retries"` + TimeoutSeconds *int `json:"timeout_seconds,omitempty"` // Task timeout (5 min for frames, 24h for FFmpeg) + OutputPath string `json:"output_path,omitempty"` + CreatedAt time.Time `json:"created_at"` + StartedAt *time.Time `json:"started_at,omitempty"` + CompletedAt *time.Time `json:"completed_at,omitempty"` + ErrorMessage string `json:"error_message,omitempty"` } // JobFileType represents the type of file @@ -106,10 +112,11 @@ type JobFile struct { // CreateJobRequest represents a request to create a new job type CreateJobRequest struct { - Name string `json:"name"` - FrameStart int `json:"frame_start"` - FrameEnd int `json:"frame_end"` - OutputFormat string `json:"output_format"` + Name string `json:"name"` + FrameStart int `json:"frame_start"` + FrameEnd int `json:"frame_end"` + OutputFormat string `json:"output_format"` + AllowParallelRunners *bool `json:"allow_parallel_runners,omitempty"` // Optional, defaults to true } // UpdateJobProgressRequest represents a request to update job progress @@ -125,3 +132,70 @@ type RegisterRunnerRequest struct { Capabilities string `json:"capabilities"` } +// LogLevel represents the level of a log entry +type LogLevel string + +const ( + LogLevelInfo LogLevel = "INFO" + LogLevelWarn LogLevel = "WARN" + LogLevelError LogLevel = "ERROR" + LogLevelDebug LogLevel = "DEBUG" +) + +// TaskLog represents a log entry for a task +type TaskLog struct { + ID int64 `json:"id"` + TaskID int64 `json:"task_id"` + RunnerID *int64 `json:"runner_id,omitempty"` + LogLevel LogLevel `json:"log_level"` + Message string `json:"message"` + StepName string `json:"step_name,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +// StepStatus represents the status of a task step +type StepStatus string + +const ( + StepStatusPending StepStatus = "pending" + StepStatusRunning StepStatus = "running" + StepStatusCompleted StepStatus = "completed" + StepStatusFailed StepStatus = "failed" + StepStatusSkipped StepStatus = "skipped" +) + +// TaskStep represents an execution step within a task +type TaskStep struct { + ID int64 `json:"id"` + TaskID int64 `json:"task_id"` + StepName string `json:"step_name"` + Status StepStatus `json:"status"` + StartedAt *time.Time `json:"started_at,omitempty"` + CompletedAt *time.Time `json:"completed_at,omitempty"` + DurationMs *int `json:"duration_ms,omitempty"` + ErrorMessage string `json:"error_message,omitempty"` +} + +// TaskAnnotation represents an annotation (warning/error) for a task +type TaskAnnotation struct { + ID int64 `json:"id"` + TaskID int64 `json:"task_id"` + StepName string `json:"step_name,omitempty"` + Level LogLevel `json:"level"` + Message string `json:"message"` + Line *int `json:"line,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +// SendTaskLogRequest represents a request to send task logs +type SendTaskLogRequest struct { + Logs []TaskLogEntry `json:"logs"` +} + +// TaskLogEntry represents a single log entry in a batch request +type TaskLogEntry struct { + LogLevel LogLevel `json:"log_level"` + Message string `json:"message"` + StepName string `json:"step_name,omitempty"` +} + diff --git a/web/src/components/JobDetails.jsx b/web/src/components/JobDetails.jsx index 58d79f6..ac81405 100644 --- a/web/src/components/JobDetails.jsx +++ b/web/src/components/JobDetails.jsx @@ -1,4 +1,4 @@ -import { useState, useEffect } from 'react'; +import { useState, useEffect, useRef } from 'react'; import { jobs } from '../utils/api'; import VideoPlayer from './VideoPlayer'; @@ -7,13 +7,39 @@ export default function JobDetails({ job, onClose, onUpdate }) { const [files, setFiles] = useState([]); const [loading, setLoading] = useState(true); const [videoUrl, setVideoUrl] = useState(null); + const [selectedTaskId, setSelectedTaskId] = useState(null); + const [taskLogs, setTaskLogs] = useState([]); + const [taskSteps, setTaskSteps] = useState([]); + const [streaming, setStreaming] = useState(false); + const wsRef = useRef(null); useEffect(() => { loadDetails(); const interval = setInterval(loadDetails, 2000); - return () => clearInterval(interval); + return () => { + clearInterval(interval); + if (wsRef.current) { + wsRef.current.close(); + } + }; }, [job.id]); + useEffect(() => { + if (selectedTaskId && jobDetails.status === 'running') { + startLogStream(); + } else if (wsRef.current) { + wsRef.current.close(); + wsRef.current = null; + setStreaming(false); + } + return () => { + if (wsRef.current) { + wsRef.current.close(); + wsRef.current = null; + } + }; + }, [selectedTaskId, jobDetails.status]); + const loadDetails = async () => { try { const [details, fileList] = await Promise.all([ @@ -41,6 +67,90 @@ export default function JobDetails({ job, onClose, onUpdate }) { window.open(jobs.downloadFile(job.id, fileId), '_blank'); }; + const loadTaskLogs = async (taskId) => { + try { + const [logs, steps] = await Promise.all([ + jobs.getTaskLogs(job.id, taskId), + jobs.getTaskSteps(job.id, taskId), + ]); + setTaskLogs(logs); + setTaskSteps(steps); + } catch (error) { + console.error('Failed to load task logs:', error); + } + }; + + const startLogStream = () => { + if (!selectedTaskId || streaming) return; + + setStreaming(true); + const ws = jobs.streamTaskLogsWebSocket(job.id, selectedTaskId); + wsRef.current = ws; + + ws.onmessage = (event) => { + try { + const data = JSON.parse(event.data); + if (data.type === 'log' && data.data) { + setTaskLogs((prev) => [...prev, data.data]); + } else if (data.type === 'connected') { + // Connection established + } + } catch (error) { + console.error('Failed to parse log message:', error); + } + }; + + ws.onerror = (error) => { + console.error('WebSocket error:', error); + setStreaming(false); + }; + + ws.onclose = () => { + setStreaming(false); + // Auto-reconnect if job is still running + if (jobDetails.status === 'running' && selectedTaskId) { + setTimeout(() => { + if (jobDetails.status === 'running') { + startLogStream(); + } + }, 2000); + } + }; + }; + + const handleTaskClick = async (taskId) => { + setSelectedTaskId(taskId); + await loadTaskLogs(taskId); + }; + + const getLogLevelColor = (level) => { + switch (level) { + case 'ERROR': + return 'text-red-600'; + case 'WARN': + return 'text-yellow-600'; + case 'DEBUG': + return 'text-gray-500'; + default: + return 'text-gray-900'; + } + }; + + const getStepStatusIcon = (status) => { + switch (status) { + case 'completed': + return '✓'; + case 'failed': + return '✗'; + case 'running': + return '⏳'; + case 'skipped': + return '⏸'; + default: + return '○'; + } + }; + const outputFiles = files.filter((f) => f.file_type === 'output'); const inputFiles = files.filter((f) => f.file_type === 'input'); @@ -156,6 +266,75 @@ export default function JobDetails({ job, onClose, onUpdate }) {
{jobDetails.error_message}
)} + +No logs yet...
+ ) : ( + taskLogs.map((log) => ( ++ Select a task to view logs and steps +
+ )} +