redo
This commit is contained in:
@@ -5,10 +5,14 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"fuego/pkg/types"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
// handleCreateJob creates a new job
|
||||
@@ -39,10 +43,19 @@ func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) {
|
||||
req.OutputFormat = "PNG"
|
||||
}
|
||||
|
||||
// Default allow_parallel_runners to true if not provided
|
||||
allowParallelRunners := true
|
||||
if req.AllowParallelRunners != nil {
|
||||
allowParallelRunners = *req.AllowParallelRunners
|
||||
}
|
||||
|
||||
// Set job timeout to 24 hours (86400 seconds)
|
||||
jobTimeout := 86400
|
||||
|
||||
result, err := s.db.Exec(
|
||||
`INSERT INTO jobs (user_id, name, status, progress, frame_start, frame_end, output_format)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
userID, req.Name, types.JobStatusPending, 0.0, req.FrameStart, req.FrameEnd, req.OutputFormat,
|
||||
`INSERT INTO jobs (user_id, name, status, progress, frame_start, frame_end, output_format, allow_parallel_runners, timeout_seconds)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
userID, req.Name, types.JobStatusPending, 0.0, req.FrameStart, req.FrameEnd, req.OutputFormat, allowParallelRunners, jobTimeout,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create job: %v", err))
|
||||
@@ -51,11 +64,21 @@ func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
jobID, _ := result.LastInsertId()
|
||||
|
||||
// Determine task timeout based on output format
|
||||
// 5 minutes (300 seconds) for frame tasks, 24 hours (86400 seconds) for FFmpeg video generation
|
||||
taskTimeout := 300 // Default: 5 minutes for frame rendering
|
||||
if req.OutputFormat == "MP4" {
|
||||
// For MP4, we'll create frame tasks with 5 min timeout
|
||||
// Video generation tasks will be created later with 24h timeout
|
||||
taskTimeout = 300
|
||||
}
|
||||
|
||||
// Create tasks for the job (one task per frame for simplicity, could be batched)
|
||||
for frame := req.FrameStart; frame <= req.FrameEnd; frame++ {
|
||||
_, err = s.db.Exec(
|
||||
`INSERT INTO tasks (job_id, frame_start, frame_end, status) VALUES (?, ?, ?, ?)`,
|
||||
jobID, frame, frame, types.TaskStatusPending,
|
||||
`INSERT INTO tasks (job_id, frame_start, frame_end, status, timeout_seconds, max_retries)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
jobID, frame, frame, types.TaskStatusPending, taskTimeout, 3,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create tasks: %v", err))
|
||||
@@ -64,17 +87,22 @@ func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
job := types.Job{
|
||||
ID: jobID,
|
||||
UserID: userID,
|
||||
Name: req.Name,
|
||||
Status: types.JobStatusPending,
|
||||
Progress: 0.0,
|
||||
FrameStart: req.FrameStart,
|
||||
FrameEnd: req.FrameEnd,
|
||||
OutputFormat: req.OutputFormat,
|
||||
CreatedAt: time.Now(),
|
||||
ID: jobID,
|
||||
UserID: userID,
|
||||
Name: req.Name,
|
||||
Status: types.JobStatusPending,
|
||||
Progress: 0.0,
|
||||
FrameStart: req.FrameStart,
|
||||
FrameEnd: req.FrameEnd,
|
||||
OutputFormat: req.OutputFormat,
|
||||
AllowParallelRunners: allowParallelRunners,
|
||||
TimeoutSeconds: jobTimeout,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Immediately try to distribute tasks to connected runners
|
||||
go s.distributeTasksToRunners()
|
||||
|
||||
s.respondJSON(w, http.StatusCreated, job)
|
||||
}
|
||||
|
||||
@@ -88,7 +116,7 @@ func (s *Server) handleListJobs(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, user_id, name, status, progress, frame_start, frame_end, output_format,
|
||||
created_at, started_at, completed_at, error_message
|
||||
allow_parallel_runners, timeout_seconds, created_at, started_at, completed_at, error_message
|
||||
FROM jobs WHERE user_id = ? ORDER BY created_at DESC`,
|
||||
userID,
|
||||
)
|
||||
@@ -105,7 +133,7 @@ func (s *Server) handleListJobs(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err := rows.Scan(
|
||||
&job.ID, &job.UserID, &job.Name, &job.Status, &job.Progress,
|
||||
&job.FrameStart, &job.FrameEnd, &job.OutputFormat,
|
||||
&job.FrameStart, &job.FrameEnd, &job.OutputFormat, &job.AllowParallelRunners, &job.TimeoutSeconds,
|
||||
&job.CreatedAt, &startedAt, &completedAt, &job.ErrorMessage,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -145,12 +173,12 @@ func (s *Server) handleGetJob(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = s.db.QueryRow(
|
||||
`SELECT id, user_id, name, status, progress, frame_start, frame_end, output_format,
|
||||
created_at, started_at, completed_at, error_message
|
||||
allow_parallel_runners, timeout_seconds, created_at, started_at, completed_at, error_message
|
||||
FROM jobs WHERE id = ? AND user_id = ?`,
|
||||
jobID, userID,
|
||||
).Scan(
|
||||
&job.ID, &job.UserID, &job.Name, &job.Status, &job.Progress,
|
||||
&job.FrameStart, &job.FrameEnd, &job.OutputFormat,
|
||||
&job.FrameStart, &job.FrameEnd, &job.OutputFormat, &job.AllowParallelRunners, &job.TimeoutSeconds,
|
||||
&job.CreatedAt, &startedAt, &completedAt, &job.ErrorMessage,
|
||||
)
|
||||
|
||||
@@ -280,7 +308,7 @@ func (s *Server) handleUploadJobFile(w http.ResponseWriter, r *http.Request) {
|
||||
fileID, _ := result.LastInsertId()
|
||||
|
||||
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"id": fileID,
|
||||
"id": fileID,
|
||||
"file_name": header.Filename,
|
||||
"file_path": filePath,
|
||||
"file_size": header.Size,
|
||||
@@ -496,3 +524,454 @@ func (s *Server) handleStreamVideo(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// handleGetTaskLogs retrieves logs for a specific task
|
||||
func (s *Server) handleGetTaskLogs(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := getUserID(r)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jobID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
taskIDStr := chi.URLParam(r, "taskId")
|
||||
taskID, err := strconv.ParseInt(taskIDStr, 10, 64)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid task ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify job belongs to user
|
||||
var jobUserID int64
|
||||
err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Job not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify job: %v", err))
|
||||
return
|
||||
}
|
||||
if jobUserID != userID {
|
||||
s.respondError(w, http.StatusForbidden, "Access denied")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify task belongs to job
|
||||
var taskJobID int64
|
||||
err = s.db.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&taskJobID)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Task not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify task: %v", err))
|
||||
return
|
||||
}
|
||||
if taskJobID != jobID {
|
||||
s.respondError(w, http.StatusBadRequest, "Task does not belong to this job")
|
||||
return
|
||||
}
|
||||
|
||||
// Get query parameters for filtering
|
||||
stepName := r.URL.Query().Get("step_name")
|
||||
logLevel := r.URL.Query().Get("log_level")
|
||||
limitStr := r.URL.Query().Get("limit")
|
||||
limit := 1000 // default
|
||||
if limitStr != "" {
|
||||
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
|
||||
limit = l
|
||||
}
|
||||
}
|
||||
|
||||
// Build query
|
||||
query := `SELECT id, task_id, runner_id, log_level, message, step_name, created_at
|
||||
FROM task_logs WHERE task_id = ?`
|
||||
args := []interface{}{taskID}
|
||||
if stepName != "" {
|
||||
query += " AND step_name = ?"
|
||||
args = append(args, stepName)
|
||||
}
|
||||
if logLevel != "" {
|
||||
query += " AND log_level = ?"
|
||||
args = append(args, logLevel)
|
||||
}
|
||||
query += " ORDER BY created_at ASC LIMIT ?"
|
||||
args = append(args, limit)
|
||||
|
||||
rows, err := s.db.Query(query, args...)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query logs: %v", err))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
logs := []types.TaskLog{}
|
||||
for rows.Next() {
|
||||
var log types.TaskLog
|
||||
var runnerID sql.NullInt64
|
||||
err := rows.Scan(
|
||||
&log.ID, &log.TaskID, &runnerID, &log.LogLevel, &log.Message,
|
||||
&log.StepName, &log.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan log: %v", err))
|
||||
return
|
||||
}
|
||||
if runnerID.Valid {
|
||||
log.RunnerID = &runnerID.Int64
|
||||
}
|
||||
logs = append(logs, log)
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, logs)
|
||||
}
|
||||
|
||||
// handleGetTaskSteps retrieves step timeline for a specific task
|
||||
func (s *Server) handleGetTaskSteps(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := getUserID(r)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jobID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
taskIDStr := chi.URLParam(r, "taskId")
|
||||
taskID, err := strconv.ParseInt(taskIDStr, 10, 64)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid task ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify job belongs to user
|
||||
var jobUserID int64
|
||||
err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Job not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify job: %v", err))
|
||||
return
|
||||
}
|
||||
if jobUserID != userID {
|
||||
s.respondError(w, http.StatusForbidden, "Access denied")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify task belongs to job
|
||||
var taskJobID int64
|
||||
err = s.db.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&taskJobID)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Task not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify task: %v", err))
|
||||
return
|
||||
}
|
||||
if taskJobID != jobID {
|
||||
s.respondError(w, http.StatusBadRequest, "Task does not belong to this job")
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, task_id, step_name, status, started_at, completed_at, duration_ms, error_message
|
||||
FROM task_steps WHERE task_id = ? ORDER BY started_at ASC`,
|
||||
taskID,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query steps: %v", err))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
steps := []types.TaskStep{}
|
||||
for rows.Next() {
|
||||
var step types.TaskStep
|
||||
var startedAt, completedAt sql.NullTime
|
||||
var durationMs sql.NullInt64
|
||||
err := rows.Scan(
|
||||
&step.ID, &step.TaskID, &step.StepName, &step.Status,
|
||||
&startedAt, &completedAt, &durationMs, &step.ErrorMessage,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan step: %v", err))
|
||||
return
|
||||
}
|
||||
if startedAt.Valid {
|
||||
step.StartedAt = &startedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
step.CompletedAt = &completedAt.Time
|
||||
}
|
||||
if durationMs.Valid {
|
||||
duration := int(durationMs.Int64)
|
||||
step.DurationMs = &duration
|
||||
}
|
||||
steps = append(steps, step)
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, steps)
|
||||
}
|
||||
|
||||
// handleRetryTask retries a failed task
|
||||
func (s *Server) handleRetryTask(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := getUserID(r)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jobID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
taskIDStr := chi.URLParam(r, "taskId")
|
||||
taskID, err := strconv.ParseInt(taskIDStr, 10, 64)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid task ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify job belongs to user
|
||||
var jobUserID int64
|
||||
err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Job not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify job: %v", err))
|
||||
return
|
||||
}
|
||||
if jobUserID != userID {
|
||||
s.respondError(w, http.StatusForbidden, "Access denied")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify task belongs to job and is in a retryable state
|
||||
var taskJobID int64
|
||||
var taskStatus string
|
||||
var retryCount, maxRetries int
|
||||
err = s.db.QueryRow(
|
||||
"SELECT job_id, status, retry_count, max_retries FROM tasks WHERE id = ?",
|
||||
taskID,
|
||||
).Scan(&taskJobID, &taskStatus, &retryCount, &maxRetries)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Task not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify task: %v", err))
|
||||
return
|
||||
}
|
||||
if taskJobID != jobID {
|
||||
s.respondError(w, http.StatusBadRequest, "Task does not belong to this job")
|
||||
return
|
||||
}
|
||||
|
||||
if taskStatus != string(types.TaskStatusFailed) {
|
||||
s.respondError(w, http.StatusBadRequest, "Task is not in failed state")
|
||||
return
|
||||
}
|
||||
|
||||
if retryCount >= maxRetries {
|
||||
s.respondError(w, http.StatusBadRequest, "Maximum retries exceeded")
|
||||
return
|
||||
}
|
||||
|
||||
// Reset task to pending
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL,
|
||||
error_message = NULL, started_at = NULL, completed_at = NULL
|
||||
WHERE id = ?`,
|
||||
types.TaskStatusPending, taskID,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to retry task: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Task queued for retry"})
|
||||
}
|
||||
|
||||
// handleStreamTaskLogsWebSocket streams task logs via WebSocket
|
||||
// Note: This is called after auth middleware, so userID is already verified
|
||||
func (s *Server) handleStreamTaskLogsWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := getUserID(r)
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
jobID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
taskIDStr := chi.URLParam(r, "taskId")
|
||||
taskID, err := strconv.ParseInt(taskIDStr, 10, 64)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid task ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify job belongs to user
|
||||
var jobUserID int64
|
||||
err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Job not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify job: %v", err))
|
||||
return
|
||||
}
|
||||
if jobUserID != userID {
|
||||
s.respondError(w, http.StatusForbidden, "Access denied")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify task belongs to job
|
||||
var taskJobID int64
|
||||
err = s.db.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&taskJobID)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Task not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify task: %v", err))
|
||||
return
|
||||
}
|
||||
if taskJobID != jobID {
|
||||
s.respondError(w, http.StatusBadRequest, "Task does not belong to this job")
|
||||
return
|
||||
}
|
||||
|
||||
// Upgrade to WebSocket
|
||||
conn, err := s.wsUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Printf("Failed to upgrade WebSocket: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
key := fmt.Sprintf("%d:%d", jobID, taskID)
|
||||
s.frontendConnsMu.Lock()
|
||||
s.frontendConns[key] = conn
|
||||
s.frontendConnsMu.Unlock()
|
||||
|
||||
defer func() {
|
||||
s.frontendConnsMu.Lock()
|
||||
delete(s.frontendConns, key)
|
||||
s.frontendConnsMu.Unlock()
|
||||
}()
|
||||
|
||||
// Send initial connection message
|
||||
conn.WriteJSON(map[string]interface{}{
|
||||
"type": "connected",
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
|
||||
// Get last log ID to start streaming from
|
||||
lastIDStr := r.URL.Query().Get("last_id")
|
||||
lastID := int64(0)
|
||||
if lastIDStr != "" {
|
||||
if id, err := strconv.ParseInt(lastIDStr, 10, 64); err == nil {
|
||||
lastID = id
|
||||
}
|
||||
}
|
||||
|
||||
// Send existing logs
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, task_id, runner_id, log_level, message, step_name, created_at
|
||||
FROM task_logs WHERE task_id = ? AND id > ? ORDER BY created_at ASC LIMIT 100`,
|
||||
taskID, lastID,
|
||||
)
|
||||
if err == nil {
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var log types.TaskLog
|
||||
var runnerID sql.NullInt64
|
||||
err := rows.Scan(
|
||||
&log.ID, &log.TaskID, &runnerID, &log.LogLevel, &log.Message,
|
||||
&log.StepName, &log.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if runnerID.Valid {
|
||||
log.RunnerID = &runnerID.Int64
|
||||
}
|
||||
if log.ID > lastID {
|
||||
lastID = log.ID
|
||||
}
|
||||
|
||||
conn.WriteJSON(map[string]interface{}{
|
||||
"type": "log",
|
||||
"data": log,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Poll for new logs and send them
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
ctx := r.Context()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, task_id, runner_id, log_level, message, step_name, created_at
|
||||
FROM task_logs WHERE task_id = ? AND id > ? ORDER BY created_at ASC LIMIT 100`,
|
||||
taskID, lastID,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var log types.TaskLog
|
||||
var runnerID sql.NullInt64
|
||||
err := rows.Scan(
|
||||
&log.ID, &log.TaskID, &runnerID, &log.LogLevel, &log.Message,
|
||||
&log.StepName, &log.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
rows.Close()
|
||||
continue
|
||||
}
|
||||
if runnerID.Valid {
|
||||
log.RunnerID = &runnerID.Int64
|
||||
}
|
||||
if log.ID > lastID {
|
||||
lastID = log.ID
|
||||
}
|
||||
|
||||
conn.WriteJSON(map[string]interface{}{
|
||||
"type": "log",
|
||||
"data": log,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
}
|
||||
rows.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,18 +2,20 @@ package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"fuego/internal/auth"
|
||||
"fuego/pkg/types"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// handleListRunners lists all runners
|
||||
@@ -154,215 +156,15 @@ func (s *Server) handleRegisterRunner(w http.ResponseWriter, r *http.Request) {
|
||||
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"id": runnerID,
|
||||
"name": req.Name,
|
||||
"hostname": req.Hostname,
|
||||
"ip_address": req.IPAddress,
|
||||
"status": types.RunnerStatusOnline,
|
||||
"runner_secret": runnerSecret,
|
||||
"hostname": req.Hostname,
|
||||
"ip_address": req.IPAddress,
|
||||
"status": types.RunnerStatusOnline,
|
||||
"runner_secret": runnerSecret,
|
||||
"manager_secret": managerSecret,
|
||||
"verified": true,
|
||||
"verified": true,
|
||||
})
|
||||
}
|
||||
|
||||
// handleRunnerHeartbeat updates runner heartbeat
|
||||
func (s *Server) handleRunnerHeartbeat(w http.ResponseWriter, r *http.Request) {
|
||||
runnerID, ok := r.Context().Value("runner_id").(int64)
|
||||
if !ok {
|
||||
s.respondError(w, http.StatusBadRequest, "runner_id not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
_, err := s.db.Exec(
|
||||
`UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`,
|
||||
time.Now(), types.RunnerStatusOnline, runnerID,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update heartbeat: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Heartbeat updated"})
|
||||
}
|
||||
|
||||
// handleGetRunnerTasks gets pending tasks for a runner
|
||||
func (s *Server) handleGetRunnerTasks(w http.ResponseWriter, r *http.Request) {
|
||||
runnerID, ok := r.Context().Value("runner_id").(int64)
|
||||
if !ok {
|
||||
s.respondError(w, http.StatusBadRequest, "runner_id not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
// Get pending tasks
|
||||
rows, err := s.db.Query(
|
||||
`SELECT t.id, t.job_id, t.runner_id, t.frame_start, t.frame_end, t.status, t.output_path,
|
||||
t.created_at, t.started_at, t.completed_at, t.error_message,
|
||||
j.name as job_name, j.output_format
|
||||
FROM tasks t
|
||||
JOIN jobs j ON t.job_id = j.id
|
||||
WHERE t.status = ? AND j.status != ?
|
||||
ORDER BY t.created_at ASC
|
||||
LIMIT 10`,
|
||||
types.TaskStatusPending, types.JobStatusCancelled,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query tasks: %v", err))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
tasks := []map[string]interface{}{}
|
||||
for rows.Next() {
|
||||
var task types.Task
|
||||
var runnerID sql.NullInt64
|
||||
var startedAt, completedAt sql.NullTime
|
||||
var jobName, outputFormat string
|
||||
|
||||
err := rows.Scan(
|
||||
&task.ID, &task.JobID, &runnerID, &task.FrameStart, &task.FrameEnd,
|
||||
&task.Status, &task.OutputPath, &task.CreatedAt,
|
||||
&startedAt, &completedAt, &task.ErrorMessage,
|
||||
&jobName, &outputFormat,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan task: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if runnerID.Valid {
|
||||
task.RunnerID = &runnerID.Int64
|
||||
}
|
||||
if startedAt.Valid {
|
||||
task.StartedAt = &startedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
task.CompletedAt = &completedAt.Time
|
||||
}
|
||||
|
||||
// Get input files for the job
|
||||
var inputFiles []string
|
||||
fileRows, err := s.db.Query(
|
||||
`SELECT file_path FROM job_files WHERE job_id = ? AND file_type = ?`,
|
||||
task.JobID, types.JobFileTypeInput,
|
||||
)
|
||||
if err == nil {
|
||||
for fileRows.Next() {
|
||||
var filePath string
|
||||
if err := fileRows.Scan(&filePath); err == nil {
|
||||
inputFiles = append(inputFiles, filePath)
|
||||
}
|
||||
}
|
||||
fileRows.Close()
|
||||
}
|
||||
|
||||
tasks = append(tasks, map[string]interface{}{
|
||||
"task": task,
|
||||
"job_name": jobName,
|
||||
"output_format": outputFormat,
|
||||
"input_files": inputFiles,
|
||||
})
|
||||
|
||||
// Assign task to runner
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET runner_id = ?, status = ? WHERE id = ?`,
|
||||
runnerID, types.TaskStatusRunning, task.ID,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to assign task: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, tasks)
|
||||
}
|
||||
|
||||
// handleCompleteTask marks a task as completed
|
||||
func (s *Server) handleCompleteTask(w http.ResponseWriter, r *http.Request) {
|
||||
taskID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
OutputPath string `json:"output_path"`
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
status := types.TaskStatusCompleted
|
||||
if !req.Success {
|
||||
status = types.TaskStatusFailed
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, output_path = ?, completed_at = ?, error_message = ? WHERE id = ?`,
|
||||
status, req.OutputPath, now, req.Error, taskID,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update task: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Update job progress
|
||||
var jobID int64
|
||||
var frameStart, frameEnd int
|
||||
err = s.db.QueryRow(
|
||||
`SELECT job_id, frame_start, frame_end FROM tasks WHERE id = ?`,
|
||||
taskID,
|
||||
).Scan(&jobID, &frameStart, &frameEnd)
|
||||
if err == nil {
|
||||
// Count completed tasks
|
||||
var totalTasks, completedTasks int
|
||||
s.db.QueryRow(
|
||||
`SELECT COUNT(*) FROM tasks WHERE job_id = ?`,
|
||||
jobID,
|
||||
).Scan(&totalTasks)
|
||||
s.db.QueryRow(
|
||||
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
|
||||
jobID, types.TaskStatusCompleted,
|
||||
).Scan(&completedTasks)
|
||||
|
||||
progress := float64(completedTasks) / float64(totalTasks) * 100.0
|
||||
|
||||
// Update job status
|
||||
var jobStatus string
|
||||
var outputFormat string
|
||||
s.db.QueryRow(`SELECT output_format FROM jobs WHERE id = ?`, jobID).Scan(&outputFormat)
|
||||
|
||||
if completedTasks == totalTasks {
|
||||
jobStatus = string(types.JobStatusCompleted)
|
||||
now := time.Now()
|
||||
s.db.Exec(
|
||||
`UPDATE jobs SET status = ?, progress = ?, completed_at = ? WHERE id = ?`,
|
||||
jobStatus, progress, now, jobID,
|
||||
)
|
||||
|
||||
// For MP4 jobs, create a video generation task
|
||||
if outputFormat == "MP4" {
|
||||
go s.generateMP4Video(jobID)
|
||||
}
|
||||
} else {
|
||||
jobStatus = string(types.JobStatusRunning)
|
||||
var startedAt sql.NullTime
|
||||
s.db.QueryRow(`SELECT started_at FROM jobs WHERE id = ?`, jobID).Scan(&startedAt)
|
||||
if !startedAt.Valid {
|
||||
now := time.Now()
|
||||
s.db.Exec(`UPDATE jobs SET started_at = ? WHERE id = ?`, now, jobID)
|
||||
}
|
||||
s.db.Exec(
|
||||
`UPDATE jobs SET status = ?, progress = ? WHERE id = ?`,
|
||||
jobStatus, progress, jobID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Task completed"})
|
||||
}
|
||||
|
||||
// handleUpdateTaskProgress updates task progress
|
||||
func (s *Server) handleUpdateTaskProgress(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := parseID(r, "id")
|
||||
@@ -471,7 +273,7 @@ func (s *Server) generateMP4Video(jobID int64) {
|
||||
// This would be called by a runner or external process
|
||||
// For now, we'll create a special task that runners can pick up
|
||||
// In a production system, you might want to use a job queue or have a dedicated video processor
|
||||
|
||||
|
||||
// Get all PNG output files for this job
|
||||
rows, err := s.db.Query(
|
||||
`SELECT file_path, file_name FROM job_files
|
||||
@@ -516,12 +318,12 @@ func (s *Server) handleGetJobStatusForRunner(w http.ResponseWriter, r *http.Requ
|
||||
|
||||
err = s.db.QueryRow(
|
||||
`SELECT id, user_id, name, status, progress, frame_start, frame_end, output_format,
|
||||
created_at, started_at, completed_at, error_message
|
||||
allow_parallel_runners, created_at, started_at, completed_at, error_message
|
||||
FROM jobs WHERE id = ?`,
|
||||
jobID,
|
||||
).Scan(
|
||||
&job.ID, &job.UserID, &job.Name, &job.Status, &job.Progress,
|
||||
&job.FrameStart, &job.FrameEnd, &job.OutputFormat,
|
||||
&job.FrameStart, &job.FrameEnd, &job.OutputFormat, &job.AllowParallelRunners,
|
||||
&job.CreatedAt, &startedAt, &completedAt, &job.ErrorMessage,
|
||||
)
|
||||
|
||||
@@ -580,3 +382,500 @@ func (s *Server) handleGetJobFilesForRunner(w http.ResponseWriter, r *http.Reque
|
||||
s.respondJSON(w, http.StatusOK, files)
|
||||
}
|
||||
|
||||
// WebSocket message types
|
||||
type WSMessage struct {
|
||||
Type string `json:"type"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
type WSTaskAssignment struct {
|
||||
TaskID int64 `json:"task_id"`
|
||||
JobID int64 `json:"job_id"`
|
||||
JobName string `json:"job_name"`
|
||||
OutputFormat string `json:"output_format"`
|
||||
FrameStart int `json:"frame_start"`
|
||||
FrameEnd int `json:"frame_end"`
|
||||
InputFiles []string `json:"input_files"`
|
||||
}
|
||||
|
||||
type WSLogEntry struct {
|
||||
TaskID int64 `json:"task_id"`
|
||||
LogLevel string `json:"log_level"`
|
||||
Message string `json:"message"`
|
||||
StepName string `json:"step_name,omitempty"`
|
||||
}
|
||||
|
||||
type WSTaskUpdate struct {
|
||||
TaskID int64 `json:"task_id"`
|
||||
Status string `json:"status"`
|
||||
OutputPath string `json:"output_path,omitempty"`
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// handleRunnerWebSocket handles WebSocket connections from runners
|
||||
func (s *Server) handleRunnerWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
// Get runner ID and signature from query params
|
||||
runnerIDStr := r.URL.Query().Get("runner_id")
|
||||
signature := r.URL.Query().Get("signature")
|
||||
timestampStr := r.URL.Query().Get("timestamp")
|
||||
|
||||
if runnerIDStr == "" || signature == "" || timestampStr == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "runner_id, signature, and timestamp required")
|
||||
return
|
||||
}
|
||||
|
||||
var runnerID int64
|
||||
_, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "invalid runner_id")
|
||||
return
|
||||
}
|
||||
|
||||
// Get runner secret
|
||||
runnerSecret, err := s.secrets.GetRunnerSecret(runnerID)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, "runner not found or not verified")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
var timestamp int64
|
||||
_, err = fmt.Sscanf(timestampStr, "%d", ×tamp)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "invalid timestamp")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify signature manually (similar to HTTP auth)
|
||||
timestampTime := time.Unix(timestamp, 0)
|
||||
|
||||
// Check timestamp is not too old
|
||||
if time.Since(timestampTime) > 5*time.Minute {
|
||||
s.respondError(w, http.StatusUnauthorized, "timestamp too old")
|
||||
return
|
||||
}
|
||||
|
||||
// Check timestamp is not in the future (allow 1 minute clock skew)
|
||||
if timestampTime.After(time.Now().Add(1 * time.Minute)) {
|
||||
s.respondError(w, http.StatusUnauthorized, "timestamp in future")
|
||||
return
|
||||
}
|
||||
|
||||
// Build the message that should be signed
|
||||
path := r.URL.Path
|
||||
expectedSig := auth.SignRequest("GET", path, "", runnerSecret, timestampTime)
|
||||
|
||||
// Compare signatures (constant-time)
|
||||
if subtle.ConstantTimeCompare([]byte(signature), []byte(expectedSig)) != 1 {
|
||||
s.respondError(w, http.StatusUnauthorized, "invalid signature")
|
||||
return
|
||||
}
|
||||
|
||||
// Upgrade to WebSocket
|
||||
conn, err := s.wsUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Printf("Failed to upgrade WebSocket: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Register connection
|
||||
s.runnerConnsMu.Lock()
|
||||
// Remove old connection if exists
|
||||
if oldConn, exists := s.runnerConns[runnerID]; exists {
|
||||
oldConn.Close()
|
||||
}
|
||||
s.runnerConns[runnerID] = conn
|
||||
s.runnerConnsMu.Unlock()
|
||||
|
||||
// Update runner status to online
|
||||
_, _ = s.db.Exec(
|
||||
`UPDATE runners SET status = ?, last_heartbeat = ? WHERE id = ?`,
|
||||
types.RunnerStatusOnline, time.Now(), runnerID,
|
||||
)
|
||||
|
||||
// Cleanup on disconnect
|
||||
defer func() {
|
||||
s.runnerConnsMu.Lock()
|
||||
delete(s.runnerConns, runnerID)
|
||||
s.runnerConnsMu.Unlock()
|
||||
_, _ = s.db.Exec(
|
||||
`UPDATE runners SET status = ? WHERE id = ?`,
|
||||
types.RunnerStatusOffline, runnerID,
|
||||
)
|
||||
}()
|
||||
|
||||
// Set ping handler to update heartbeat
|
||||
conn.SetPongHandler(func(string) error {
|
||||
_, _ = s.db.Exec(
|
||||
`UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`,
|
||||
time.Now(), types.RunnerStatusOnline, runnerID,
|
||||
)
|
||||
return nil
|
||||
})
|
||||
|
||||
// Send ping every 30 seconds
|
||||
go func() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.runnerConnsMu.RLock()
|
||||
conn, exists := s.runnerConns[runnerID]
|
||||
s.runnerConnsMu.RUnlock()
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Handle incoming messages
|
||||
for {
|
||||
var msg WSMessage
|
||||
err := conn.ReadJSON(&msg)
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
log.Printf("WebSocket error for runner %d: %v", runnerID, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
switch msg.Type {
|
||||
case "heartbeat":
|
||||
// Update heartbeat
|
||||
_, _ = s.db.Exec(
|
||||
`UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`,
|
||||
time.Now(), types.RunnerStatusOnline, runnerID,
|
||||
)
|
||||
|
||||
case "log_entry":
|
||||
var logEntry WSLogEntry
|
||||
if err := json.Unmarshal(msg.Data, &logEntry); err == nil {
|
||||
s.handleWebSocketLog(runnerID, logEntry)
|
||||
}
|
||||
|
||||
case "task_update":
|
||||
var taskUpdate WSTaskUpdate
|
||||
if err := json.Unmarshal(msg.Data, &taskUpdate); err == nil {
|
||||
s.handleWebSocketTaskUpdate(runnerID, taskUpdate)
|
||||
}
|
||||
|
||||
case "task_complete":
|
||||
var taskUpdate WSTaskUpdate
|
||||
if err := json.Unmarshal(msg.Data, &taskUpdate); err == nil {
|
||||
s.handleWebSocketTaskComplete(runnerID, taskUpdate)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleWebSocketLog handles log entries from WebSocket
|
||||
func (s *Server) handleWebSocketLog(runnerID int64, logEntry WSLogEntry) {
|
||||
// Store log in database
|
||||
_, err := s.db.Exec(
|
||||
`INSERT INTO task_logs (task_id, runner_id, log_level, message, step_name, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
logEntry.TaskID, runnerID, logEntry.LogLevel, logEntry.Message, logEntry.StepName, time.Now(),
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to store log: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Broadcast to frontend clients
|
||||
s.broadcastLogToFrontend(logEntry.TaskID, logEntry)
|
||||
}
|
||||
|
||||
// handleWebSocketTaskUpdate handles task status updates from WebSocket
|
||||
func (s *Server) handleWebSocketTaskUpdate(runnerID int64, taskUpdate WSTaskUpdate) {
|
||||
// This can be used for progress updates
|
||||
// For now, we'll just log it
|
||||
log.Printf("Task %d update from runner %d: %s", taskUpdate.TaskID, runnerID, taskUpdate.Status)
|
||||
}
|
||||
|
||||
// handleWebSocketTaskComplete handles task completion from WebSocket
|
||||
func (s *Server) handleWebSocketTaskComplete(runnerID int64, taskUpdate WSTaskUpdate) {
|
||||
// Verify task belongs to runner
|
||||
var taskRunnerID sql.NullInt64
|
||||
err := s.db.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskUpdate.TaskID).Scan(&taskRunnerID)
|
||||
if err != nil || !taskRunnerID.Valid || taskRunnerID.Int64 != runnerID {
|
||||
log.Printf("Task %d does not belong to runner %d", taskUpdate.TaskID, runnerID)
|
||||
return
|
||||
}
|
||||
|
||||
status := types.TaskStatusCompleted
|
||||
if !taskUpdate.Success {
|
||||
status = types.TaskStatusFailed
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, output_path = ?, completed_at = ?, error_message = ? WHERE id = ?`,
|
||||
status, taskUpdate.OutputPath, now, taskUpdate.Error, taskUpdate.TaskID,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to update task: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update job progress
|
||||
var jobID int64
|
||||
var frameStart, frameEnd int
|
||||
err = s.db.QueryRow(
|
||||
`SELECT job_id, frame_start, frame_end FROM tasks WHERE id = ?`,
|
||||
taskUpdate.TaskID,
|
||||
).Scan(&jobID, &frameStart, &frameEnd)
|
||||
if err == nil {
|
||||
var totalTasks, completedTasks int
|
||||
s.db.QueryRow(`SELECT COUNT(*) FROM tasks WHERE job_id = ?`, jobID).Scan(&totalTasks)
|
||||
s.db.QueryRow(
|
||||
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
|
||||
jobID, types.TaskStatusCompleted,
|
||||
).Scan(&completedTasks)
|
||||
|
||||
progress := float64(completedTasks) / float64(totalTasks) * 100.0
|
||||
|
||||
var jobStatus string
|
||||
var outputFormat string
|
||||
s.db.QueryRow(`SELECT output_format FROM jobs WHERE id = ?`, jobID).Scan(&outputFormat)
|
||||
|
||||
if completedTasks == totalTasks {
|
||||
jobStatus = string(types.JobStatusCompleted)
|
||||
s.db.Exec(
|
||||
`UPDATE jobs SET status = ?, progress = ?, completed_at = ? WHERE id = ?`,
|
||||
jobStatus, progress, now, jobID,
|
||||
)
|
||||
|
||||
if outputFormat == "MP4" {
|
||||
go s.generateMP4Video(jobID)
|
||||
}
|
||||
} else {
|
||||
jobStatus = string(types.JobStatusRunning)
|
||||
var startedAt sql.NullTime
|
||||
s.db.QueryRow(`SELECT started_at FROM jobs WHERE id = ?`, jobID).Scan(&startedAt)
|
||||
if !startedAt.Valid {
|
||||
s.db.Exec(`UPDATE jobs SET started_at = ? WHERE id = ?`, now, jobID)
|
||||
}
|
||||
s.db.Exec(
|
||||
`UPDATE jobs SET status = ?, progress = ? WHERE id = ?`,
|
||||
jobStatus, progress, jobID,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// broadcastLogToFrontend broadcasts log to connected frontend clients
|
||||
func (s *Server) broadcastLogToFrontend(taskID int64, logEntry WSLogEntry) {
|
||||
// Get job_id from task
|
||||
var jobID int64
|
||||
err := s.db.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&jobID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%d:%d", jobID, taskID)
|
||||
s.frontendConnsMu.RLock()
|
||||
conn, exists := s.frontendConns[key]
|
||||
s.frontendConnsMu.RUnlock()
|
||||
|
||||
if exists && conn != nil {
|
||||
// Get full log entry from database for consistency
|
||||
var log types.TaskLog
|
||||
var runnerID sql.NullInt64
|
||||
err := s.db.QueryRow(
|
||||
`SELECT id, task_id, runner_id, log_level, message, step_name, created_at
|
||||
FROM task_logs WHERE task_id = ? AND message = ? ORDER BY id DESC LIMIT 1`,
|
||||
taskID, logEntry.Message,
|
||||
).Scan(&log.ID, &log.TaskID, &runnerID, &log.LogLevel, &log.Message, &log.StepName, &log.CreatedAt)
|
||||
if err == nil {
|
||||
if runnerID.Valid {
|
||||
log.RunnerID = &runnerID.Int64
|
||||
}
|
||||
msg := map[string]interface{}{
|
||||
"type": "log",
|
||||
"data": log,
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
conn.WriteJSON(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// distributeTasksToRunners pushes available tasks to connected runners
|
||||
func (s *Server) distributeTasksToRunners() {
|
||||
// Get all pending tasks
|
||||
rows, err := s.db.Query(
|
||||
`SELECT t.id, t.job_id, t.frame_start, t.frame_end, j.allow_parallel_runners, j.status as job_status
|
||||
FROM tasks t
|
||||
JOIN jobs j ON t.job_id = j.id
|
||||
WHERE t.status = ? AND j.status != ?
|
||||
ORDER BY t.created_at ASC
|
||||
LIMIT 100`,
|
||||
types.TaskStatusPending, types.JobStatusCancelled,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to query pending tasks: %v", err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var pendingTasks []struct {
|
||||
TaskID int64
|
||||
JobID int64
|
||||
FrameStart int
|
||||
FrameEnd int
|
||||
AllowParallelRunners bool
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var t struct {
|
||||
TaskID int64
|
||||
JobID int64
|
||||
FrameStart int
|
||||
FrameEnd int
|
||||
AllowParallelRunners bool
|
||||
}
|
||||
var allowParallel int
|
||||
var jobStatus string
|
||||
err := rows.Scan(&t.TaskID, &t.JobID, &t.FrameStart, &t.FrameEnd, &allowParallel, &jobStatus)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
t.AllowParallelRunners = allowParallel == 1
|
||||
pendingTasks = append(pendingTasks, t)
|
||||
}
|
||||
|
||||
if len(pendingTasks) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Get connected runners
|
||||
s.runnerConnsMu.RLock()
|
||||
connectedRunners := make([]int64, 0, len(s.runnerConns))
|
||||
for runnerID := range s.runnerConns {
|
||||
connectedRunners = append(connectedRunners, runnerID)
|
||||
}
|
||||
s.runnerConnsMu.RUnlock()
|
||||
|
||||
if len(connectedRunners) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Distribute tasks to runners
|
||||
for _, task := range pendingTasks {
|
||||
// Check if task is already assigned
|
||||
var assignedRunnerID sql.NullInt64
|
||||
err := s.db.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", task.TaskID).Scan(&assignedRunnerID)
|
||||
if err == nil && assignedRunnerID.Valid {
|
||||
continue // Already assigned
|
||||
}
|
||||
|
||||
// Find available runner
|
||||
var selectedRunnerID int64
|
||||
for _, runnerID := range connectedRunners {
|
||||
// Check if runner is busy (has running tasks)
|
||||
var runningCount int
|
||||
s.db.QueryRow(
|
||||
`SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status = ?`,
|
||||
runnerID, types.TaskStatusRunning,
|
||||
).Scan(&runningCount)
|
||||
|
||||
if runningCount > 0 {
|
||||
continue // Runner is busy
|
||||
}
|
||||
|
||||
// For non-parallel jobs, check if runner already has tasks from this job
|
||||
if !task.AllowParallelRunners {
|
||||
var jobTaskCount int
|
||||
s.db.QueryRow(
|
||||
`SELECT COUNT(*) FROM tasks
|
||||
WHERE job_id = ? AND runner_id = ? AND status IN (?, ?)`,
|
||||
task.JobID, runnerID, types.TaskStatusPending, types.TaskStatusRunning,
|
||||
).Scan(&jobTaskCount)
|
||||
if jobTaskCount > 0 {
|
||||
continue // Another runner is working on this job
|
||||
}
|
||||
}
|
||||
|
||||
selectedRunnerID = runnerID
|
||||
break
|
||||
}
|
||||
|
||||
if selectedRunnerID == 0 {
|
||||
continue // No available runner
|
||||
}
|
||||
|
||||
// Assign task to runner
|
||||
if err := s.assignTaskToRunner(selectedRunnerID, task.TaskID); err != nil {
|
||||
log.Printf("Failed to assign task %d to runner %d: %v", task.TaskID, selectedRunnerID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// assignTaskToRunner sends a task to a runner via WebSocket
|
||||
func (s *Server) assignTaskToRunner(runnerID int64, taskID int64) error {
|
||||
s.runnerConnsMu.RLock()
|
||||
conn, exists := s.runnerConns[runnerID]
|
||||
s.runnerConnsMu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("runner %d not connected", runnerID)
|
||||
}
|
||||
|
||||
// Get task details
|
||||
var task WSTaskAssignment
|
||||
var jobName, outputFormat string
|
||||
err := s.db.QueryRow(
|
||||
`SELECT t.job_id, t.frame_start, t.frame_end, j.name, j.output_format
|
||||
FROM tasks t JOIN jobs j ON t.job_id = j.id WHERE t.id = ?`,
|
||||
taskID,
|
||||
).Scan(&task.JobID, &task.FrameStart, &task.FrameEnd, &jobName, &outputFormat)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
task.TaskID = taskID
|
||||
task.JobID = task.JobID
|
||||
task.JobName = jobName
|
||||
task.OutputFormat = outputFormat
|
||||
|
||||
// Get input files
|
||||
rows, err := s.db.Query(
|
||||
`SELECT file_path FROM job_files WHERE job_id = ? AND file_type = ?`,
|
||||
task.JobID, types.JobFileTypeInput,
|
||||
)
|
||||
if err == nil {
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var filePath string
|
||||
if err := rows.Scan(&filePath); err == nil {
|
||||
task.InputFiles = append(task.InputFiles, filePath)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Assign task to runner in database
|
||||
now := time.Now()
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET runner_id = ?, status = ?, started_at = ? WHERE id = ?`,
|
||||
runnerID, types.TaskStatusRunning, now, taskID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send task via WebSocket
|
||||
msg := WSMessage{
|
||||
Type: "task_assignment",
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
msg.Data, _ = json.Marshal(task)
|
||||
return conn.WriteJSON(msg)
|
||||
}
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/go-chi/cors"
|
||||
"github.com/gorilla/websocket"
|
||||
"fuego/internal/auth"
|
||||
"fuego/internal/database"
|
||||
"fuego/internal/storage"
|
||||
"fuego/pkg/types"
|
||||
)
|
||||
|
||||
// Server represents the API server
|
||||
@@ -23,6 +27,13 @@ type Server struct {
|
||||
secrets *auth.Secrets
|
||||
storage *storage.Storage
|
||||
router *chi.Mux
|
||||
|
||||
// WebSocket connections
|
||||
wsUpgrader websocket.Upgrader
|
||||
runnerConns map[int64]*websocket.Conn
|
||||
runnerConnsMu sync.RWMutex
|
||||
frontendConns map[string]*websocket.Conn // key: "jobId:taskId"
|
||||
frontendConnsMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewServer creates a new API server
|
||||
@@ -38,10 +49,20 @@ func NewServer(db *database.DB, auth *auth.Auth, storage *storage.Storage) (*Ser
|
||||
secrets: secrets,
|
||||
storage: storage,
|
||||
router: chi.NewRouter(),
|
||||
wsUpgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // Allow all origins for now
|
||||
},
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
},
|
||||
runnerConns: make(map[int64]*websocket.Conn),
|
||||
frontendConns: make(map[string]*websocket.Conn),
|
||||
}
|
||||
|
||||
s.setupMiddleware()
|
||||
s.setupRoutes()
|
||||
s.StartBackgroundTasks()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
@@ -87,6 +108,10 @@ func (s *Server) setupRoutes() {
|
||||
r.Get("/{id}/files", s.handleListJobFiles)
|
||||
r.Get("/{id}/files/{fileId}/download", s.handleDownloadJobFile)
|
||||
r.Get("/{id}/video", s.handleStreamVideo)
|
||||
r.Get("/{id}/tasks/{taskId}/logs", s.handleGetTaskLogs)
|
||||
r.Get("/{id}/tasks/{taskId}/logs/ws", s.handleStreamTaskLogsWebSocket)
|
||||
r.Get("/{id}/tasks/{taskId}/steps", s.handleGetTaskSteps)
|
||||
r.Post("/{id}/tasks/{taskId}/retry", s.handleRetryTask)
|
||||
})
|
||||
|
||||
s.router.Route("/api/runners", func(r chi.Router) {
|
||||
@@ -118,14 +143,14 @@ func (s *Server) setupRoutes() {
|
||||
// Registration doesn't require auth (uses token)
|
||||
r.Post("/register", s.handleRegisterRunner)
|
||||
|
||||
// All other endpoints require runner authentication
|
||||
// WebSocket endpoint (auth handled in handler)
|
||||
r.Get("/ws", s.handleRunnerWebSocket)
|
||||
|
||||
// File operations still use HTTP (WebSocket not suitable for large files)
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(s.runnerAuthMiddleware(next.ServeHTTP))
|
||||
})
|
||||
r.Post("/heartbeat", s.handleRunnerHeartbeat)
|
||||
r.Get("/tasks", s.handleGetRunnerTasks)
|
||||
r.Post("/tasks/{id}/complete", s.handleCompleteTask)
|
||||
r.Post("/tasks/{id}/progress", s.handleUpdateTaskProgress)
|
||||
r.Get("/files/{jobId}/{fileName}", s.handleDownloadFileForRunner)
|
||||
r.Post("/files/{jobId}/upload", s.handleUploadFileFromRunner)
|
||||
@@ -282,3 +307,207 @@ func parseID(r *http.Request, param string) (int64, error) {
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// StartBackgroundTasks starts background goroutines for error recovery
|
||||
func (s *Server) StartBackgroundTasks() {
|
||||
go s.recoverStuckTasks()
|
||||
}
|
||||
|
||||
// recoverStuckTasks periodically checks for dead runners and stuck tasks
|
||||
func (s *Server) recoverStuckTasks() {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Also distribute tasks every 5 seconds
|
||||
distributeTicker := time.NewTicker(5 * time.Second)
|
||||
defer distributeTicker.Stop()
|
||||
|
||||
go func() {
|
||||
for range distributeTicker.C {
|
||||
s.distributeTasksToRunners()
|
||||
}
|
||||
}()
|
||||
|
||||
for range ticker.C {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("Panic in recoverStuckTasks: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Find dead runners (no heartbeat for 90 seconds)
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id FROM runners
|
||||
WHERE last_heartbeat < datetime('now', '-90 seconds')
|
||||
AND status = ?`,
|
||||
types.RunnerStatusOnline,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to query dead runners: %v", err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var deadRunnerIDs []int64
|
||||
for rows.Next() {
|
||||
var runnerID int64
|
||||
if err := rows.Scan(&runnerID); err == nil {
|
||||
deadRunnerIDs = append(deadRunnerIDs, runnerID)
|
||||
}
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
if len(deadRunnerIDs) == 0 {
|
||||
// Check for task timeouts
|
||||
s.recoverTaskTimeouts()
|
||||
return
|
||||
}
|
||||
|
||||
// Reset tasks assigned to dead runners
|
||||
for _, runnerID := range deadRunnerIDs {
|
||||
// Get tasks assigned to this runner
|
||||
taskRows, err := s.db.Query(
|
||||
`SELECT id, retry_count, max_retries FROM tasks
|
||||
WHERE runner_id = ? AND status = ?`,
|
||||
runnerID, types.TaskStatusRunning,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to query tasks for runner %d: %v", runnerID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
var tasksToReset []struct {
|
||||
ID int64
|
||||
RetryCount int
|
||||
MaxRetries int
|
||||
}
|
||||
|
||||
for taskRows.Next() {
|
||||
var t struct {
|
||||
ID int64
|
||||
RetryCount int
|
||||
MaxRetries int
|
||||
}
|
||||
if err := taskRows.Scan(&t.ID, &t.RetryCount, &t.MaxRetries); err == nil {
|
||||
tasksToReset = append(tasksToReset, t)
|
||||
}
|
||||
}
|
||||
taskRows.Close()
|
||||
|
||||
// Reset or fail tasks
|
||||
for _, task := range tasksToReset {
|
||||
if task.RetryCount >= task.MaxRetries {
|
||||
// Mark as failed
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL
|
||||
WHERE id = ?`,
|
||||
types.TaskStatusFailed, "Runner died, max retries exceeded", task.ID,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to mark task %d as failed: %v", task.ID, err)
|
||||
}
|
||||
} else {
|
||||
// Reset to pending
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL,
|
||||
retry_count = retry_count + 1 WHERE id = ?`,
|
||||
types.TaskStatusPending, task.ID,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to reset task %d: %v", task.ID, err)
|
||||
} else {
|
||||
// Add log entry
|
||||
_, _ = s.db.Exec(
|
||||
`INSERT INTO task_logs (task_id, log_level, message, step_name, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
task.ID, types.LogLevelWarn, fmt.Sprintf("Runner died, task reset (retry %d/%d)", task.RetryCount+1, task.MaxRetries),
|
||||
"", time.Now(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mark runner as offline
|
||||
_, _ = s.db.Exec(
|
||||
`UPDATE runners SET status = ? WHERE id = ?`,
|
||||
types.RunnerStatusOffline, runnerID,
|
||||
)
|
||||
}
|
||||
|
||||
// Check for task timeouts
|
||||
s.recoverTaskTimeouts()
|
||||
|
||||
// Distribute newly recovered tasks
|
||||
s.distributeTasksToRunners()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// recoverTaskTimeouts handles tasks that have exceeded their timeout
|
||||
func (s *Server) recoverTaskTimeouts() {
|
||||
// Find tasks running longer than their timeout
|
||||
rows, err := s.db.Query(
|
||||
`SELECT t.id, t.runner_id, t.retry_count, t.max_retries, t.timeout_seconds, t.started_at
|
||||
FROM tasks t
|
||||
WHERE t.status = ?
|
||||
AND t.started_at IS NOT NULL
|
||||
AND (t.timeout_seconds IS NULL OR
|
||||
datetime(t.started_at, '+' || t.timeout_seconds || ' seconds') < datetime('now'))`,
|
||||
types.TaskStatusRunning,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to query timed out tasks: %v", err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var taskID int64
|
||||
var runnerID sql.NullInt64
|
||||
var retryCount, maxRetries int
|
||||
var timeoutSeconds sql.NullInt64
|
||||
var startedAt time.Time
|
||||
|
||||
err := rows.Scan(&taskID, &runnerID, &retryCount, &maxRetries, &timeoutSeconds, &startedAt)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Use default timeout if not set (5 minutes for frame tasks, 24 hours for FFmpeg)
|
||||
timeout := 300 // 5 minutes default
|
||||
if timeoutSeconds.Valid {
|
||||
timeout = int(timeoutSeconds.Int64)
|
||||
}
|
||||
|
||||
// Check if actually timed out
|
||||
if time.Since(startedAt).Seconds() < float64(timeout) {
|
||||
continue
|
||||
}
|
||||
|
||||
if retryCount >= maxRetries {
|
||||
// Mark as failed
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL
|
||||
WHERE id = ?`,
|
||||
types.TaskStatusFailed, "Task timeout exceeded, max retries reached", taskID,
|
||||
)
|
||||
} else {
|
||||
// Reset to pending
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL,
|
||||
retry_count = retry_count + 1 WHERE id = ?`,
|
||||
types.TaskStatusPending, taskID,
|
||||
)
|
||||
if err == nil {
|
||||
// Add log entry
|
||||
_, _ = s.db.Exec(
|
||||
`INSERT INTO task_logs (task_id, log_level, message, step_name, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
taskID, types.LogLevelWarn, fmt.Sprintf("Task timeout exceeded, resetting (retry %d/%d)", retryCount+1, maxRetries),
|
||||
"", time.Now(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user