redo
This commit is contained in:
@@ -1,19 +1,23 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/go-chi/cors"
|
||||
"github.com/gorilla/websocket"
|
||||
"fuego/internal/auth"
|
||||
"fuego/internal/database"
|
||||
"fuego/internal/storage"
|
||||
"fuego/pkg/types"
|
||||
)
|
||||
|
||||
// Server represents the API server
|
||||
@@ -23,6 +27,13 @@ type Server struct {
|
||||
secrets *auth.Secrets
|
||||
storage *storage.Storage
|
||||
router *chi.Mux
|
||||
|
||||
// WebSocket connections
|
||||
wsUpgrader websocket.Upgrader
|
||||
runnerConns map[int64]*websocket.Conn
|
||||
runnerConnsMu sync.RWMutex
|
||||
frontendConns map[string]*websocket.Conn // key: "jobId:taskId"
|
||||
frontendConnsMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewServer creates a new API server
|
||||
@@ -38,10 +49,20 @@ func NewServer(db *database.DB, auth *auth.Auth, storage *storage.Storage) (*Ser
|
||||
secrets: secrets,
|
||||
storage: storage,
|
||||
router: chi.NewRouter(),
|
||||
wsUpgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // Allow all origins for now
|
||||
},
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
},
|
||||
runnerConns: make(map[int64]*websocket.Conn),
|
||||
frontendConns: make(map[string]*websocket.Conn),
|
||||
}
|
||||
|
||||
s.setupMiddleware()
|
||||
s.setupRoutes()
|
||||
s.StartBackgroundTasks()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
@@ -87,6 +108,10 @@ func (s *Server) setupRoutes() {
|
||||
r.Get("/{id}/files", s.handleListJobFiles)
|
||||
r.Get("/{id}/files/{fileId}/download", s.handleDownloadJobFile)
|
||||
r.Get("/{id}/video", s.handleStreamVideo)
|
||||
r.Get("/{id}/tasks/{taskId}/logs", s.handleGetTaskLogs)
|
||||
r.Get("/{id}/tasks/{taskId}/logs/ws", s.handleStreamTaskLogsWebSocket)
|
||||
r.Get("/{id}/tasks/{taskId}/steps", s.handleGetTaskSteps)
|
||||
r.Post("/{id}/tasks/{taskId}/retry", s.handleRetryTask)
|
||||
})
|
||||
|
||||
s.router.Route("/api/runners", func(r chi.Router) {
|
||||
@@ -118,14 +143,14 @@ func (s *Server) setupRoutes() {
|
||||
// Registration doesn't require auth (uses token)
|
||||
r.Post("/register", s.handleRegisterRunner)
|
||||
|
||||
// All other endpoints require runner authentication
|
||||
// WebSocket endpoint (auth handled in handler)
|
||||
r.Get("/ws", s.handleRunnerWebSocket)
|
||||
|
||||
// File operations still use HTTP (WebSocket not suitable for large files)
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(s.runnerAuthMiddleware(next.ServeHTTP))
|
||||
})
|
||||
r.Post("/heartbeat", s.handleRunnerHeartbeat)
|
||||
r.Get("/tasks", s.handleGetRunnerTasks)
|
||||
r.Post("/tasks/{id}/complete", s.handleCompleteTask)
|
||||
r.Post("/tasks/{id}/progress", s.handleUpdateTaskProgress)
|
||||
r.Get("/files/{jobId}/{fileName}", s.handleDownloadFileForRunner)
|
||||
r.Post("/files/{jobId}/upload", s.handleUploadFileFromRunner)
|
||||
@@ -282,3 +307,207 @@ func parseID(r *http.Request, param string) (int64, error) {
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// StartBackgroundTasks starts background goroutines for error recovery
|
||||
func (s *Server) StartBackgroundTasks() {
|
||||
go s.recoverStuckTasks()
|
||||
}
|
||||
|
||||
// recoverStuckTasks periodically checks for dead runners and stuck tasks
|
||||
func (s *Server) recoverStuckTasks() {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Also distribute tasks every 5 seconds
|
||||
distributeTicker := time.NewTicker(5 * time.Second)
|
||||
defer distributeTicker.Stop()
|
||||
|
||||
go func() {
|
||||
for range distributeTicker.C {
|
||||
s.distributeTasksToRunners()
|
||||
}
|
||||
}()
|
||||
|
||||
for range ticker.C {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("Panic in recoverStuckTasks: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Find dead runners (no heartbeat for 90 seconds)
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id FROM runners
|
||||
WHERE last_heartbeat < datetime('now', '-90 seconds')
|
||||
AND status = ?`,
|
||||
types.RunnerStatusOnline,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to query dead runners: %v", err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var deadRunnerIDs []int64
|
||||
for rows.Next() {
|
||||
var runnerID int64
|
||||
if err := rows.Scan(&runnerID); err == nil {
|
||||
deadRunnerIDs = append(deadRunnerIDs, runnerID)
|
||||
}
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
if len(deadRunnerIDs) == 0 {
|
||||
// Check for task timeouts
|
||||
s.recoverTaskTimeouts()
|
||||
return
|
||||
}
|
||||
|
||||
// Reset tasks assigned to dead runners
|
||||
for _, runnerID := range deadRunnerIDs {
|
||||
// Get tasks assigned to this runner
|
||||
taskRows, err := s.db.Query(
|
||||
`SELECT id, retry_count, max_retries FROM tasks
|
||||
WHERE runner_id = ? AND status = ?`,
|
||||
runnerID, types.TaskStatusRunning,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to query tasks for runner %d: %v", runnerID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
var tasksToReset []struct {
|
||||
ID int64
|
||||
RetryCount int
|
||||
MaxRetries int
|
||||
}
|
||||
|
||||
for taskRows.Next() {
|
||||
var t struct {
|
||||
ID int64
|
||||
RetryCount int
|
||||
MaxRetries int
|
||||
}
|
||||
if err := taskRows.Scan(&t.ID, &t.RetryCount, &t.MaxRetries); err == nil {
|
||||
tasksToReset = append(tasksToReset, t)
|
||||
}
|
||||
}
|
||||
taskRows.Close()
|
||||
|
||||
// Reset or fail tasks
|
||||
for _, task := range tasksToReset {
|
||||
if task.RetryCount >= task.MaxRetries {
|
||||
// Mark as failed
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL
|
||||
WHERE id = ?`,
|
||||
types.TaskStatusFailed, "Runner died, max retries exceeded", task.ID,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to mark task %d as failed: %v", task.ID, err)
|
||||
}
|
||||
} else {
|
||||
// Reset to pending
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL,
|
||||
retry_count = retry_count + 1 WHERE id = ?`,
|
||||
types.TaskStatusPending, task.ID,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to reset task %d: %v", task.ID, err)
|
||||
} else {
|
||||
// Add log entry
|
||||
_, _ = s.db.Exec(
|
||||
`INSERT INTO task_logs (task_id, log_level, message, step_name, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
task.ID, types.LogLevelWarn, fmt.Sprintf("Runner died, task reset (retry %d/%d)", task.RetryCount+1, task.MaxRetries),
|
||||
"", time.Now(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mark runner as offline
|
||||
_, _ = s.db.Exec(
|
||||
`UPDATE runners SET status = ? WHERE id = ?`,
|
||||
types.RunnerStatusOffline, runnerID,
|
||||
)
|
||||
}
|
||||
|
||||
// Check for task timeouts
|
||||
s.recoverTaskTimeouts()
|
||||
|
||||
// Distribute newly recovered tasks
|
||||
s.distributeTasksToRunners()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// recoverTaskTimeouts handles tasks that have exceeded their timeout
|
||||
func (s *Server) recoverTaskTimeouts() {
|
||||
// Find tasks running longer than their timeout
|
||||
rows, err := s.db.Query(
|
||||
`SELECT t.id, t.runner_id, t.retry_count, t.max_retries, t.timeout_seconds, t.started_at
|
||||
FROM tasks t
|
||||
WHERE t.status = ?
|
||||
AND t.started_at IS NOT NULL
|
||||
AND (t.timeout_seconds IS NULL OR
|
||||
datetime(t.started_at, '+' || t.timeout_seconds || ' seconds') < datetime('now'))`,
|
||||
types.TaskStatusRunning,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to query timed out tasks: %v", err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var taskID int64
|
||||
var runnerID sql.NullInt64
|
||||
var retryCount, maxRetries int
|
||||
var timeoutSeconds sql.NullInt64
|
||||
var startedAt time.Time
|
||||
|
||||
err := rows.Scan(&taskID, &runnerID, &retryCount, &maxRetries, &timeoutSeconds, &startedAt)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Use default timeout if not set (5 minutes for frame tasks, 24 hours for FFmpeg)
|
||||
timeout := 300 // 5 minutes default
|
||||
if timeoutSeconds.Valid {
|
||||
timeout = int(timeoutSeconds.Int64)
|
||||
}
|
||||
|
||||
// Check if actually timed out
|
||||
if time.Since(startedAt).Seconds() < float64(timeout) {
|
||||
continue
|
||||
}
|
||||
|
||||
if retryCount >= maxRetries {
|
||||
// Mark as failed
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL
|
||||
WHERE id = ?`,
|
||||
types.TaskStatusFailed, "Task timeout exceeded, max retries reached", taskID,
|
||||
)
|
||||
} else {
|
||||
// Reset to pending
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL,
|
||||
retry_count = retry_count + 1 WHERE id = ?`,
|
||||
types.TaskStatusPending, taskID,
|
||||
)
|
||||
if err == nil {
|
||||
// Add log entry
|
||||
_, _ = s.db.Exec(
|
||||
`INSERT INTO task_logs (task_id, log_level, message, step_name, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
taskID, types.LogLevelWarn, fmt.Sprintf("Task timeout exceeded, resetting (retry %d/%d)", retryCount+1, maxRetries),
|
||||
"", time.Now(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user