package api import ( "database/sql" "encoding/json" "fmt" "log" "net/http" "strconv" "sync" "time" authpkg "fuego/internal/auth" "fuego/internal/database" "fuego/internal/storage" "fuego/pkg/types" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/cors" "github.com/gorilla/websocket" ) // Server represents the API server type Server struct { db *database.DB auth *authpkg.Auth secrets *authpkg.Secrets storage *storage.Storage router *chi.Mux // WebSocket connections wsUpgrader websocket.Upgrader runnerConns map[int64]*websocket.Conn runnerConnsMu sync.RWMutex frontendConns map[string]*websocket.Conn // key: "jobId:taskId" frontendConnsMu sync.RWMutex } // NewServer creates a new API server func NewServer(db *database.DB, auth *authpkg.Auth, storage *storage.Storage) (*Server, error) { secrets, err := authpkg.NewSecrets(db.DB) if err != nil { return nil, fmt.Errorf("failed to initialize secrets: %w", err) } s := &Server{ db: db, auth: auth, secrets: secrets, storage: storage, router: chi.NewRouter(), wsUpgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true // Allow all origins for now }, ReadBufferSize: 1024, WriteBufferSize: 1024, }, runnerConns: make(map[int64]*websocket.Conn), frontendConns: make(map[string]*websocket.Conn), } s.setupMiddleware() s.setupRoutes() s.StartBackgroundTasks() return s, nil } // setupMiddleware configures middleware func (s *Server) setupMiddleware() { s.router.Use(middleware.Logger) s.router.Use(middleware.Recoverer) s.router.Use(middleware.Timeout(60 * time.Second)) s.router.Use(cors.Handler(cors.Options{ AllowedOrigins: []string{"*"}, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"}, ExposedHeaders: []string{"Link"}, AllowCredentials: true, MaxAge: 300, })) } // setupRoutes configures routes func (s *Server) setupRoutes() { // Public routes s.router.Route("/api/auth", func(r chi.Router) { r.Get("/google/login", s.handleGoogleLogin) r.Get("/google/callback", s.handleGoogleCallback) r.Get("/discord/login", s.handleDiscordLogin) r.Get("/discord/callback", s.handleDiscordCallback) r.Post("/logout", s.handleLogout) r.Get("/me", s.handleGetMe) }) // Protected routes s.router.Route("/api/jobs", func(r chi.Router) { r.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(s.auth.Middleware(next.ServeHTTP)) }) r.Post("/", s.handleCreateJob) r.Get("/", s.handleListJobs) r.Get("/{id}", s.handleGetJob) r.Delete("/{id}", s.handleCancelJob) r.Post("/{id}/upload", s.handleUploadJobFile) r.Get("/{id}/files", s.handleListJobFiles) r.Get("/{id}/files/{fileId}/download", s.handleDownloadJobFile) r.Get("/{id}/video", s.handleStreamVideo) r.Get("/{id}/metadata", s.handleGetJobMetadata) r.Get("/{id}/tasks", s.handleListJobTasks) r.Get("/{id}/tasks/{taskId}/logs", s.handleGetTaskLogs) r.Get("/{id}/tasks/{taskId}/logs/ws", s.handleStreamTaskLogsWebSocket) r.Get("/{id}/tasks/{taskId}/steps", s.handleGetTaskSteps) r.Post("/{id}/tasks/{taskId}/retry", s.handleRetryTask) }) s.router.Route("/api/runners", func(r chi.Router) { r.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(s.auth.Middleware(next.ServeHTTP)) }) r.Get("/", s.handleListRunners) }) // Admin routes s.router.Route("/api/admin", func(r chi.Router) { r.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(s.auth.AdminMiddleware(next.ServeHTTP)) }) r.Route("/runners", func(r chi.Router) { r.Route("/tokens", func(r chi.Router) { r.Post("/", s.handleGenerateRegistrationToken) r.Get("/", s.handleListRegistrationTokens) r.Delete("/{id}", s.handleRevokeRegistrationToken) }) r.Get("/", s.handleListRunnersAdmin) r.Post("/{id}/verify", s.handleVerifyRunner) r.Delete("/{id}", s.handleDeleteRunner) }) }) // Runner API s.router.Route("/api/runner", func(r chi.Router) { // Registration doesn't require auth (uses token) r.Post("/register", s.handleRegisterRunner) // WebSocket endpoint (auth handled in handler) r.Get("/ws", s.handleRunnerWebSocket) // File operations still use HTTP (WebSocket not suitable for large files) r.Group(func(r chi.Router) { r.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(s.runnerAuthMiddleware(next.ServeHTTP)) }) r.Post("/tasks/{id}/progress", s.handleUpdateTaskProgress) r.Post("/tasks/{id}/steps", s.handleUpdateTaskStep) r.Get("/files/{jobId}/{fileName}", s.handleDownloadFileForRunner) r.Post("/files/{jobId}/upload", s.handleUploadFileFromRunner) r.Get("/jobs/{jobId}/status", s.handleGetJobStatusForRunner) r.Get("/jobs/{jobId}/files", s.handleGetJobFilesForRunner) r.Post("/jobs/{jobId}/metadata", s.handleSubmitMetadata) }) }) // Serve static files (built React app) s.router.Handle("/*", http.FileServer(http.Dir("./web/dist"))) } // ServeHTTP implements http.Handler func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.router.ServeHTTP(w, r) } // JSON response helpers func (s *Server) respondJSON(w http.ResponseWriter, status int, data interface{}) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) if err := json.NewEncoder(w).Encode(data); err != nil { log.Printf("Failed to encode JSON response: %v", err) } } func (s *Server) respondError(w http.ResponseWriter, status int, message string) { s.respondJSON(w, status, map[string]string{"error": message}) } // Auth handlers func (s *Server) handleGoogleLogin(w http.ResponseWriter, r *http.Request) { url, err := s.auth.GoogleLoginURL() if err != nil { s.respondError(w, http.StatusInternalServerError, err.Error()) return } http.Redirect(w, r, url, http.StatusFound) } func (s *Server) handleGoogleCallback(w http.ResponseWriter, r *http.Request) { code := r.URL.Query().Get("code") if code == "" { s.respondError(w, http.StatusBadRequest, "Missing code parameter") return } session, err := s.auth.GoogleCallback(r.Context(), code) if err != nil { s.respondError(w, http.StatusInternalServerError, err.Error()) return } sessionID := s.auth.CreateSession(session) http.SetCookie(w, &http.Cookie{ Name: "session_id", Value: sessionID, Path: "/", MaxAge: 86400, HttpOnly: true, SameSite: http.SameSiteLaxMode, }) http.Redirect(w, r, "/", http.StatusFound) } func (s *Server) handleDiscordLogin(w http.ResponseWriter, r *http.Request) { url, err := s.auth.DiscordLoginURL() if err != nil { s.respondError(w, http.StatusInternalServerError, err.Error()) return } http.Redirect(w, r, url, http.StatusFound) } func (s *Server) handleDiscordCallback(w http.ResponseWriter, r *http.Request) { code := r.URL.Query().Get("code") if code == "" { s.respondError(w, http.StatusBadRequest, "Missing code parameter") return } session, err := s.auth.DiscordCallback(r.Context(), code) if err != nil { s.respondError(w, http.StatusInternalServerError, err.Error()) return } sessionID := s.auth.CreateSession(session) http.SetCookie(w, &http.Cookie{ Name: "session_id", Value: sessionID, Path: "/", MaxAge: 86400, HttpOnly: true, SameSite: http.SameSiteLaxMode, }) http.Redirect(w, r, "/", http.StatusFound) } func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("session_id") if err == nil { s.auth.DeleteSession(cookie.Value) } http.SetCookie(w, &http.Cookie{ Name: "session_id", Value: "", Path: "/", MaxAge: -1, HttpOnly: true, }) s.respondJSON(w, http.StatusOK, map[string]string{"message": "Logged out"}) } func (s *Server) handleGetMe(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("session_id") if err != nil { s.respondError(w, http.StatusUnauthorized, "Not authenticated") return } session, ok := s.auth.GetSession(cookie.Value) if !ok { s.respondError(w, http.StatusUnauthorized, "Invalid session") return } s.respondJSON(w, http.StatusOK, map[string]interface{}{ "id": session.UserID, "email": session.Email, "name": session.Name, "is_admin": session.IsAdmin, }) } // Helper to get user ID from context func getUserID(r *http.Request) (int64, error) { userID, ok := authpkg.GetUserID(r.Context()) if !ok { return 0, fmt.Errorf("user ID not found in context") } return userID, nil } // Helper to parse ID from URL func parseID(r *http.Request, param string) (int64, error) { idStr := chi.URLParam(r, param) id, err := strconv.ParseInt(idStr, 10, 64) if err != nil { return 0, fmt.Errorf("invalid ID: %s", idStr) } return id, nil } // StartBackgroundTasks starts background goroutines for error recovery func (s *Server) StartBackgroundTasks() { go s.recoverStuckTasks() } // recoverStuckTasks periodically checks for dead runners and stuck tasks func (s *Server) recoverStuckTasks() { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() // Also distribute tasks every 5 seconds distributeTicker := time.NewTicker(5 * time.Second) defer distributeTicker.Stop() go func() { for range distributeTicker.C { s.distributeTasksToRunners() } }() for range ticker.C { func() { defer func() { if r := recover(); r != nil { log.Printf("Panic in recoverStuckTasks: %v", r) } }() // Find dead runners (no heartbeat for 90 seconds) rows, err := s.db.Query( `SELECT id FROM runners WHERE last_heartbeat < datetime('now', '-90 seconds') AND status = ?`, types.RunnerStatusOnline, ) if err != nil { log.Printf("Failed to query dead runners: %v", err) return } defer rows.Close() var deadRunnerIDs []int64 for rows.Next() { var runnerID int64 if err := rows.Scan(&runnerID); err == nil { deadRunnerIDs = append(deadRunnerIDs, runnerID) } } rows.Close() if len(deadRunnerIDs) == 0 { // Check for task timeouts s.recoverTaskTimeouts() return } // Reset tasks assigned to dead runners for _, runnerID := range deadRunnerIDs { // Get tasks assigned to this runner taskRows, err := s.db.Query( `SELECT id, retry_count, max_retries FROM tasks WHERE runner_id = ? AND status = ?`, runnerID, types.TaskStatusRunning, ) if err != nil { log.Printf("Failed to query tasks for runner %d: %v", runnerID, err) continue } var tasksToReset []struct { ID int64 RetryCount int MaxRetries int } for taskRows.Next() { var t struct { ID int64 RetryCount int MaxRetries int } if err := taskRows.Scan(&t.ID, &t.RetryCount, &t.MaxRetries); err == nil { tasksToReset = append(tasksToReset, t) } } taskRows.Close() // Reset or fail tasks for _, task := range tasksToReset { if task.RetryCount >= task.MaxRetries { // Mark as failed _, err = s.db.Exec( `UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL WHERE id = ?`, types.TaskStatusFailed, "Runner died, max retries exceeded", task.ID, ) if err != nil { log.Printf("Failed to mark task %d as failed: %v", task.ID, err) } } else { // Reset to pending _, err = s.db.Exec( `UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL, retry_count = retry_count + 1 WHERE id = ?`, types.TaskStatusPending, task.ID, ) if err != nil { log.Printf("Failed to reset task %d: %v", task.ID, err) } else { // Add log entry _, _ = s.db.Exec( `INSERT INTO task_logs (task_id, log_level, message, step_name, created_at) VALUES (?, ?, ?, ?, ?)`, task.ID, types.LogLevelWarn, fmt.Sprintf("Runner died, task reset (retry %d/%d)", task.RetryCount+1, task.MaxRetries), "", time.Now(), ) } } } // Mark runner as offline _, _ = s.db.Exec( `UPDATE runners SET status = ? WHERE id = ?`, types.RunnerStatusOffline, runnerID, ) } // Check for task timeouts s.recoverTaskTimeouts() // Distribute newly recovered tasks s.distributeTasksToRunners() }() } } // recoverTaskTimeouts handles tasks that have exceeded their timeout func (s *Server) recoverTaskTimeouts() { // Find tasks running longer than their timeout rows, err := s.db.Query( `SELECT t.id, t.runner_id, t.retry_count, t.max_retries, t.timeout_seconds, t.started_at FROM tasks t WHERE t.status = ? AND t.started_at IS NOT NULL AND (t.timeout_seconds IS NULL OR datetime(t.started_at, '+' || t.timeout_seconds || ' seconds') < datetime('now'))`, types.TaskStatusRunning, ) if err != nil { log.Printf("Failed to query timed out tasks: %v", err) return } defer rows.Close() for rows.Next() { var taskID int64 var runnerID sql.NullInt64 var retryCount, maxRetries int var timeoutSeconds sql.NullInt64 var startedAt time.Time err := rows.Scan(&taskID, &runnerID, &retryCount, &maxRetries, &timeoutSeconds, &startedAt) if err != nil { continue } // Use default timeout if not set (5 minutes for frame tasks, 24 hours for FFmpeg) timeout := 300 // 5 minutes default if timeoutSeconds.Valid { timeout = int(timeoutSeconds.Int64) } // Check if actually timed out if time.Since(startedAt).Seconds() < float64(timeout) { continue } if retryCount >= maxRetries { // Mark as failed _, err = s.db.Exec( `UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL WHERE id = ?`, types.TaskStatusFailed, "Task timeout exceeded, max retries reached", taskID, ) } else { // Reset to pending _, err = s.db.Exec( `UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL, retry_count = retry_count + 1 WHERE id = ?`, types.TaskStatusPending, taskID, ) if err == nil { // Add log entry _, _ = s.db.Exec( `INSERT INTO task_logs (task_id, log_level, message, step_name, created_at) VALUES (?, ?, ?, ?, ?)`, taskID, types.LogLevelWarn, fmt.Sprintf("Task timeout exceeded, resetting (retry %d/%d)", retryCount+1, maxRetries), "", time.Now(), ) } } } }