659 lines
19 KiB
Go
659 lines
19 KiB
Go
package api
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
authpkg "jiggablend/internal/auth"
|
|
"jiggablend/internal/database"
|
|
"jiggablend/internal/storage"
|
|
"jiggablend/pkg/types"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
"github.com/go-chi/cors"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// Server represents the API server
|
|
type Server struct {
|
|
db *database.DB
|
|
auth *authpkg.Auth
|
|
secrets *authpkg.Secrets
|
|
storage *storage.Storage
|
|
router *chi.Mux
|
|
|
|
// WebSocket connections
|
|
wsUpgrader websocket.Upgrader
|
|
runnerConns map[int64]*websocket.Conn
|
|
runnerConnsMu sync.RWMutex
|
|
frontendConns map[string]*websocket.Conn // key: "jobId:taskId"
|
|
frontendConnsMu sync.RWMutex
|
|
// Mutexes for each frontend connection to serialize writes
|
|
frontendConnsWriteMu map[string]*sync.Mutex // key: "jobId:taskId"
|
|
frontendConnsWriteMuMu sync.RWMutex
|
|
// Throttling for progress updates (per job)
|
|
progressUpdateTimes map[int64]time.Time // key: jobID
|
|
progressUpdateTimesMu sync.RWMutex
|
|
}
|
|
|
|
// NewServer creates a new API server
|
|
func NewServer(db *database.DB, auth *authpkg.Auth, storage *storage.Storage) (*Server, error) {
|
|
secrets, err := authpkg.NewSecrets(db.DB)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to initialize secrets: %w", err)
|
|
}
|
|
|
|
s := &Server{
|
|
db: db,
|
|
auth: auth,
|
|
secrets: secrets,
|
|
storage: storage,
|
|
router: chi.NewRouter(),
|
|
wsUpgrader: websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
return true // Allow all origins for now
|
|
},
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
},
|
|
runnerConns: make(map[int64]*websocket.Conn),
|
|
frontendConns: make(map[string]*websocket.Conn),
|
|
frontendConnsWriteMu: make(map[string]*sync.Mutex),
|
|
progressUpdateTimes: make(map[int64]time.Time),
|
|
}
|
|
|
|
s.setupMiddleware()
|
|
s.setupRoutes()
|
|
s.StartBackgroundTasks()
|
|
|
|
return s, nil
|
|
}
|
|
|
|
// setupMiddleware configures middleware
|
|
func (s *Server) setupMiddleware() {
|
|
s.router.Use(middleware.Logger)
|
|
s.router.Use(middleware.Recoverer)
|
|
// Note: Timeout middleware is NOT applied globally to avoid conflicts with WebSocket connections
|
|
// WebSocket connections are long-lived and should not have HTTP timeouts
|
|
|
|
s.router.Use(cors.Handler(cors.Options{
|
|
AllowedOrigins: []string{"*"},
|
|
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
|
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "Range"},
|
|
ExposedHeaders: []string{"Link", "Content-Range", "Accept-Ranges", "Content-Length"},
|
|
AllowCredentials: true,
|
|
MaxAge: 300,
|
|
}))
|
|
}
|
|
|
|
// setupRoutes configures routes
|
|
func (s *Server) setupRoutes() {
|
|
// Public routes
|
|
s.router.Route("/api/auth", func(r chi.Router) {
|
|
r.Get("/providers", s.handleGetAuthProviders)
|
|
r.Get("/google/login", s.handleGoogleLogin)
|
|
r.Get("/google/callback", s.handleGoogleCallback)
|
|
r.Get("/discord/login", s.handleDiscordLogin)
|
|
r.Get("/discord/callback", s.handleDiscordCallback)
|
|
r.Get("/local/available", s.handleLocalLoginAvailable)
|
|
r.Post("/local/register", s.handleLocalRegister)
|
|
r.Post("/local/login", s.handleLocalLogin)
|
|
r.Post("/logout", s.handleLogout)
|
|
r.Get("/me", s.handleGetMe)
|
|
r.Post("/change-password", s.handleChangePassword)
|
|
})
|
|
|
|
// Protected routes
|
|
s.router.Route("/api/jobs", func(r chi.Router) {
|
|
r.Use(func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(s.auth.Middleware(next.ServeHTTP))
|
|
})
|
|
r.Post("/", s.handleCreateJob)
|
|
r.Get("/", s.handleListJobs)
|
|
r.Get("/{id}", s.handleGetJob)
|
|
r.Delete("/{id}", s.handleCancelJob)
|
|
r.Post("/{id}/delete", s.handleDeleteJob)
|
|
r.Post("/{id}/upload", s.handleUploadJobFile)
|
|
r.Get("/{id}/files", s.handleListJobFiles)
|
|
r.Get("/{id}/files/{fileId}/download", s.handleDownloadJobFile)
|
|
r.Get("/{id}/video", s.handleStreamVideo)
|
|
r.Get("/{id}/metadata", s.handleGetJobMetadata)
|
|
r.Get("/{id}/tasks", s.handleListJobTasks)
|
|
r.Get("/{id}/tasks/{taskId}/logs", s.handleGetTaskLogs)
|
|
// WebSocket route - no timeout middleware (long-lived connection)
|
|
r.With(func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Remove timeout middleware for WebSocket
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}).Get("/{id}/tasks/{taskId}/logs/ws", s.handleStreamTaskLogsWebSocket)
|
|
r.Get("/{id}/tasks/{taskId}/steps", s.handleGetTaskSteps)
|
|
r.Post("/{id}/tasks/{taskId}/retry", s.handleRetryTask)
|
|
})
|
|
|
|
// Admin routes
|
|
s.router.Route("/api/admin", func(r chi.Router) {
|
|
r.Use(func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(s.auth.AdminMiddleware(next.ServeHTTP))
|
|
})
|
|
r.Route("/runners", func(r chi.Router) {
|
|
r.Route("/tokens", func(r chi.Router) {
|
|
r.Post("/", s.handleGenerateRegistrationToken)
|
|
r.Get("/", s.handleListRegistrationTokens)
|
|
r.Delete("/{id}", s.handleRevokeRegistrationToken)
|
|
})
|
|
r.Get("/", s.handleListRunnersAdmin)
|
|
r.Post("/{id}/verify", s.handleVerifyRunner)
|
|
r.Delete("/{id}", s.handleDeleteRunner)
|
|
})
|
|
r.Route("/users", func(r chi.Router) {
|
|
r.Get("/", s.handleListUsers)
|
|
r.Get("/{id}/jobs", s.handleGetUserJobs)
|
|
r.Post("/{id}/admin", s.handleSetUserAdminStatus)
|
|
})
|
|
r.Route("/settings", func(r chi.Router) {
|
|
r.Get("/registration", s.handleGetRegistrationEnabled)
|
|
r.Post("/registration", s.handleSetRegistrationEnabled)
|
|
})
|
|
})
|
|
|
|
// Runner API
|
|
s.router.Route("/api/runner", func(r chi.Router) {
|
|
// Registration doesn't require auth (uses token)
|
|
r.With(middleware.Timeout(60*time.Second)).Post("/register", s.handleRegisterRunner)
|
|
|
|
// WebSocket endpoint (auth handled in handler) - no timeout middleware
|
|
r.Get("/ws", s.handleRunnerWebSocket)
|
|
|
|
// File operations still use HTTP (WebSocket not suitable for large files)
|
|
r.Group(func(r chi.Router) {
|
|
r.Use(func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(s.runnerAuthMiddleware(next.ServeHTTP))
|
|
})
|
|
r.Post("/tasks/{id}/progress", s.handleUpdateTaskProgress)
|
|
r.Post("/tasks/{id}/steps", s.handleUpdateTaskStep)
|
|
r.Get("/files/{jobId}/*", s.handleDownloadFileForRunner)
|
|
r.Post("/files/{jobId}/upload", s.handleUploadFileFromRunner)
|
|
r.Get("/jobs/{jobId}/status", s.handleGetJobStatusForRunner)
|
|
r.Get("/jobs/{jobId}/files", s.handleGetJobFilesForRunner)
|
|
r.Post("/jobs/{jobId}/metadata", s.handleSubmitMetadata)
|
|
})
|
|
})
|
|
|
|
// Serve static files (built React app)
|
|
s.router.Handle("/*", http.FileServer(http.Dir("./web/dist")))
|
|
}
|
|
|
|
// ServeHTTP implements http.Handler
|
|
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
s.router.ServeHTTP(w, r)
|
|
}
|
|
|
|
// JSON response helpers
|
|
func (s *Server) respondJSON(w http.ResponseWriter, status int, data interface{}) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
if err := json.NewEncoder(w).Encode(data); err != nil {
|
|
log.Printf("Failed to encode JSON response: %v", err)
|
|
}
|
|
}
|
|
|
|
func (s *Server) respondError(w http.ResponseWriter, status int, message string) {
|
|
s.respondJSON(w, status, map[string]string{"error": message})
|
|
}
|
|
|
|
// Auth handlers
|
|
func (s *Server) handleGoogleLogin(w http.ResponseWriter, r *http.Request) {
|
|
url, err := s.auth.GoogleLoginURL()
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, err.Error())
|
|
return
|
|
}
|
|
http.Redirect(w, r, url, http.StatusFound)
|
|
}
|
|
|
|
func (s *Server) handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
|
|
code := r.URL.Query().Get("code")
|
|
if code == "" {
|
|
s.respondError(w, http.StatusBadRequest, "Missing code parameter")
|
|
return
|
|
}
|
|
|
|
session, err := s.auth.GoogleCallback(r.Context(), code)
|
|
if err != nil {
|
|
// If registration is disabled, redirect back to login with error
|
|
if err.Error() == "registration is disabled" {
|
|
http.Redirect(w, r, "/?error=registration_disabled", http.StatusFound)
|
|
return
|
|
}
|
|
s.respondError(w, http.StatusInternalServerError, err.Error())
|
|
return
|
|
}
|
|
|
|
sessionID := s.auth.CreateSession(session)
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "session_id",
|
|
Value: sessionID,
|
|
Path: "/",
|
|
MaxAge: 86400,
|
|
HttpOnly: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
|
|
http.Redirect(w, r, "/", http.StatusFound)
|
|
}
|
|
|
|
func (s *Server) handleDiscordLogin(w http.ResponseWriter, r *http.Request) {
|
|
url, err := s.auth.DiscordLoginURL()
|
|
if err != nil {
|
|
s.respondError(w, http.StatusInternalServerError, err.Error())
|
|
return
|
|
}
|
|
http.Redirect(w, r, url, http.StatusFound)
|
|
}
|
|
|
|
func (s *Server) handleDiscordCallback(w http.ResponseWriter, r *http.Request) {
|
|
code := r.URL.Query().Get("code")
|
|
if code == "" {
|
|
s.respondError(w, http.StatusBadRequest, "Missing code parameter")
|
|
return
|
|
}
|
|
|
|
session, err := s.auth.DiscordCallback(r.Context(), code)
|
|
if err != nil {
|
|
// If registration is disabled, redirect back to login with error
|
|
if err.Error() == "registration is disabled" {
|
|
http.Redirect(w, r, "/?error=registration_disabled", http.StatusFound)
|
|
return
|
|
}
|
|
s.respondError(w, http.StatusInternalServerError, err.Error())
|
|
return
|
|
}
|
|
|
|
sessionID := s.auth.CreateSession(session)
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "session_id",
|
|
Value: sessionID,
|
|
Path: "/",
|
|
MaxAge: 86400,
|
|
HttpOnly: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
|
|
http.Redirect(w, r, "/", http.StatusFound)
|
|
}
|
|
|
|
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
|
|
cookie, err := r.Cookie("session_id")
|
|
if err == nil {
|
|
s.auth.DeleteSession(cookie.Value)
|
|
}
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "session_id",
|
|
Value: "",
|
|
Path: "/",
|
|
MaxAge: -1,
|
|
HttpOnly: true,
|
|
})
|
|
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Logged out"})
|
|
}
|
|
|
|
func (s *Server) handleGetMe(w http.ResponseWriter, r *http.Request) {
|
|
cookie, err := r.Cookie("session_id")
|
|
if err != nil {
|
|
s.respondError(w, http.StatusUnauthorized, "Not authenticated")
|
|
return
|
|
}
|
|
|
|
session, ok := s.auth.GetSession(cookie.Value)
|
|
if !ok {
|
|
s.respondError(w, http.StatusUnauthorized, "Invalid session")
|
|
return
|
|
}
|
|
|
|
s.respondJSON(w, http.StatusOK, map[string]interface{}{
|
|
"id": session.UserID,
|
|
"email": session.Email,
|
|
"name": session.Name,
|
|
"is_admin": session.IsAdmin,
|
|
})
|
|
}
|
|
|
|
func (s *Server) handleGetAuthProviders(w http.ResponseWriter, r *http.Request) {
|
|
s.respondJSON(w, http.StatusOK, map[string]bool{
|
|
"google": s.auth.IsGoogleOAuthConfigured(),
|
|
"discord": s.auth.IsDiscordOAuthConfigured(),
|
|
"local": s.auth.IsLocalLoginEnabled(),
|
|
})
|
|
}
|
|
|
|
func (s *Server) handleLocalLoginAvailable(w http.ResponseWriter, r *http.Request) {
|
|
s.respondJSON(w, http.StatusOK, map[string]bool{
|
|
"available": s.auth.IsLocalLoginEnabled(),
|
|
})
|
|
}
|
|
|
|
func (s *Server) handleLocalRegister(w http.ResponseWriter, r *http.Request) {
|
|
var req struct {
|
|
Email string `json:"email"`
|
|
Name string `json:"name"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
|
return
|
|
}
|
|
|
|
if req.Email == "" || req.Name == "" || req.Password == "" {
|
|
s.respondError(w, http.StatusBadRequest, "Email, name, and password are required")
|
|
return
|
|
}
|
|
|
|
if len(req.Password) < 8 {
|
|
s.respondError(w, http.StatusBadRequest, "Password must be at least 8 characters long")
|
|
return
|
|
}
|
|
|
|
session, err := s.auth.RegisterLocalUser(req.Email, req.Name, req.Password)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
sessionID := s.auth.CreateSession(session)
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "session_id",
|
|
Value: sessionID,
|
|
Path: "/",
|
|
MaxAge: 86400,
|
|
HttpOnly: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
|
|
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
|
|
"message": "Registration successful",
|
|
"user": map[string]interface{}{
|
|
"id": session.UserID,
|
|
"email": session.Email,
|
|
"name": session.Name,
|
|
"is_admin": session.IsAdmin,
|
|
},
|
|
})
|
|
}
|
|
|
|
func (s *Server) handleLocalLogin(w http.ResponseWriter, r *http.Request) {
|
|
var req struct {
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
|
return
|
|
}
|
|
|
|
if req.Username == "" || req.Password == "" {
|
|
s.respondError(w, http.StatusBadRequest, "Username and password are required")
|
|
return
|
|
}
|
|
|
|
session, err := s.auth.LocalLogin(req.Username, req.Password)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusUnauthorized, "Invalid credentials")
|
|
return
|
|
}
|
|
|
|
sessionID := s.auth.CreateSession(session)
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "session_id",
|
|
Value: sessionID,
|
|
Path: "/",
|
|
MaxAge: 86400,
|
|
HttpOnly: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
|
|
s.respondJSON(w, http.StatusOK, map[string]interface{}{
|
|
"message": "Login successful",
|
|
"user": map[string]interface{}{
|
|
"id": session.UserID,
|
|
"email": session.Email,
|
|
"name": session.Name,
|
|
"is_admin": session.IsAdmin,
|
|
},
|
|
})
|
|
}
|
|
|
|
func (s *Server) handleChangePassword(w http.ResponseWriter, r *http.Request) {
|
|
userID, err := getUserID(r)
|
|
if err != nil {
|
|
s.respondError(w, http.StatusUnauthorized, err.Error())
|
|
return
|
|
}
|
|
|
|
var req struct {
|
|
OldPassword string `json:"old_password"`
|
|
NewPassword string `json:"new_password"`
|
|
TargetUserID *int64 `json:"target_user_id,omitempty"` // For admin to change other users' passwords
|
|
}
|
|
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
|
return
|
|
}
|
|
|
|
if req.NewPassword == "" {
|
|
s.respondError(w, http.StatusBadRequest, "New password is required")
|
|
return
|
|
}
|
|
|
|
if len(req.NewPassword) < 8 {
|
|
s.respondError(w, http.StatusBadRequest, "Password must be at least 8 characters long")
|
|
return
|
|
}
|
|
|
|
isAdmin := authpkg.IsAdmin(r.Context())
|
|
|
|
// If target_user_id is provided and user is admin, allow changing other user's password
|
|
if req.TargetUserID != nil && isAdmin {
|
|
if err := s.auth.AdminChangePassword(*req.TargetUserID, req.NewPassword); err != nil {
|
|
s.respondError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Password changed successfully"})
|
|
return
|
|
}
|
|
|
|
// Otherwise, user is changing their own password (requires old password)
|
|
if req.OldPassword == "" {
|
|
s.respondError(w, http.StatusBadRequest, "Old password is required")
|
|
return
|
|
}
|
|
|
|
if err := s.auth.ChangePassword(userID, req.OldPassword, req.NewPassword); err != nil {
|
|
s.respondError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Password changed successfully"})
|
|
}
|
|
|
|
// Helper to get user ID from context
|
|
func getUserID(r *http.Request) (int64, error) {
|
|
userID, ok := authpkg.GetUserID(r.Context())
|
|
if !ok {
|
|
return 0, fmt.Errorf("user ID not found in context")
|
|
}
|
|
return userID, nil
|
|
}
|
|
|
|
// Helper to parse ID from URL
|
|
func parseID(r *http.Request, param string) (int64, error) {
|
|
idStr := chi.URLParam(r, param)
|
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("invalid ID: %s", idStr)
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
// StartBackgroundTasks starts background goroutines for error recovery
|
|
func (s *Server) StartBackgroundTasks() {
|
|
go s.recoverStuckTasks()
|
|
go s.cleanupOldMetadataJobs()
|
|
}
|
|
|
|
// recoverStuckTasks periodically checks for dead runners and stuck tasks
|
|
func (s *Server) recoverStuckTasks() {
|
|
ticker := time.NewTicker(10 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
// Also distribute tasks every 10 seconds (reduced frequency since we have event-driven distribution)
|
|
distributeTicker := time.NewTicker(10 * time.Second)
|
|
defer distributeTicker.Stop()
|
|
|
|
go func() {
|
|
for range distributeTicker.C {
|
|
s.distributeTasksToRunners()
|
|
}
|
|
}()
|
|
|
|
for range ticker.C {
|
|
func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("Panic in recoverStuckTasks: %v", r)
|
|
}
|
|
}()
|
|
|
|
// Find dead runners (no heartbeat for 90 seconds)
|
|
// But only mark as dead if they're not actually connected via WebSocket
|
|
rows, err := s.db.Query(
|
|
`SELECT id FROM runners
|
|
WHERE last_heartbeat < CURRENT_TIMESTAMP - INTERVAL '90 seconds'
|
|
AND status = ?`,
|
|
types.RunnerStatusOnline,
|
|
)
|
|
if err != nil {
|
|
log.Printf("Failed to query dead runners: %v", err)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
var deadRunnerIDs []int64
|
|
s.runnerConnsMu.RLock()
|
|
for rows.Next() {
|
|
var runnerID int64
|
|
if err := rows.Scan(&runnerID); err == nil {
|
|
// Only mark as dead if not actually connected via WebSocket
|
|
// The WebSocket connection is the source of truth
|
|
if _, stillConnected := s.runnerConns[runnerID]; !stillConnected {
|
|
deadRunnerIDs = append(deadRunnerIDs, runnerID)
|
|
}
|
|
// If still connected, heartbeat should be updated by pong handler or heartbeat message
|
|
// No need to manually update here - if it's stale, the pong handler isn't working
|
|
}
|
|
}
|
|
s.runnerConnsMu.RUnlock()
|
|
rows.Close()
|
|
|
|
if len(deadRunnerIDs) == 0 {
|
|
// Check for task timeouts
|
|
s.recoverTaskTimeouts()
|
|
return
|
|
}
|
|
|
|
// Reset tasks assigned to dead runners
|
|
for _, runnerID := range deadRunnerIDs {
|
|
s.redistributeRunnerTasks(runnerID)
|
|
|
|
// Mark runner as offline
|
|
_, _ = s.db.Exec(
|
|
`UPDATE runners SET status = ? WHERE id = ?`,
|
|
types.RunnerStatusOffline, runnerID,
|
|
)
|
|
}
|
|
|
|
// Check for task timeouts
|
|
s.recoverTaskTimeouts()
|
|
|
|
// Distribute newly recovered tasks
|
|
s.distributeTasksToRunners()
|
|
}()
|
|
}
|
|
}
|
|
|
|
// recoverTaskTimeouts handles tasks that have exceeded their timeout
|
|
func (s *Server) recoverTaskTimeouts() {
|
|
// Find tasks running longer than their timeout
|
|
rows, err := s.db.Query(
|
|
`SELECT t.id, t.runner_id, t.retry_count, t.max_retries, t.timeout_seconds, t.started_at
|
|
FROM tasks t
|
|
WHERE t.status = ?
|
|
AND t.started_at IS NOT NULL
|
|
AND (t.timeout_seconds IS NULL OR
|
|
t.started_at + INTERVAL (t.timeout_seconds || ' seconds') < CURRENT_TIMESTAMP)`,
|
|
types.TaskStatusRunning,
|
|
)
|
|
if err != nil {
|
|
log.Printf("Failed to query timed out tasks: %v", err)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var taskID int64
|
|
var runnerID sql.NullInt64
|
|
var retryCount, maxRetries int
|
|
var timeoutSeconds sql.NullInt64
|
|
var startedAt time.Time
|
|
|
|
err := rows.Scan(&taskID, &runnerID, &retryCount, &maxRetries, &timeoutSeconds, &startedAt)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
// Use default timeout if not set (5 minutes for frame tasks, 24 hours for FFmpeg)
|
|
timeout := 300 // 5 minutes default
|
|
if timeoutSeconds.Valid {
|
|
timeout = int(timeoutSeconds.Int64)
|
|
}
|
|
|
|
// Check if actually timed out
|
|
if time.Since(startedAt).Seconds() < float64(timeout) {
|
|
continue
|
|
}
|
|
|
|
if retryCount >= maxRetries {
|
|
// Mark as failed
|
|
_, err = s.db.Exec(
|
|
`UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL
|
|
WHERE id = ?`,
|
|
types.TaskStatusFailed, "Task timeout exceeded, max retries reached", taskID,
|
|
)
|
|
if err != nil {
|
|
log.Printf("Failed to mark task %d as failed: %v", taskID, err)
|
|
}
|
|
} else {
|
|
// Reset to pending
|
|
_, err = s.db.Exec(
|
|
`UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL,
|
|
retry_count = retry_count + 1 WHERE id = ?`,
|
|
types.TaskStatusPending, taskID,
|
|
)
|
|
if err == nil {
|
|
// Add log entry using the helper function
|
|
s.logTaskEvent(taskID, nil, types.LogLevelWarn, fmt.Sprintf("Task timeout exceeded, resetting (retry %d/%d)", retryCount+1, maxRetries), "")
|
|
}
|
|
}
|
|
}
|
|
}
|