This commit is contained in:
2025-11-22 05:40:31 -06:00
parent 87cb54a17d
commit fb2e318eaa
12 changed files with 1891 additions and 353 deletions

View File

@@ -2,18 +2,20 @@ package api
import (
"context"
"crypto/subtle"
"database/sql"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/go-chi/chi/v5"
"fuego/internal/auth"
"fuego/pkg/types"
"github.com/go-chi/chi/v5"
"github.com/gorilla/websocket"
)
// handleListRunners lists all runners
@@ -154,215 +156,15 @@ func (s *Server) handleRegisterRunner(w http.ResponseWriter, r *http.Request) {
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
"id": runnerID,
"name": req.Name,
"hostname": req.Hostname,
"ip_address": req.IPAddress,
"status": types.RunnerStatusOnline,
"runner_secret": runnerSecret,
"hostname": req.Hostname,
"ip_address": req.IPAddress,
"status": types.RunnerStatusOnline,
"runner_secret": runnerSecret,
"manager_secret": managerSecret,
"verified": true,
"verified": true,
})
}
// handleRunnerHeartbeat updates runner heartbeat
func (s *Server) handleRunnerHeartbeat(w http.ResponseWriter, r *http.Request) {
runnerID, ok := r.Context().Value("runner_id").(int64)
if !ok {
s.respondError(w, http.StatusBadRequest, "runner_id not found in context")
return
}
_, err := s.db.Exec(
`UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`,
time.Now(), types.RunnerStatusOnline, runnerID,
)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update heartbeat: %v", err))
return
}
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Heartbeat updated"})
}
// handleGetRunnerTasks gets pending tasks for a runner
func (s *Server) handleGetRunnerTasks(w http.ResponseWriter, r *http.Request) {
runnerID, ok := r.Context().Value("runner_id").(int64)
if !ok {
s.respondError(w, http.StatusBadRequest, "runner_id not found in context")
return
}
// Get pending tasks
rows, err := s.db.Query(
`SELECT t.id, t.job_id, t.runner_id, t.frame_start, t.frame_end, t.status, t.output_path,
t.created_at, t.started_at, t.completed_at, t.error_message,
j.name as job_name, j.output_format
FROM tasks t
JOIN jobs j ON t.job_id = j.id
WHERE t.status = ? AND j.status != ?
ORDER BY t.created_at ASC
LIMIT 10`,
types.TaskStatusPending, types.JobStatusCancelled,
)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query tasks: %v", err))
return
}
defer rows.Close()
tasks := []map[string]interface{}{}
for rows.Next() {
var task types.Task
var runnerID sql.NullInt64
var startedAt, completedAt sql.NullTime
var jobName, outputFormat string
err := rows.Scan(
&task.ID, &task.JobID, &runnerID, &task.FrameStart, &task.FrameEnd,
&task.Status, &task.OutputPath, &task.CreatedAt,
&startedAt, &completedAt, &task.ErrorMessage,
&jobName, &outputFormat,
)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan task: %v", err))
return
}
if runnerID.Valid {
task.RunnerID = &runnerID.Int64
}
if startedAt.Valid {
task.StartedAt = &startedAt.Time
}
if completedAt.Valid {
task.CompletedAt = &completedAt.Time
}
// Get input files for the job
var inputFiles []string
fileRows, err := s.db.Query(
`SELECT file_path FROM job_files WHERE job_id = ? AND file_type = ?`,
task.JobID, types.JobFileTypeInput,
)
if err == nil {
for fileRows.Next() {
var filePath string
if err := fileRows.Scan(&filePath); err == nil {
inputFiles = append(inputFiles, filePath)
}
}
fileRows.Close()
}
tasks = append(tasks, map[string]interface{}{
"task": task,
"job_name": jobName,
"output_format": outputFormat,
"input_files": inputFiles,
})
// Assign task to runner
_, err = s.db.Exec(
`UPDATE tasks SET runner_id = ?, status = ? WHERE id = ?`,
runnerID, types.TaskStatusRunning, task.ID,
)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to assign task: %v", err))
return
}
}
s.respondJSON(w, http.StatusOK, tasks)
}
// handleCompleteTask marks a task as completed
func (s *Server) handleCompleteTask(w http.ResponseWriter, r *http.Request) {
taskID, err := parseID(r, "id")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
var req struct {
OutputPath string `json:"output_path"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
status := types.TaskStatusCompleted
if !req.Success {
status = types.TaskStatusFailed
}
now := time.Now()
_, err = s.db.Exec(
`UPDATE tasks SET status = ?, output_path = ?, completed_at = ?, error_message = ? WHERE id = ?`,
status, req.OutputPath, now, req.Error, taskID,
)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update task: %v", err))
return
}
// Update job progress
var jobID int64
var frameStart, frameEnd int
err = s.db.QueryRow(
`SELECT job_id, frame_start, frame_end FROM tasks WHERE id = ?`,
taskID,
).Scan(&jobID, &frameStart, &frameEnd)
if err == nil {
// Count completed tasks
var totalTasks, completedTasks int
s.db.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ?`,
jobID,
).Scan(&totalTasks)
s.db.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
jobID, types.TaskStatusCompleted,
).Scan(&completedTasks)
progress := float64(completedTasks) / float64(totalTasks) * 100.0
// Update job status
var jobStatus string
var outputFormat string
s.db.QueryRow(`SELECT output_format FROM jobs WHERE id = ?`, jobID).Scan(&outputFormat)
if completedTasks == totalTasks {
jobStatus = string(types.JobStatusCompleted)
now := time.Now()
s.db.Exec(
`UPDATE jobs SET status = ?, progress = ?, completed_at = ? WHERE id = ?`,
jobStatus, progress, now, jobID,
)
// For MP4 jobs, create a video generation task
if outputFormat == "MP4" {
go s.generateMP4Video(jobID)
}
} else {
jobStatus = string(types.JobStatusRunning)
var startedAt sql.NullTime
s.db.QueryRow(`SELECT started_at FROM jobs WHERE id = ?`, jobID).Scan(&startedAt)
if !startedAt.Valid {
now := time.Now()
s.db.Exec(`UPDATE jobs SET started_at = ? WHERE id = ?`, now, jobID)
}
s.db.Exec(
`UPDATE jobs SET status = ?, progress = ? WHERE id = ?`,
jobStatus, progress, jobID,
)
}
}
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Task completed"})
}
// handleUpdateTaskProgress updates task progress
func (s *Server) handleUpdateTaskProgress(w http.ResponseWriter, r *http.Request) {
_, err := parseID(r, "id")
@@ -471,7 +273,7 @@ func (s *Server) generateMP4Video(jobID int64) {
// This would be called by a runner or external process
// For now, we'll create a special task that runners can pick up
// In a production system, you might want to use a job queue or have a dedicated video processor
// Get all PNG output files for this job
rows, err := s.db.Query(
`SELECT file_path, file_name FROM job_files
@@ -516,12 +318,12 @@ func (s *Server) handleGetJobStatusForRunner(w http.ResponseWriter, r *http.Requ
err = s.db.QueryRow(
`SELECT id, user_id, name, status, progress, frame_start, frame_end, output_format,
created_at, started_at, completed_at, error_message
allow_parallel_runners, created_at, started_at, completed_at, error_message
FROM jobs WHERE id = ?`,
jobID,
).Scan(
&job.ID, &job.UserID, &job.Name, &job.Status, &job.Progress,
&job.FrameStart, &job.FrameEnd, &job.OutputFormat,
&job.FrameStart, &job.FrameEnd, &job.OutputFormat, &job.AllowParallelRunners,
&job.CreatedAt, &startedAt, &completedAt, &job.ErrorMessage,
)
@@ -580,3 +382,500 @@ func (s *Server) handleGetJobFilesForRunner(w http.ResponseWriter, r *http.Reque
s.respondJSON(w, http.StatusOK, files)
}
// WebSocket message types
type WSMessage struct {
Type string `json:"type"`
Data json.RawMessage `json:"data"`
Timestamp int64 `json:"timestamp"`
}
type WSTaskAssignment struct {
TaskID int64 `json:"task_id"`
JobID int64 `json:"job_id"`
JobName string `json:"job_name"`
OutputFormat string `json:"output_format"`
FrameStart int `json:"frame_start"`
FrameEnd int `json:"frame_end"`
InputFiles []string `json:"input_files"`
}
type WSLogEntry struct {
TaskID int64 `json:"task_id"`
LogLevel string `json:"log_level"`
Message string `json:"message"`
StepName string `json:"step_name,omitempty"`
}
type WSTaskUpdate struct {
TaskID int64 `json:"task_id"`
Status string `json:"status"`
OutputPath string `json:"output_path,omitempty"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
}
// handleRunnerWebSocket handles WebSocket connections from runners
func (s *Server) handleRunnerWebSocket(w http.ResponseWriter, r *http.Request) {
// Get runner ID and signature from query params
runnerIDStr := r.URL.Query().Get("runner_id")
signature := r.URL.Query().Get("signature")
timestampStr := r.URL.Query().Get("timestamp")
if runnerIDStr == "" || signature == "" || timestampStr == "" {
s.respondError(w, http.StatusBadRequest, "runner_id, signature, and timestamp required")
return
}
var runnerID int64
_, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID)
if err != nil {
s.respondError(w, http.StatusBadRequest, "invalid runner_id")
return
}
// Get runner secret
runnerSecret, err := s.secrets.GetRunnerSecret(runnerID)
if err != nil {
s.respondError(w, http.StatusUnauthorized, "runner not found or not verified")
return
}
// Verify signature
var timestamp int64
_, err = fmt.Sscanf(timestampStr, "%d", &timestamp)
if err != nil {
s.respondError(w, http.StatusBadRequest, "invalid timestamp")
return
}
// Verify signature manually (similar to HTTP auth)
timestampTime := time.Unix(timestamp, 0)
// Check timestamp is not too old
if time.Since(timestampTime) > 5*time.Minute {
s.respondError(w, http.StatusUnauthorized, "timestamp too old")
return
}
// Check timestamp is not in the future (allow 1 minute clock skew)
if timestampTime.After(time.Now().Add(1 * time.Minute)) {
s.respondError(w, http.StatusUnauthorized, "timestamp in future")
return
}
// Build the message that should be signed
path := r.URL.Path
expectedSig := auth.SignRequest("GET", path, "", runnerSecret, timestampTime)
// Compare signatures (constant-time)
if subtle.ConstantTimeCompare([]byte(signature), []byte(expectedSig)) != 1 {
s.respondError(w, http.StatusUnauthorized, "invalid signature")
return
}
// Upgrade to WebSocket
conn, err := s.wsUpgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("Failed to upgrade WebSocket: %v", err)
return
}
defer conn.Close()
// Register connection
s.runnerConnsMu.Lock()
// Remove old connection if exists
if oldConn, exists := s.runnerConns[runnerID]; exists {
oldConn.Close()
}
s.runnerConns[runnerID] = conn
s.runnerConnsMu.Unlock()
// Update runner status to online
_, _ = s.db.Exec(
`UPDATE runners SET status = ?, last_heartbeat = ? WHERE id = ?`,
types.RunnerStatusOnline, time.Now(), runnerID,
)
// Cleanup on disconnect
defer func() {
s.runnerConnsMu.Lock()
delete(s.runnerConns, runnerID)
s.runnerConnsMu.Unlock()
_, _ = s.db.Exec(
`UPDATE runners SET status = ? WHERE id = ?`,
types.RunnerStatusOffline, runnerID,
)
}()
// Set ping handler to update heartbeat
conn.SetPongHandler(func(string) error {
_, _ = s.db.Exec(
`UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`,
time.Now(), types.RunnerStatusOnline, runnerID,
)
return nil
})
// Send ping every 30 seconds
go func() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.runnerConnsMu.RLock()
conn, exists := s.runnerConns[runnerID]
s.runnerConnsMu.RUnlock()
if !exists {
return
}
if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil {
return
}
}
}
}()
// Handle incoming messages
for {
var msg WSMessage
err := conn.ReadJSON(&msg)
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("WebSocket error for runner %d: %v", runnerID, err)
}
break
}
switch msg.Type {
case "heartbeat":
// Update heartbeat
_, _ = s.db.Exec(
`UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`,
time.Now(), types.RunnerStatusOnline, runnerID,
)
case "log_entry":
var logEntry WSLogEntry
if err := json.Unmarshal(msg.Data, &logEntry); err == nil {
s.handleWebSocketLog(runnerID, logEntry)
}
case "task_update":
var taskUpdate WSTaskUpdate
if err := json.Unmarshal(msg.Data, &taskUpdate); err == nil {
s.handleWebSocketTaskUpdate(runnerID, taskUpdate)
}
case "task_complete":
var taskUpdate WSTaskUpdate
if err := json.Unmarshal(msg.Data, &taskUpdate); err == nil {
s.handleWebSocketTaskComplete(runnerID, taskUpdate)
}
}
}
}
// handleWebSocketLog handles log entries from WebSocket
func (s *Server) handleWebSocketLog(runnerID int64, logEntry WSLogEntry) {
// Store log in database
_, err := s.db.Exec(
`INSERT INTO task_logs (task_id, runner_id, log_level, message, step_name, created_at)
VALUES (?, ?, ?, ?, ?, ?)`,
logEntry.TaskID, runnerID, logEntry.LogLevel, logEntry.Message, logEntry.StepName, time.Now(),
)
if err != nil {
log.Printf("Failed to store log: %v", err)
return
}
// Broadcast to frontend clients
s.broadcastLogToFrontend(logEntry.TaskID, logEntry)
}
// handleWebSocketTaskUpdate handles task status updates from WebSocket
func (s *Server) handleWebSocketTaskUpdate(runnerID int64, taskUpdate WSTaskUpdate) {
// This can be used for progress updates
// For now, we'll just log it
log.Printf("Task %d update from runner %d: %s", taskUpdate.TaskID, runnerID, taskUpdate.Status)
}
// handleWebSocketTaskComplete handles task completion from WebSocket
func (s *Server) handleWebSocketTaskComplete(runnerID int64, taskUpdate WSTaskUpdate) {
// Verify task belongs to runner
var taskRunnerID sql.NullInt64
err := s.db.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskUpdate.TaskID).Scan(&taskRunnerID)
if err != nil || !taskRunnerID.Valid || taskRunnerID.Int64 != runnerID {
log.Printf("Task %d does not belong to runner %d", taskUpdate.TaskID, runnerID)
return
}
status := types.TaskStatusCompleted
if !taskUpdate.Success {
status = types.TaskStatusFailed
}
now := time.Now()
_, err = s.db.Exec(
`UPDATE tasks SET status = ?, output_path = ?, completed_at = ?, error_message = ? WHERE id = ?`,
status, taskUpdate.OutputPath, now, taskUpdate.Error, taskUpdate.TaskID,
)
if err != nil {
log.Printf("Failed to update task: %v", err)
return
}
// Update job progress
var jobID int64
var frameStart, frameEnd int
err = s.db.QueryRow(
`SELECT job_id, frame_start, frame_end FROM tasks WHERE id = ?`,
taskUpdate.TaskID,
).Scan(&jobID, &frameStart, &frameEnd)
if err == nil {
var totalTasks, completedTasks int
s.db.QueryRow(`SELECT COUNT(*) FROM tasks WHERE job_id = ?`, jobID).Scan(&totalTasks)
s.db.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
jobID, types.TaskStatusCompleted,
).Scan(&completedTasks)
progress := float64(completedTasks) / float64(totalTasks) * 100.0
var jobStatus string
var outputFormat string
s.db.QueryRow(`SELECT output_format FROM jobs WHERE id = ?`, jobID).Scan(&outputFormat)
if completedTasks == totalTasks {
jobStatus = string(types.JobStatusCompleted)
s.db.Exec(
`UPDATE jobs SET status = ?, progress = ?, completed_at = ? WHERE id = ?`,
jobStatus, progress, now, jobID,
)
if outputFormat == "MP4" {
go s.generateMP4Video(jobID)
}
} else {
jobStatus = string(types.JobStatusRunning)
var startedAt sql.NullTime
s.db.QueryRow(`SELECT started_at FROM jobs WHERE id = ?`, jobID).Scan(&startedAt)
if !startedAt.Valid {
s.db.Exec(`UPDATE jobs SET started_at = ? WHERE id = ?`, now, jobID)
}
s.db.Exec(
`UPDATE jobs SET status = ?, progress = ? WHERE id = ?`,
jobStatus, progress, jobID,
)
}
}
}
// broadcastLogToFrontend broadcasts log to connected frontend clients
func (s *Server) broadcastLogToFrontend(taskID int64, logEntry WSLogEntry) {
// Get job_id from task
var jobID int64
err := s.db.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&jobID)
if err != nil {
return
}
key := fmt.Sprintf("%d:%d", jobID, taskID)
s.frontendConnsMu.RLock()
conn, exists := s.frontendConns[key]
s.frontendConnsMu.RUnlock()
if exists && conn != nil {
// Get full log entry from database for consistency
var log types.TaskLog
var runnerID sql.NullInt64
err := s.db.QueryRow(
`SELECT id, task_id, runner_id, log_level, message, step_name, created_at
FROM task_logs WHERE task_id = ? AND message = ? ORDER BY id DESC LIMIT 1`,
taskID, logEntry.Message,
).Scan(&log.ID, &log.TaskID, &runnerID, &log.LogLevel, &log.Message, &log.StepName, &log.CreatedAt)
if err == nil {
if runnerID.Valid {
log.RunnerID = &runnerID.Int64
}
msg := map[string]interface{}{
"type": "log",
"data": log,
"timestamp": time.Now().Unix(),
}
conn.WriteJSON(msg)
}
}
}
// distributeTasksToRunners pushes available tasks to connected runners
func (s *Server) distributeTasksToRunners() {
// Get all pending tasks
rows, err := s.db.Query(
`SELECT t.id, t.job_id, t.frame_start, t.frame_end, j.allow_parallel_runners, j.status as job_status
FROM tasks t
JOIN jobs j ON t.job_id = j.id
WHERE t.status = ? AND j.status != ?
ORDER BY t.created_at ASC
LIMIT 100`,
types.TaskStatusPending, types.JobStatusCancelled,
)
if err != nil {
log.Printf("Failed to query pending tasks: %v", err)
return
}
defer rows.Close()
var pendingTasks []struct {
TaskID int64
JobID int64
FrameStart int
FrameEnd int
AllowParallelRunners bool
}
for rows.Next() {
var t struct {
TaskID int64
JobID int64
FrameStart int
FrameEnd int
AllowParallelRunners bool
}
var allowParallel int
var jobStatus string
err := rows.Scan(&t.TaskID, &t.JobID, &t.FrameStart, &t.FrameEnd, &allowParallel, &jobStatus)
if err != nil {
continue
}
t.AllowParallelRunners = allowParallel == 1
pendingTasks = append(pendingTasks, t)
}
if len(pendingTasks) == 0 {
return
}
// Get connected runners
s.runnerConnsMu.RLock()
connectedRunners := make([]int64, 0, len(s.runnerConns))
for runnerID := range s.runnerConns {
connectedRunners = append(connectedRunners, runnerID)
}
s.runnerConnsMu.RUnlock()
if len(connectedRunners) == 0 {
return
}
// Distribute tasks to runners
for _, task := range pendingTasks {
// Check if task is already assigned
var assignedRunnerID sql.NullInt64
err := s.db.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", task.TaskID).Scan(&assignedRunnerID)
if err == nil && assignedRunnerID.Valid {
continue // Already assigned
}
// Find available runner
var selectedRunnerID int64
for _, runnerID := range connectedRunners {
// Check if runner is busy (has running tasks)
var runningCount int
s.db.QueryRow(
`SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status = ?`,
runnerID, types.TaskStatusRunning,
).Scan(&runningCount)
if runningCount > 0 {
continue // Runner is busy
}
// For non-parallel jobs, check if runner already has tasks from this job
if !task.AllowParallelRunners {
var jobTaskCount int
s.db.QueryRow(
`SELECT COUNT(*) FROM tasks
WHERE job_id = ? AND runner_id = ? AND status IN (?, ?)`,
task.JobID, runnerID, types.TaskStatusPending, types.TaskStatusRunning,
).Scan(&jobTaskCount)
if jobTaskCount > 0 {
continue // Another runner is working on this job
}
}
selectedRunnerID = runnerID
break
}
if selectedRunnerID == 0 {
continue // No available runner
}
// Assign task to runner
if err := s.assignTaskToRunner(selectedRunnerID, task.TaskID); err != nil {
log.Printf("Failed to assign task %d to runner %d: %v", task.TaskID, selectedRunnerID, err)
}
}
}
// assignTaskToRunner sends a task to a runner via WebSocket
func (s *Server) assignTaskToRunner(runnerID int64, taskID int64) error {
s.runnerConnsMu.RLock()
conn, exists := s.runnerConns[runnerID]
s.runnerConnsMu.RUnlock()
if !exists {
return fmt.Errorf("runner %d not connected", runnerID)
}
// Get task details
var task WSTaskAssignment
var jobName, outputFormat string
err := s.db.QueryRow(
`SELECT t.job_id, t.frame_start, t.frame_end, j.name, j.output_format
FROM tasks t JOIN jobs j ON t.job_id = j.id WHERE t.id = ?`,
taskID,
).Scan(&task.JobID, &task.FrameStart, &task.FrameEnd, &jobName, &outputFormat)
if err != nil {
return err
}
task.TaskID = taskID
task.JobID = task.JobID
task.JobName = jobName
task.OutputFormat = outputFormat
// Get input files
rows, err := s.db.Query(
`SELECT file_path FROM job_files WHERE job_id = ? AND file_type = ?`,
task.JobID, types.JobFileTypeInput,
)
if err == nil {
defer rows.Close()
for rows.Next() {
var filePath string
if err := rows.Scan(&filePath); err == nil {
task.InputFiles = append(task.InputFiles, filePath)
}
}
}
// Assign task to runner in database
now := time.Now()
_, err = s.db.Exec(
`UPDATE tasks SET runner_id = ?, status = ?, started_at = ? WHERE id = ?`,
runnerID, types.TaskStatusRunning, now, taskID,
)
if err != nil {
return err
}
// Send task via WebSocket
msg := WSMessage{
Type: "task_assignment",
Timestamp: time.Now().Unix(),
}
msg.Data, _ = json.Marshal(task)
return conn.WriteJSON(msg)
}