1055 lines
31 KiB
Go
1055 lines
31 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"crypto/subtle"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"time"
|
|
|
|
"fuego/internal/auth"
|
|
"fuego/pkg/types"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// handleListRunners lists all runners
|
|
func (s *Server) handleListRunners(w http.ResponseWriter, r *http.Request) {
|
|
_, err := getUserID(r)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusUnauthorized, err.Error())
|
|
return
|
|
}
|
|
|
|
rows, err := s.db.Query(
|
|
`SELECT id, name, hostname, ip_address, status, last_heartbeat, capabilities, created_at
|
|
FROM runners ORDER BY created_at DESC`,
|
|
)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query runners: %v", err))
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
runners := []types.Runner{}
|
|
for rows.Next() {
|
|
var runner types.Runner
|
|
err := rows.Scan(
|
|
&runner.ID, &runner.Name, &runner.Hostname, &runner.IPAddress,
|
|
&runner.Status, &runner.LastHeartbeat, &runner.Capabilities, &runner.CreatedAt,
|
|
)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan runner: %v", err))
|
|
return
|
|
}
|
|
runners = append(runners, runner)
|
|
}
|
|
|
|
s.respondJSON(w, http.StatusOK, runners)
|
|
}
|
|
|
|
// runnerAuthMiddleware verifies runner requests using HMAC signatures
|
|
func (s *Server) runnerAuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
// Get runner ID from query string
|
|
runnerIDStr := r.URL.Query().Get("runner_id")
|
|
if runnerIDStr == "" {
|
|
s.respondError(w, http.StatusBadRequest, "runner_id required in query string")
|
|
return
|
|
}
|
|
|
|
var runnerID int64
|
|
_, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, "invalid runner_id")
|
|
return
|
|
}
|
|
|
|
// Get runner secret
|
|
runnerSecret, err := s.secrets.GetRunnerSecret(runnerID)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusUnauthorized, "runner not found or not verified")
|
|
return
|
|
}
|
|
|
|
// Verify request signature
|
|
valid, err := auth.VerifyRequest(r, runnerSecret, 5*time.Minute)
|
|
if err != nil || !valid {
|
|
s.respondError(w, http.StatusUnauthorized, "invalid signature")
|
|
return
|
|
}
|
|
|
|
// Add runner ID to context
|
|
ctx := r.Context()
|
|
ctx = context.WithValue(ctx, "runner_id", runnerID)
|
|
next(w, r.WithContext(ctx))
|
|
}
|
|
}
|
|
|
|
// handleRegisterRunner registers a new runner
|
|
func (s *Server) handleRegisterRunner(w http.ResponseWriter, r *http.Request) {
|
|
var req struct {
|
|
types.RegisterRunnerRequest
|
|
RegistrationToken string `json:"registration_token"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
|
return
|
|
}
|
|
|
|
if req.Name == "" {
|
|
s.respondError(w, http.StatusBadRequest, "Runner name is required")
|
|
return
|
|
}
|
|
|
|
if req.RegistrationToken == "" {
|
|
s.respondError(w, http.StatusBadRequest, "Registration token is required")
|
|
return
|
|
}
|
|
|
|
// Validate registration token
|
|
valid, err := s.secrets.ValidateRegistrationToken(req.RegistrationToken)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to validate token: %v", err))
|
|
return
|
|
}
|
|
if !valid {
|
|
s.respondError(w, http.StatusUnauthorized, "Invalid or expired registration token")
|
|
return
|
|
}
|
|
|
|
// Get manager secret
|
|
managerSecret, err := s.secrets.GetManagerSecret()
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, "Failed to get manager secret")
|
|
return
|
|
}
|
|
|
|
// Generate runner secret
|
|
runnerSecret, err := s.secrets.GenerateRunnerSecret()
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, "Failed to generate runner secret")
|
|
return
|
|
}
|
|
|
|
// Register runner
|
|
result, err := s.db.Exec(
|
|
`INSERT INTO runners (name, hostname, ip_address, status, last_heartbeat, capabilities,
|
|
registration_token, runner_secret, manager_secret, verified)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
req.Name, req.Hostname, req.IPAddress, types.RunnerStatusOnline, time.Now(), req.Capabilities,
|
|
req.RegistrationToken, runnerSecret, managerSecret, true,
|
|
)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to register runner: %v", err))
|
|
return
|
|
}
|
|
|
|
runnerID, _ := result.LastInsertId()
|
|
|
|
// Return runner info with secrets
|
|
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
|
|
"id": runnerID,
|
|
"name": req.Name,
|
|
"hostname": req.Hostname,
|
|
"ip_address": req.IPAddress,
|
|
"status": types.RunnerStatusOnline,
|
|
"runner_secret": runnerSecret,
|
|
"manager_secret": managerSecret,
|
|
"verified": true,
|
|
})
|
|
}
|
|
|
|
// handleUpdateTaskProgress updates task progress
|
|
func (s *Server) handleUpdateTaskProgress(w http.ResponseWriter, r *http.Request) {
|
|
_, err := parseID(r, "id")
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
var req struct {
|
|
Progress float64 `json:"progress"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
|
return
|
|
}
|
|
|
|
// This is mainly for logging/debugging, actual progress is calculated from completed tasks
|
|
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Progress updated"})
|
|
}
|
|
|
|
// handleUpdateTaskStep handles step start/complete events from runners
|
|
func (s *Server) handleUpdateTaskStep(w http.ResponseWriter, r *http.Request) {
|
|
// Get runner ID from context (set by runnerAuthMiddleware)
|
|
runnerID, ok := r.Context().Value("runner_id").(int64)
|
|
if !ok {
|
|
s.respondError(w, http.StatusUnauthorized, "runner_id not found in context")
|
|
return
|
|
}
|
|
|
|
taskID, err := parseID(r, "id")
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
var req struct {
|
|
StepName string `json:"step_name"`
|
|
Status string `json:"status"` // "pending", "running", "completed", "failed", "skipped"
|
|
DurationMs *int `json:"duration_ms,omitempty"`
|
|
ErrorMessage string `json:"error_message,omitempty"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
|
return
|
|
}
|
|
|
|
// Verify task belongs to runner
|
|
var taskRunnerID sql.NullInt64
|
|
err = s.db.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskID).Scan(&taskRunnerID)
|
|
if err == sql.ErrNoRows {
|
|
s.respondError(w, http.StatusNotFound, "Task not found")
|
|
return
|
|
}
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify task: %v", err))
|
|
return
|
|
}
|
|
if !taskRunnerID.Valid || taskRunnerID.Int64 != runnerID {
|
|
s.respondError(w, http.StatusForbidden, "Task does not belong to this runner")
|
|
return
|
|
}
|
|
|
|
now := time.Now()
|
|
var stepID int64
|
|
|
|
// Check if step already exists
|
|
var existingStepID sql.NullInt64
|
|
err = s.db.QueryRow(
|
|
`SELECT id FROM task_steps WHERE task_id = ? AND step_name = ?`,
|
|
taskID, req.StepName,
|
|
).Scan(&existingStepID)
|
|
|
|
if err == sql.ErrNoRows || !existingStepID.Valid {
|
|
// Create new step
|
|
var startedAt *time.Time
|
|
var completedAt *time.Time
|
|
if req.Status == string(types.StepStatusRunning) || req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) {
|
|
startedAt = &now
|
|
}
|
|
if req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) {
|
|
completedAt = &now
|
|
}
|
|
|
|
result, err := s.db.Exec(
|
|
`INSERT INTO task_steps (task_id, step_name, status, started_at, completed_at, duration_ms, error_message)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
|
taskID, req.StepName, req.Status, startedAt, completedAt, req.DurationMs, req.ErrorMessage,
|
|
)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create step: %v", err))
|
|
return
|
|
}
|
|
stepID, _ = result.LastInsertId()
|
|
} else {
|
|
// Update existing step
|
|
stepID = existingStepID.Int64
|
|
var startedAt *time.Time
|
|
var completedAt *time.Time
|
|
|
|
// Get existing started_at if status is running/completed/failed
|
|
if req.Status == string(types.StepStatusRunning) || req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) {
|
|
var existingStartedAt sql.NullTime
|
|
s.db.QueryRow(`SELECT started_at FROM task_steps WHERE id = ?`, stepID).Scan(&existingStartedAt)
|
|
if existingStartedAt.Valid {
|
|
startedAt = &existingStartedAt.Time
|
|
} else {
|
|
startedAt = &now
|
|
}
|
|
}
|
|
|
|
if req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) {
|
|
completedAt = &now
|
|
}
|
|
|
|
_, err = s.db.Exec(
|
|
`UPDATE task_steps SET status = ?, started_at = ?, completed_at = ?, duration_ms = ?, error_message = ?
|
|
WHERE id = ?`,
|
|
req.Status, startedAt, completedAt, req.DurationMs, req.ErrorMessage, stepID,
|
|
)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update step: %v", err))
|
|
return
|
|
}
|
|
}
|
|
|
|
s.respondJSON(w, http.StatusOK, map[string]interface{}{
|
|
"step_id": stepID,
|
|
"message": "Step updated successfully",
|
|
})
|
|
}
|
|
|
|
// handleDownloadFileForRunner allows runners to download job files
|
|
func (s *Server) handleDownloadFileForRunner(w http.ResponseWriter, r *http.Request) {
|
|
jobID, err := parseID(r, "jobId")
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
fileName := chi.URLParam(r, "fileName")
|
|
|
|
// Find the file in the database
|
|
var filePath string
|
|
err = s.db.QueryRow(
|
|
`SELECT file_path FROM job_files WHERE job_id = ? AND file_name = ?`,
|
|
jobID, fileName,
|
|
).Scan(&filePath)
|
|
if err == sql.ErrNoRows {
|
|
s.respondError(w, http.StatusNotFound, "File not found")
|
|
return
|
|
}
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query file: %v", err))
|
|
return
|
|
}
|
|
|
|
// Open and serve file
|
|
file, err := s.storage.GetFile(filePath)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusNotFound, "File not found on disk")
|
|
return
|
|
}
|
|
defer file.Close()
|
|
|
|
w.Header().Set("Content-Type", "application/octet-stream")
|
|
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", fileName))
|
|
io.Copy(w, file)
|
|
}
|
|
|
|
// handleUploadFileFromRunner allows runners to upload output files
|
|
func (s *Server) handleUploadFileFromRunner(w http.ResponseWriter, r *http.Request) {
|
|
jobID, err := parseID(r, "jobId")
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
err = r.ParseMultipartForm(100 << 20) // 100 MB
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, "Failed to parse form")
|
|
return
|
|
}
|
|
|
|
file, header, err := r.FormFile("file")
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, "No file provided")
|
|
return
|
|
}
|
|
defer file.Close()
|
|
|
|
// Save file
|
|
filePath, err := s.storage.SaveOutput(jobID, header.Filename, file)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to save file: %v", err))
|
|
return
|
|
}
|
|
|
|
// Record in database
|
|
_, err = s.db.Exec(
|
|
`INSERT INTO job_files (job_id, file_type, file_path, file_name, file_size)
|
|
VALUES (?, ?, ?, ?, ?)`,
|
|
jobID, types.JobFileTypeOutput, filePath, header.Filename, header.Size,
|
|
)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to record file: %v", err))
|
|
return
|
|
}
|
|
|
|
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
|
|
"file_path": filePath,
|
|
"file_name": header.Filename,
|
|
})
|
|
}
|
|
|
|
// generateMP4Video generates MP4 video from PNG frames for a completed job
|
|
func (s *Server) generateMP4Video(jobID int64) {
|
|
// This would be called by a runner or external process
|
|
// For now, we'll create a special task that runners can pick up
|
|
// In a production system, you might want to use a job queue or have a dedicated video processor
|
|
|
|
// Get all PNG output files for this job
|
|
rows, err := s.db.Query(
|
|
`SELECT file_path, file_name FROM job_files
|
|
WHERE job_id = ? AND file_type = ? AND file_name LIKE '%.png'
|
|
ORDER BY file_name`,
|
|
jobID, types.JobFileTypeOutput,
|
|
)
|
|
if err != nil {
|
|
log.Printf("Failed to query PNG files for job %d: %v", jobID, err)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
var pngFiles []string
|
|
for rows.Next() {
|
|
var filePath, fileName string
|
|
if err := rows.Scan(&filePath, &fileName); err == nil {
|
|
pngFiles = append(pngFiles, filePath)
|
|
}
|
|
}
|
|
|
|
if len(pngFiles) == 0 {
|
|
log.Printf("No PNG files found for job %d", jobID)
|
|
return
|
|
}
|
|
|
|
// Note: Video generation will be handled by runners when they complete tasks
|
|
// Runners can check job status and generate MP4 when all frames are complete
|
|
log.Printf("Job %d completed with %d PNG frames - ready for MP4 generation", jobID, len(pngFiles))
|
|
}
|
|
|
|
// handleGetJobStatusForRunner allows runners to check job status
|
|
func (s *Server) handleGetJobStatusForRunner(w http.ResponseWriter, r *http.Request) {
|
|
jobID, err := parseID(r, "jobId")
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
var job types.Job
|
|
var startedAt, completedAt sql.NullTime
|
|
|
|
err = s.db.QueryRow(
|
|
`SELECT id, user_id, name, status, progress, frame_start, frame_end, output_format,
|
|
allow_parallel_runners, created_at, started_at, completed_at, error_message
|
|
FROM jobs WHERE id = ?`,
|
|
jobID,
|
|
).Scan(
|
|
&job.ID, &job.UserID, &job.Name, &job.Status, &job.Progress,
|
|
&job.FrameStart, &job.FrameEnd, &job.OutputFormat, &job.AllowParallelRunners,
|
|
&job.CreatedAt, &startedAt, &completedAt, &job.ErrorMessage,
|
|
)
|
|
|
|
if err == sql.ErrNoRows {
|
|
s.respondError(w, http.StatusNotFound, "Job not found")
|
|
return
|
|
}
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query job: %v", err))
|
|
return
|
|
}
|
|
|
|
if startedAt.Valid {
|
|
job.StartedAt = &startedAt.Time
|
|
}
|
|
if completedAt.Valid {
|
|
job.CompletedAt = &completedAt.Time
|
|
}
|
|
|
|
s.respondJSON(w, http.StatusOK, job)
|
|
}
|
|
|
|
// handleGetJobFilesForRunner allows runners to get job files
|
|
func (s *Server) handleGetJobFilesForRunner(w http.ResponseWriter, r *http.Request) {
|
|
jobID, err := parseID(r, "jobId")
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
rows, err := s.db.Query(
|
|
`SELECT id, job_id, file_type, file_path, file_name, file_size, created_at
|
|
FROM job_files WHERE job_id = ? ORDER BY file_name`,
|
|
jobID,
|
|
)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query files: %v", err))
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
files := []types.JobFile{}
|
|
for rows.Next() {
|
|
var file types.JobFile
|
|
err := rows.Scan(
|
|
&file.ID, &file.JobID, &file.FileType, &file.FilePath,
|
|
&file.FileName, &file.FileSize, &file.CreatedAt,
|
|
)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan file: %v", err))
|
|
return
|
|
}
|
|
files = append(files, file)
|
|
}
|
|
|
|
s.respondJSON(w, http.StatusOK, files)
|
|
}
|
|
|
|
// WebSocket message types
|
|
type WSMessage struct {
|
|
Type string `json:"type"`
|
|
Data json.RawMessage `json:"data"`
|
|
Timestamp int64 `json:"timestamp"`
|
|
}
|
|
|
|
type WSTaskAssignment struct {
|
|
TaskID int64 `json:"task_id"`
|
|
JobID int64 `json:"job_id"`
|
|
JobName string `json:"job_name"`
|
|
OutputFormat string `json:"output_format"`
|
|
FrameStart int `json:"frame_start"`
|
|
FrameEnd int `json:"frame_end"`
|
|
TaskType string `json:"task_type"`
|
|
InputFiles []string `json:"input_files"`
|
|
}
|
|
|
|
type WSLogEntry struct {
|
|
TaskID int64 `json:"task_id"`
|
|
LogLevel string `json:"log_level"`
|
|
Message string `json:"message"`
|
|
StepName string `json:"step_name,omitempty"`
|
|
}
|
|
|
|
type WSTaskUpdate struct {
|
|
TaskID int64 `json:"task_id"`
|
|
Status string `json:"status"`
|
|
OutputPath string `json:"output_path,omitempty"`
|
|
Success bool `json:"success"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
// handleRunnerWebSocket handles WebSocket connections from runners
|
|
func (s *Server) handleRunnerWebSocket(w http.ResponseWriter, r *http.Request) {
|
|
// Get runner ID and signature from query params
|
|
runnerIDStr := r.URL.Query().Get("runner_id")
|
|
signature := r.URL.Query().Get("signature")
|
|
timestampStr := r.URL.Query().Get("timestamp")
|
|
|
|
if runnerIDStr == "" || signature == "" || timestampStr == "" {
|
|
s.respondError(w, http.StatusBadRequest, "runner_id, signature, and timestamp required")
|
|
return
|
|
}
|
|
|
|
var runnerID int64
|
|
_, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, "invalid runner_id")
|
|
return
|
|
}
|
|
|
|
// Get runner secret
|
|
runnerSecret, err := s.secrets.GetRunnerSecret(runnerID)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusUnauthorized, "runner not found or not verified")
|
|
return
|
|
}
|
|
|
|
// Verify signature
|
|
var timestamp int64
|
|
_, err = fmt.Sscanf(timestampStr, "%d", ×tamp)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, "invalid timestamp")
|
|
return
|
|
}
|
|
|
|
// Verify signature manually (similar to HTTP auth)
|
|
timestampTime := time.Unix(timestamp, 0)
|
|
|
|
// Check timestamp is not too old
|
|
if time.Since(timestampTime) > 5*time.Minute {
|
|
s.respondError(w, http.StatusUnauthorized, "timestamp too old")
|
|
return
|
|
}
|
|
|
|
// Check timestamp is not in the future (allow 1 minute clock skew)
|
|
if timestampTime.After(time.Now().Add(1 * time.Minute)) {
|
|
s.respondError(w, http.StatusUnauthorized, "timestamp in future")
|
|
return
|
|
}
|
|
|
|
// Build the message that should be signed
|
|
path := r.URL.Path
|
|
expectedSig := auth.SignRequest("GET", path, "", runnerSecret, timestampTime)
|
|
|
|
// Compare signatures (constant-time)
|
|
if subtle.ConstantTimeCompare([]byte(signature), []byte(expectedSig)) != 1 {
|
|
s.respondError(w, http.StatusUnauthorized, "invalid signature")
|
|
return
|
|
}
|
|
|
|
// Upgrade to WebSocket
|
|
conn, err := s.wsUpgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
log.Printf("Failed to upgrade WebSocket: %v", err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Register connection
|
|
s.runnerConnsMu.Lock()
|
|
// Remove old connection if exists
|
|
if oldConn, exists := s.runnerConns[runnerID]; exists {
|
|
oldConn.Close()
|
|
}
|
|
s.runnerConns[runnerID] = conn
|
|
s.runnerConnsMu.Unlock()
|
|
|
|
// Update runner status to online
|
|
_, _ = s.db.Exec(
|
|
`UPDATE runners SET status = ?, last_heartbeat = ? WHERE id = ?`,
|
|
types.RunnerStatusOnline, time.Now(), runnerID,
|
|
)
|
|
|
|
// Cleanup on disconnect
|
|
defer func() {
|
|
s.runnerConnsMu.Lock()
|
|
delete(s.runnerConns, runnerID)
|
|
s.runnerConnsMu.Unlock()
|
|
_, _ = s.db.Exec(
|
|
`UPDATE runners SET status = ? WHERE id = ?`,
|
|
types.RunnerStatusOffline, runnerID,
|
|
)
|
|
}()
|
|
|
|
// Set ping handler to update heartbeat
|
|
conn.SetPongHandler(func(string) error {
|
|
_, _ = s.db.Exec(
|
|
`UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`,
|
|
time.Now(), types.RunnerStatusOnline, runnerID,
|
|
)
|
|
return nil
|
|
})
|
|
|
|
// Send ping every 30 seconds
|
|
go func() {
|
|
ticker := time.NewTicker(30 * time.Second)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
s.runnerConnsMu.RLock()
|
|
conn, exists := s.runnerConns[runnerID]
|
|
s.runnerConnsMu.RUnlock()
|
|
if !exists {
|
|
return
|
|
}
|
|
if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Handle incoming messages
|
|
for {
|
|
var msg WSMessage
|
|
err := conn.ReadJSON(&msg)
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
log.Printf("WebSocket error for runner %d: %v", runnerID, err)
|
|
}
|
|
break
|
|
}
|
|
|
|
switch msg.Type {
|
|
case "heartbeat":
|
|
// Update heartbeat
|
|
_, _ = s.db.Exec(
|
|
`UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`,
|
|
time.Now(), types.RunnerStatusOnline, runnerID,
|
|
)
|
|
|
|
case "log_entry":
|
|
var logEntry WSLogEntry
|
|
if err := json.Unmarshal(msg.Data, &logEntry); err == nil {
|
|
s.handleWebSocketLog(runnerID, logEntry)
|
|
}
|
|
|
|
case "task_update":
|
|
var taskUpdate WSTaskUpdate
|
|
if err := json.Unmarshal(msg.Data, &taskUpdate); err == nil {
|
|
s.handleWebSocketTaskUpdate(runnerID, taskUpdate)
|
|
}
|
|
|
|
case "task_complete":
|
|
var taskUpdate WSTaskUpdate
|
|
if err := json.Unmarshal(msg.Data, &taskUpdate); err == nil {
|
|
s.handleWebSocketTaskComplete(runnerID, taskUpdate)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleWebSocketLog handles log entries from WebSocket
|
|
func (s *Server) handleWebSocketLog(runnerID int64, logEntry WSLogEntry) {
|
|
// Store log in database
|
|
_, err := s.db.Exec(
|
|
`INSERT INTO task_logs (task_id, runner_id, log_level, message, step_name, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)`,
|
|
logEntry.TaskID, runnerID, logEntry.LogLevel, logEntry.Message, logEntry.StepName, time.Now(),
|
|
)
|
|
if err != nil {
|
|
log.Printf("Failed to store log: %v", err)
|
|
return
|
|
}
|
|
|
|
// Broadcast to frontend clients
|
|
s.broadcastLogToFrontend(logEntry.TaskID, logEntry)
|
|
}
|
|
|
|
// handleWebSocketTaskUpdate handles task status updates from WebSocket
|
|
func (s *Server) handleWebSocketTaskUpdate(runnerID int64, taskUpdate WSTaskUpdate) {
|
|
// This can be used for progress updates
|
|
// For now, we'll just log it
|
|
log.Printf("Task %d update from runner %d: %s", taskUpdate.TaskID, runnerID, taskUpdate.Status)
|
|
}
|
|
|
|
// handleWebSocketTaskComplete handles task completion from WebSocket
|
|
func (s *Server) handleWebSocketTaskComplete(runnerID int64, taskUpdate WSTaskUpdate) {
|
|
// Verify task belongs to runner
|
|
var taskRunnerID sql.NullInt64
|
|
err := s.db.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskUpdate.TaskID).Scan(&taskRunnerID)
|
|
if err != nil || !taskRunnerID.Valid || taskRunnerID.Int64 != runnerID {
|
|
log.Printf("Task %d does not belong to runner %d", taskUpdate.TaskID, runnerID)
|
|
return
|
|
}
|
|
|
|
status := types.TaskStatusCompleted
|
|
if !taskUpdate.Success {
|
|
status = types.TaskStatusFailed
|
|
}
|
|
|
|
now := time.Now()
|
|
_, err = s.db.Exec(
|
|
`UPDATE tasks SET status = ?, output_path = ?, completed_at = ?, error_message = ? WHERE id = ?`,
|
|
status, taskUpdate.OutputPath, now, taskUpdate.Error, taskUpdate.TaskID,
|
|
)
|
|
if err != nil {
|
|
log.Printf("Failed to update task: %v", err)
|
|
return
|
|
}
|
|
|
|
// Update job progress
|
|
var jobID int64
|
|
var frameStart, frameEnd int
|
|
err = s.db.QueryRow(
|
|
`SELECT job_id, frame_start, frame_end FROM tasks WHERE id = ?`,
|
|
taskUpdate.TaskID,
|
|
).Scan(&jobID, &frameStart, &frameEnd)
|
|
if err == nil {
|
|
// Count total tasks excluding failed ones (failed tasks are retried, so we count them)
|
|
// We exclude tasks that are in a terminal failed state with max retries exceeded
|
|
var totalTasks, completedTasks int
|
|
s.db.QueryRow(
|
|
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?, ?, ?)`,
|
|
jobID, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed,
|
|
).Scan(&totalTasks)
|
|
s.db.QueryRow(
|
|
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
|
|
jobID, types.TaskStatusCompleted,
|
|
).Scan(&completedTasks)
|
|
|
|
// Handle edge cases: division by zero and all tasks cancelled
|
|
var progress float64
|
|
if totalTasks == 0 {
|
|
// All tasks cancelled or no tasks, set progress to 0
|
|
progress = 0.0
|
|
} else {
|
|
progress = float64(completedTasks) / float64(totalTasks) * 100.0
|
|
}
|
|
|
|
var jobStatus string
|
|
var outputFormat string
|
|
s.db.QueryRow(`SELECT output_format FROM jobs WHERE id = ?`, jobID).Scan(&outputFormat)
|
|
|
|
// Check if all non-cancelled tasks are completed
|
|
var pendingOrRunningTasks int
|
|
s.db.QueryRow(
|
|
`SELECT COUNT(*) FROM tasks
|
|
WHERE job_id = ? AND status IN (?, ?)`,
|
|
jobID, types.TaskStatusPending, types.TaskStatusRunning,
|
|
).Scan(&pendingOrRunningTasks)
|
|
|
|
if pendingOrRunningTasks == 0 && totalTasks > 0 {
|
|
// All tasks are either completed or failed/cancelled
|
|
jobStatus = string(types.JobStatusCompleted)
|
|
s.db.Exec(
|
|
`UPDATE jobs SET status = ?, progress = ?, completed_at = ? WHERE id = ?`,
|
|
jobStatus, progress, now, jobID,
|
|
)
|
|
|
|
if outputFormat == "MP4" {
|
|
// Create a video generation task instead of calling generateMP4Video directly
|
|
// This prevents race conditions when multiple runners complete frames simultaneously
|
|
videoTaskTimeout := 86400 // 24 hours for video generation
|
|
_, err := s.db.Exec(
|
|
`INSERT INTO tasks (job_id, frame_start, frame_end, task_type, status, timeout_seconds, max_retries)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
|
jobID, 0, 0, types.TaskTypeVideoGeneration, types.TaskStatusPending, videoTaskTimeout, 1,
|
|
)
|
|
if err != nil {
|
|
log.Printf("Failed to create video generation task for job %d: %v", jobID, err)
|
|
} else {
|
|
// Try to distribute the task immediately
|
|
go s.distributeTasksToRunners()
|
|
}
|
|
}
|
|
} else {
|
|
jobStatus = string(types.JobStatusRunning)
|
|
var startedAt sql.NullTime
|
|
s.db.QueryRow(`SELECT started_at FROM jobs WHERE id = ?`, jobID).Scan(&startedAt)
|
|
if !startedAt.Valid {
|
|
s.db.Exec(`UPDATE jobs SET started_at = ? WHERE id = ?`, now, jobID)
|
|
}
|
|
s.db.Exec(
|
|
`UPDATE jobs SET status = ?, progress = ? WHERE id = ?`,
|
|
jobStatus, progress, jobID,
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
// broadcastLogToFrontend broadcasts log to connected frontend clients
|
|
func (s *Server) broadcastLogToFrontend(taskID int64, logEntry WSLogEntry) {
|
|
// Get job_id from task
|
|
var jobID int64
|
|
err := s.db.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&jobID)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
key := fmt.Sprintf("%d:%d", jobID, taskID)
|
|
s.frontendConnsMu.RLock()
|
|
conn, exists := s.frontendConns[key]
|
|
s.frontendConnsMu.RUnlock()
|
|
|
|
if exists && conn != nil {
|
|
// Get full log entry from database for consistency
|
|
var log types.TaskLog
|
|
var runnerID sql.NullInt64
|
|
err := s.db.QueryRow(
|
|
`SELECT id, task_id, runner_id, log_level, message, step_name, created_at
|
|
FROM task_logs WHERE task_id = ? AND message = ? ORDER BY id DESC LIMIT 1`,
|
|
taskID, logEntry.Message,
|
|
).Scan(&log.ID, &log.TaskID, &runnerID, &log.LogLevel, &log.Message, &log.StepName, &log.CreatedAt)
|
|
if err == nil {
|
|
if runnerID.Valid {
|
|
log.RunnerID = &runnerID.Int64
|
|
}
|
|
msg := map[string]interface{}{
|
|
"type": "log",
|
|
"data": log,
|
|
"timestamp": time.Now().Unix(),
|
|
}
|
|
conn.WriteJSON(msg)
|
|
}
|
|
}
|
|
}
|
|
|
|
// distributeTasksToRunners pushes available tasks to connected runners
|
|
func (s *Server) distributeTasksToRunners() {
|
|
// Get all pending tasks
|
|
rows, err := s.db.Query(
|
|
`SELECT t.id, t.job_id, t.frame_start, t.frame_end, t.task_type, j.allow_parallel_runners, j.status as job_status
|
|
FROM tasks t
|
|
JOIN jobs j ON t.job_id = j.id
|
|
WHERE t.status = ? AND j.status != ?
|
|
ORDER BY t.created_at ASC
|
|
LIMIT 100`,
|
|
types.TaskStatusPending, types.JobStatusCancelled,
|
|
)
|
|
if err != nil {
|
|
log.Printf("Failed to query pending tasks: %v", err)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
var pendingTasks []struct {
|
|
TaskID int64
|
|
JobID int64
|
|
FrameStart int
|
|
FrameEnd int
|
|
TaskType string
|
|
AllowParallelRunners bool
|
|
}
|
|
|
|
for rows.Next() {
|
|
var t struct {
|
|
TaskID int64
|
|
JobID int64
|
|
FrameStart int
|
|
FrameEnd int
|
|
TaskType string
|
|
AllowParallelRunners bool
|
|
}
|
|
var allowParallel int
|
|
var jobStatus string
|
|
err := rows.Scan(&t.TaskID, &t.JobID, &t.FrameStart, &t.FrameEnd, &t.TaskType, &allowParallel, &jobStatus)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
t.AllowParallelRunners = allowParallel == 1
|
|
pendingTasks = append(pendingTasks, t)
|
|
}
|
|
|
|
if len(pendingTasks) == 0 {
|
|
return
|
|
}
|
|
|
|
// Get connected runners
|
|
s.runnerConnsMu.RLock()
|
|
connectedRunners := make([]int64, 0, len(s.runnerConns))
|
|
for runnerID := range s.runnerConns {
|
|
connectedRunners = append(connectedRunners, runnerID)
|
|
}
|
|
s.runnerConnsMu.RUnlock()
|
|
|
|
if len(connectedRunners) == 0 {
|
|
return
|
|
}
|
|
|
|
// Distribute tasks to runners
|
|
for _, task := range pendingTasks {
|
|
// Find available runner
|
|
var selectedRunnerID int64
|
|
for _, runnerID := range connectedRunners {
|
|
// Check if runner is busy (has running tasks)
|
|
var runningCount int
|
|
s.db.QueryRow(
|
|
`SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status = ?`,
|
|
runnerID, types.TaskStatusRunning,
|
|
).Scan(&runningCount)
|
|
|
|
if runningCount > 0 {
|
|
continue // Runner is busy
|
|
}
|
|
|
|
// For non-parallel jobs, check if runner already has tasks from this job
|
|
if !task.AllowParallelRunners {
|
|
var jobTaskCount int
|
|
s.db.QueryRow(
|
|
`SELECT COUNT(*) FROM tasks
|
|
WHERE job_id = ? AND runner_id = ? AND status IN (?, ?)`,
|
|
task.JobID, runnerID, types.TaskStatusPending, types.TaskStatusRunning,
|
|
).Scan(&jobTaskCount)
|
|
if jobTaskCount > 0 {
|
|
continue // Another runner is working on this job
|
|
}
|
|
}
|
|
|
|
selectedRunnerID = runnerID
|
|
break
|
|
}
|
|
|
|
if selectedRunnerID == 0 {
|
|
continue // No available runner
|
|
}
|
|
|
|
// Atomically assign task to runner using UPDATE with WHERE runner_id IS NULL
|
|
// This prevents race conditions when multiple goroutines try to assign the same task
|
|
now := time.Now()
|
|
result, err := s.db.Exec(
|
|
`UPDATE tasks SET runner_id = ?, status = ?, started_at = ?
|
|
WHERE id = ? AND runner_id IS NULL AND status = ?`,
|
|
selectedRunnerID, types.TaskStatusRunning, now, task.TaskID, types.TaskStatusPending,
|
|
)
|
|
if err != nil {
|
|
log.Printf("Failed to atomically assign task %d: %v", task.TaskID, err)
|
|
continue
|
|
}
|
|
|
|
// Check if the update actually affected a row (task was successfully assigned)
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
log.Printf("Failed to get rows affected for task %d: %v", task.TaskID, err)
|
|
continue
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
// Task was already assigned by another goroutine, skip
|
|
continue
|
|
}
|
|
|
|
// Task was successfully assigned, send via WebSocket
|
|
if err := s.assignTaskToRunner(selectedRunnerID, task.TaskID); err != nil {
|
|
log.Printf("Failed to send task %d to runner %d: %v", task.TaskID, selectedRunnerID, err)
|
|
// Rollback the assignment if WebSocket send fails
|
|
s.db.Exec(
|
|
`UPDATE tasks SET runner_id = NULL, status = ?, started_at = NULL
|
|
WHERE id = ?`,
|
|
types.TaskStatusPending, task.TaskID,
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
// assignTaskToRunner sends a task to a runner via WebSocket
|
|
func (s *Server) assignTaskToRunner(runnerID int64, taskID int64) error {
|
|
s.runnerConnsMu.RLock()
|
|
conn, exists := s.runnerConns[runnerID]
|
|
s.runnerConnsMu.RUnlock()
|
|
|
|
if !exists {
|
|
return fmt.Errorf("runner %d not connected", runnerID)
|
|
}
|
|
|
|
// Get task details
|
|
var task WSTaskAssignment
|
|
var jobName, outputFormat, taskType string
|
|
err := s.db.QueryRow(
|
|
`SELECT t.job_id, t.frame_start, t.frame_end, t.task_type, j.name, j.output_format
|
|
FROM tasks t JOIN jobs j ON t.job_id = j.id WHERE t.id = ?`,
|
|
taskID,
|
|
).Scan(&task.JobID, &task.FrameStart, &task.FrameEnd, &taskType, &jobName, &outputFormat)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
task.TaskID = taskID
|
|
task.JobName = jobName
|
|
task.OutputFormat = outputFormat
|
|
task.TaskType = taskType
|
|
|
|
// Get input files
|
|
rows, err := s.db.Query(
|
|
`SELECT file_path FROM job_files WHERE job_id = ? AND file_type = ?`,
|
|
task.JobID, types.JobFileTypeInput,
|
|
)
|
|
if err == nil {
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var filePath string
|
|
if err := rows.Scan(&filePath); err == nil {
|
|
task.InputFiles = append(task.InputFiles, filePath)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Note: Task is already assigned in database by the atomic update in distributeTasksToRunners
|
|
// We just need to verify it's still assigned to this runner
|
|
var assignedRunnerID sql.NullInt64
|
|
err = s.db.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskID).Scan(&assignedRunnerID)
|
|
if err != nil {
|
|
return fmt.Errorf("task not found: %w", err)
|
|
}
|
|
if !assignedRunnerID.Valid || assignedRunnerID.Int64 != runnerID {
|
|
return fmt.Errorf("task %d is not assigned to runner %d", taskID, runnerID)
|
|
}
|
|
|
|
// Send task via WebSocket
|
|
msg := WSMessage{
|
|
Type: "task_assignment",
|
|
Timestamp: time.Now().Unix(),
|
|
}
|
|
msg.Data, _ = json.Marshal(task)
|
|
return conn.WriteJSON(msg)
|
|
}
|