massive changes and it works

This commit is contained in:
2025-11-23 10:58:24 -06:00
parent 30aa969433
commit 2a0ff98834
3499 changed files with 7770 additions and 634687 deletions

View File

@@ -7,7 +7,7 @@ import (
"net/http"
"time"
"fuego/pkg/types"
"jiggablend/pkg/types"
)
// handleGenerateRegistrationToken generates a new registration token
@@ -126,7 +126,7 @@ func (s *Server) handleDeleteRunner(w http.ResponseWriter, r *http.Request) {
func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request) {
rows, err := s.db.Query(
`SELECT id, name, hostname, ip_address, status, last_heartbeat, capabilities,
registration_token, verified, created_at
registration_token, verified, priority, created_at
FROM runners ORDER BY created_at DESC`,
)
if err != nil {
@@ -144,7 +144,7 @@ func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request)
err := rows.Scan(
&runner.ID, &runner.Name, &runner.Hostname, &runner.IPAddress,
&runner.Status, &runner.LastHeartbeat, &runner.Capabilities,
&registrationToken, &verified, &runner.CreatedAt,
&registrationToken, &verified, &runner.Priority, &runner.CreatedAt,
)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan runner: %v", err))
@@ -161,9 +161,206 @@ func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request)
"capabilities": runner.Capabilities,
"registration_token": registrationToken.String,
"verified": verified,
"priority": runner.Priority,
"created_at": runner.CreatedAt,
})
}
s.respondJSON(w, http.StatusOK, runners)
}
// handleListUsers lists all users
func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) {
// Get first user ID to mark it in the response
firstUserID, err := s.auth.GetFirstUserID()
if err != nil {
// If no users exist, firstUserID will be 0, which is fine
firstUserID = 0
}
rows, err := s.db.Query(
`SELECT id, email, name, oauth_provider, is_admin, created_at
FROM users ORDER BY created_at DESC`,
)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query users: %v", err))
return
}
defer rows.Close()
users := []map[string]interface{}{}
for rows.Next() {
var userID int64
var email, name, oauthProvider string
var isAdmin bool
var createdAt time.Time
err := rows.Scan(&userID, &email, &name, &oauthProvider, &isAdmin, &createdAt)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan user: %v", err))
return
}
// Get job count for this user
var jobCount int
err = s.db.QueryRow("SELECT COUNT(*) FROM jobs WHERE user_id = ?", userID).Scan(&jobCount)
if err != nil {
jobCount = 0 // Default to 0 if query fails
}
users = append(users, map[string]interface{}{
"id": userID,
"email": email,
"name": name,
"oauth_provider": oauthProvider,
"is_admin": isAdmin,
"created_at": createdAt,
"job_count": jobCount,
"is_first_user": userID == firstUserID,
})
}
s.respondJSON(w, http.StatusOK, users)
}
// handleGetUserJobs gets all jobs for a specific user
func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) {
userID, err := parseID(r, "id")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
// Verify user exists
var exists bool
err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", userID).Scan(&exists)
if err != nil || !exists {
s.respondError(w, http.StatusNotFound, "User not found")
return
}
rows, err := s.db.Query(
`SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format,
allow_parallel_runners, timeout_seconds, blend_metadata, created_at, started_at, completed_at, error_message
FROM jobs WHERE user_id = ? ORDER BY created_at DESC`,
userID,
)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query jobs: %v", err))
return
}
defer rows.Close()
jobs := []types.Job{}
for rows.Next() {
var job types.Job
var jobType string
var startedAt, completedAt sql.NullTime
var blendMetadataJSON sql.NullString
var errorMessage sql.NullString
var frameStart, frameEnd sql.NullInt64
var outputFormat sql.NullString
var allowParallelRunners sql.NullBool
err := rows.Scan(
&job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress,
&frameStart, &frameEnd, &outputFormat, &allowParallelRunners, &job.TimeoutSeconds,
&blendMetadataJSON, &job.CreatedAt, &startedAt, &completedAt, &errorMessage,
)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan job: %v", err))
return
}
job.JobType = types.JobType(jobType)
if frameStart.Valid {
fs := int(frameStart.Int64)
job.FrameStart = &fs
}
if frameEnd.Valid {
fe := int(frameEnd.Int64)
job.FrameEnd = &fe
}
if outputFormat.Valid {
job.OutputFormat = &outputFormat.String
}
if allowParallelRunners.Valid {
job.AllowParallelRunners = &allowParallelRunners.Bool
}
if startedAt.Valid {
job.StartedAt = &startedAt.Time
}
if completedAt.Valid {
job.CompletedAt = &completedAt.Time
}
if blendMetadataJSON.Valid && blendMetadataJSON.String != "" {
var metadata types.BlendMetadata
if err := json.Unmarshal([]byte(blendMetadataJSON.String), &metadata); err == nil {
job.BlendMetadata = &metadata
}
}
if errorMessage.Valid {
job.ErrorMessage = errorMessage.String
}
jobs = append(jobs, job)
}
s.respondJSON(w, http.StatusOK, jobs)
}
// handleGetRegistrationEnabled gets the registration enabled setting
func (s *Server) handleGetRegistrationEnabled(w http.ResponseWriter, r *http.Request) {
enabled, err := s.auth.IsRegistrationEnabled()
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to get registration setting: %v", err))
return
}
s.respondJSON(w, http.StatusOK, map[string]bool{"enabled": enabled})
}
// handleSetRegistrationEnabled sets the registration enabled setting
func (s *Server) handleSetRegistrationEnabled(w http.ResponseWriter, r *http.Request) {
var req struct {
Enabled bool `json:"enabled"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
if err := s.auth.SetRegistrationEnabled(req.Enabled); err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to set registration setting: %v", err))
return
}
s.respondJSON(w, http.StatusOK, map[string]bool{"enabled": req.Enabled})
}
// handleSetUserAdminStatus sets a user's admin status (admin only)
func (s *Server) handleSetUserAdminStatus(w http.ResponseWriter, r *http.Request) {
targetUserID, err := parseID(r, "id")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
var req struct {
IsAdmin bool `json:"is_admin"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
if err := s.auth.SetUserAdminStatus(targetUserID, req.IsAdmin); err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
s.respondJSON(w, http.StatusOK, map[string]interface{}{
"user_id": targetUserID,
"is_admin": req.IsAdmin,
"message": "Admin status updated successfully",
})
}

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,7 @@ import (
"log"
"net/http"
"fuego/pkg/types"
"jiggablend/pkg/types"
)
// handleSubmitMetadata handles metadata submission from runner
@@ -19,7 +19,7 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) {
}
// Get runner ID from context (set by runnerAuthMiddleware)
runnerID, ok := r.Context().Value("runner_id").(int64)
runnerID, ok := r.Context().Value(runnerIDContextKey).(int64)
if !ok {
s.respondError(w, http.StatusUnauthorized, "runner_id not found in context")
return
@@ -44,16 +44,32 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) {
}
// Find the metadata extraction task for this job
// First try to find task assigned to this runner, then fall back to any metadata task for this job
var taskID int64
err = s.db.QueryRow(
`SELECT id FROM tasks WHERE job_id = ? AND task_type = ? AND runner_id = ?`,
jobID, types.TaskTypeMetadata, runnerID,
).Scan(&taskID)
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "Metadata extraction task not found")
return
}
if err != nil {
// Fall back to any metadata task for this job (in case assignment changed)
err = s.db.QueryRow(
`SELECT id FROM tasks WHERE job_id = ? AND task_type = ? ORDER BY created_at DESC LIMIT 1`,
jobID, types.TaskTypeMetadata,
).Scan(&taskID)
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "Metadata extraction task not found")
return
}
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to find task: %v", err))
return
}
// Update the task to be assigned to this runner if it wasn't already
s.db.Exec(
`UPDATE tasks SET runner_id = ? WHERE id = ? AND runner_id IS NULL`,
runnerID, taskID,
)
} else if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to find task: %v", err))
return
}
@@ -82,6 +98,9 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) {
)
if err != nil {
log.Printf("Failed to mark metadata task as completed: %v", err)
} else {
// Update job status and progress after metadata task completes
s.updateJobStatusFromTasks(jobID)
}
log.Printf("Metadata extracted for job %d: frame_start=%d, frame_end=%d", jobID, metadata.FrameStart, metadata.FrameEnd)

File diff suppressed because it is too large Load Diff

View File

@@ -10,10 +10,10 @@ import (
"sync"
"time"
authpkg "fuego/internal/auth"
"fuego/internal/database"
"fuego/internal/storage"
"fuego/pkg/types"
authpkg "jiggablend/internal/auth"
"jiggablend/internal/database"
"jiggablend/internal/storage"
"jiggablend/pkg/types"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
@@ -35,6 +35,12 @@ type Server struct {
runnerConnsMu sync.RWMutex
frontendConns map[string]*websocket.Conn // key: "jobId:taskId"
frontendConnsMu sync.RWMutex
// Mutexes for each frontend connection to serialize writes
frontendConnsWriteMu map[string]*sync.Mutex // key: "jobId:taskId"
frontendConnsWriteMuMu sync.RWMutex
// Throttling for progress updates (per job)
progressUpdateTimes map[int64]time.Time // key: jobID
progressUpdateTimesMu sync.RWMutex
}
// NewServer creates a new API server
@@ -57,8 +63,10 @@ func NewServer(db *database.DB, auth *authpkg.Auth, storage *storage.Storage) (*
ReadBufferSize: 1024,
WriteBufferSize: 1024,
},
runnerConns: make(map[int64]*websocket.Conn),
frontendConns: make(map[string]*websocket.Conn),
runnerConns: make(map[int64]*websocket.Conn),
frontendConns: make(map[string]*websocket.Conn),
frontendConnsWriteMu: make(map[string]*sync.Mutex),
progressUpdateTimes: make(map[int64]time.Time),
}
s.setupMiddleware()
@@ -72,13 +80,14 @@ func NewServer(db *database.DB, auth *authpkg.Auth, storage *storage.Storage) (*
func (s *Server) setupMiddleware() {
s.router.Use(middleware.Logger)
s.router.Use(middleware.Recoverer)
s.router.Use(middleware.Timeout(60 * time.Second))
// Note: Timeout middleware is NOT applied globally to avoid conflicts with WebSocket connections
// WebSocket connections are long-lived and should not have HTTP timeouts
s.router.Use(cors.Handler(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"},
ExposedHeaders: []string{"Link"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "Range"},
ExposedHeaders: []string{"Link", "Content-Range", "Accept-Ranges", "Content-Length"},
AllowCredentials: true,
MaxAge: 300,
}))
@@ -88,12 +97,17 @@ func (s *Server) setupMiddleware() {
func (s *Server) setupRoutes() {
// Public routes
s.router.Route("/api/auth", func(r chi.Router) {
r.Get("/providers", s.handleGetAuthProviders)
r.Get("/google/login", s.handleGoogleLogin)
r.Get("/google/callback", s.handleGoogleCallback)
r.Get("/discord/login", s.handleDiscordLogin)
r.Get("/discord/callback", s.handleDiscordCallback)
r.Get("/local/available", s.handleLocalLoginAvailable)
r.Post("/local/register", s.handleLocalRegister)
r.Post("/local/login", s.handleLocalLogin)
r.Post("/logout", s.handleLogout)
r.Get("/me", s.handleGetMe)
r.Post("/change-password", s.handleChangePassword)
})
// Protected routes
@@ -105,6 +119,7 @@ func (s *Server) setupRoutes() {
r.Get("/", s.handleListJobs)
r.Get("/{id}", s.handleGetJob)
r.Delete("/{id}", s.handleCancelJob)
r.Post("/{id}/delete", s.handleDeleteJob)
r.Post("/{id}/upload", s.handleUploadJobFile)
r.Get("/{id}/files", s.handleListJobFiles)
r.Get("/{id}/files/{fileId}/download", s.handleDownloadJobFile)
@@ -112,18 +127,17 @@ func (s *Server) setupRoutes() {
r.Get("/{id}/metadata", s.handleGetJobMetadata)
r.Get("/{id}/tasks", s.handleListJobTasks)
r.Get("/{id}/tasks/{taskId}/logs", s.handleGetTaskLogs)
r.Get("/{id}/tasks/{taskId}/logs/ws", s.handleStreamTaskLogsWebSocket)
// WebSocket route - no timeout middleware (long-lived connection)
r.With(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Remove timeout middleware for WebSocket
next.ServeHTTP(w, r)
})
}).Get("/{id}/tasks/{taskId}/logs/ws", s.handleStreamTaskLogsWebSocket)
r.Get("/{id}/tasks/{taskId}/steps", s.handleGetTaskSteps)
r.Post("/{id}/tasks/{taskId}/retry", s.handleRetryTask)
})
s.router.Route("/api/runners", func(r chi.Router) {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(s.auth.Middleware(next.ServeHTTP))
})
r.Get("/", s.handleListRunners)
})
// Admin routes
s.router.Route("/api/admin", func(r chi.Router) {
r.Use(func(next http.Handler) http.Handler {
@@ -139,14 +153,23 @@ func (s *Server) setupRoutes() {
r.Post("/{id}/verify", s.handleVerifyRunner)
r.Delete("/{id}", s.handleDeleteRunner)
})
r.Route("/users", func(r chi.Router) {
r.Get("/", s.handleListUsers)
r.Get("/{id}/jobs", s.handleGetUserJobs)
r.Post("/{id}/admin", s.handleSetUserAdminStatus)
})
r.Route("/settings", func(r chi.Router) {
r.Get("/registration", s.handleGetRegistrationEnabled)
r.Post("/registration", s.handleSetRegistrationEnabled)
})
})
// Runner API
s.router.Route("/api/runner", func(r chi.Router) {
// Registration doesn't require auth (uses token)
r.Post("/register", s.handleRegisterRunner)
r.With(middleware.Timeout(60*time.Second)).Post("/register", s.handleRegisterRunner)
// WebSocket endpoint (auth handled in handler)
// WebSocket endpoint (auth handled in handler) - no timeout middleware
r.Get("/ws", s.handleRunnerWebSocket)
// File operations still use HTTP (WebSocket not suitable for large files)
@@ -156,7 +179,7 @@ func (s *Server) setupRoutes() {
})
r.Post("/tasks/{id}/progress", s.handleUpdateTaskProgress)
r.Post("/tasks/{id}/steps", s.handleUpdateTaskStep)
r.Get("/files/{jobId}/{fileName}", s.handleDownloadFileForRunner)
r.Get("/files/{jobId}/*", s.handleDownloadFileForRunner)
r.Post("/files/{jobId}/upload", s.handleUploadFileFromRunner)
r.Get("/jobs/{jobId}/status", s.handleGetJobStatusForRunner)
r.Get("/jobs/{jobId}/files", s.handleGetJobFilesForRunner)
@@ -205,6 +228,11 @@ func (s *Server) handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
session, err := s.auth.GoogleCallback(r.Context(), code)
if err != nil {
// If registration is disabled, redirect back to login with error
if err.Error() == "registration is disabled" {
http.Redirect(w, r, "/?error=registration_disabled", http.StatusFound)
return
}
s.respondError(w, http.StatusInternalServerError, err.Error())
return
}
@@ -240,6 +268,11 @@ func (s *Server) handleDiscordCallback(w http.ResponseWriter, r *http.Request) {
session, err := s.auth.DiscordCallback(r.Context(), code)
if err != nil {
// If registration is disabled, redirect back to login with error
if err.Error() == "registration is disabled" {
http.Redirect(w, r, "/?error=registration_disabled", http.StatusFound)
return
}
s.respondError(w, http.StatusInternalServerError, err.Error())
return
}
@@ -293,6 +326,166 @@ func (s *Server) handleGetMe(w http.ResponseWriter, r *http.Request) {
})
}
func (s *Server) handleGetAuthProviders(w http.ResponseWriter, r *http.Request) {
s.respondJSON(w, http.StatusOK, map[string]bool{
"google": s.auth.IsGoogleOAuthConfigured(),
"discord": s.auth.IsDiscordOAuthConfigured(),
"local": s.auth.IsLocalLoginEnabled(),
})
}
func (s *Server) handleLocalLoginAvailable(w http.ResponseWriter, r *http.Request) {
s.respondJSON(w, http.StatusOK, map[string]bool{
"available": s.auth.IsLocalLoginEnabled(),
})
}
func (s *Server) handleLocalRegister(w http.ResponseWriter, r *http.Request) {
var req struct {
Email string `json:"email"`
Name string `json:"name"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
if req.Email == "" || req.Name == "" || req.Password == "" {
s.respondError(w, http.StatusBadRequest, "Email, name, and password are required")
return
}
if len(req.Password) < 8 {
s.respondError(w, http.StatusBadRequest, "Password must be at least 8 characters long")
return
}
session, err := s.auth.RegisterLocalUser(req.Email, req.Name, req.Password)
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: sessionID,
Path: "/",
MaxAge: 86400,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
"message": "Registration successful",
"user": map[string]interface{}{
"id": session.UserID,
"email": session.Email,
"name": session.Name,
"is_admin": session.IsAdmin,
},
})
}
func (s *Server) handleLocalLogin(w http.ResponseWriter, r *http.Request) {
var req struct {
Username string `json:"username"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
if req.Username == "" || req.Password == "" {
s.respondError(w, http.StatusBadRequest, "Username and password are required")
return
}
session, err := s.auth.LocalLogin(req.Username, req.Password)
if err != nil {
s.respondError(w, http.StatusUnauthorized, "Invalid credentials")
return
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: sessionID,
Path: "/",
MaxAge: 86400,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
s.respondJSON(w, http.StatusOK, map[string]interface{}{
"message": "Login successful",
"user": map[string]interface{}{
"id": session.UserID,
"email": session.Email,
"name": session.Name,
"is_admin": session.IsAdmin,
},
})
}
func (s *Server) handleChangePassword(w http.ResponseWriter, r *http.Request) {
userID, err := getUserID(r)
if err != nil {
s.respondError(w, http.StatusUnauthorized, err.Error())
return
}
var req struct {
OldPassword string `json:"old_password"`
NewPassword string `json:"new_password"`
TargetUserID *int64 `json:"target_user_id,omitempty"` // For admin to change other users' passwords
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
if req.NewPassword == "" {
s.respondError(w, http.StatusBadRequest, "New password is required")
return
}
if len(req.NewPassword) < 8 {
s.respondError(w, http.StatusBadRequest, "Password must be at least 8 characters long")
return
}
isAdmin := authpkg.IsAdmin(r.Context())
// If target_user_id is provided and user is admin, allow changing other user's password
if req.TargetUserID != nil && isAdmin {
if err := s.auth.AdminChangePassword(*req.TargetUserID, req.NewPassword); err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Password changed successfully"})
return
}
// Otherwise, user is changing their own password (requires old password)
if req.OldPassword == "" {
s.respondError(w, http.StatusBadRequest, "Old password is required")
return
}
if err := s.auth.ChangePassword(userID, req.OldPassword, req.NewPassword); err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Password changed successfully"})
}
// Helper to get user ID from context
func getUserID(r *http.Request) (int64, error) {
userID, ok := authpkg.GetUserID(r.Context())
@@ -315,6 +508,7 @@ func parseID(r *http.Request, param string) (int64, error) {
// StartBackgroundTasks starts background goroutines for error recovery
func (s *Server) StartBackgroundTasks() {
go s.recoverStuckTasks()
go s.cleanupOldMetadataJobs()
}
// recoverStuckTasks periodically checks for dead runners and stuck tasks
@@ -322,8 +516,8 @@ func (s *Server) recoverStuckTasks() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
// Also distribute tasks every 5 seconds
distributeTicker := time.NewTicker(5 * time.Second)
// Also distribute tasks every 10 seconds (reduced frequency since we have event-driven distribution)
distributeTicker := time.NewTicker(10 * time.Second)
defer distributeTicker.Stop()
go func() {
@@ -341,9 +535,10 @@ func (s *Server) recoverStuckTasks() {
}()
// Find dead runners (no heartbeat for 90 seconds)
// But only mark as dead if they're not actually connected via WebSocket
rows, err := s.db.Query(
`SELECT id FROM runners
WHERE last_heartbeat < datetime('now', '-90 seconds')
WHERE last_heartbeat < CURRENT_TIMESTAMP - INTERVAL '90 seconds'
AND status = ?`,
types.RunnerStatusOnline,
)
@@ -354,12 +549,20 @@ func (s *Server) recoverStuckTasks() {
defer rows.Close()
var deadRunnerIDs []int64
s.runnerConnsMu.RLock()
for rows.Next() {
var runnerID int64
if err := rows.Scan(&runnerID); err == nil {
deadRunnerIDs = append(deadRunnerIDs, runnerID)
// Only mark as dead if not actually connected via WebSocket
// The WebSocket connection is the source of truth
if _, stillConnected := s.runnerConns[runnerID]; !stillConnected {
deadRunnerIDs = append(deadRunnerIDs, runnerID)
}
// If still connected, heartbeat should be updated by pong handler or heartbeat message
// No need to manually update here - if it's stale, the pong handler isn't working
}
}
s.runnerConnsMu.RUnlock()
rows.Close()
if len(deadRunnerIDs) == 0 {
@@ -370,67 +573,7 @@ func (s *Server) recoverStuckTasks() {
// Reset tasks assigned to dead runners
for _, runnerID := range deadRunnerIDs {
// Get tasks assigned to this runner
taskRows, err := s.db.Query(
`SELECT id, retry_count, max_retries FROM tasks
WHERE runner_id = ? AND status = ?`,
runnerID, types.TaskStatusRunning,
)
if err != nil {
log.Printf("Failed to query tasks for runner %d: %v", runnerID, err)
continue
}
var tasksToReset []struct {
ID int64
RetryCount int
MaxRetries int
}
for taskRows.Next() {
var t struct {
ID int64
RetryCount int
MaxRetries int
}
if err := taskRows.Scan(&t.ID, &t.RetryCount, &t.MaxRetries); err == nil {
tasksToReset = append(tasksToReset, t)
}
}
taskRows.Close()
// Reset or fail tasks
for _, task := range tasksToReset {
if task.RetryCount >= task.MaxRetries {
// Mark as failed
_, err = s.db.Exec(
`UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL
WHERE id = ?`,
types.TaskStatusFailed, "Runner died, max retries exceeded", task.ID,
)
if err != nil {
log.Printf("Failed to mark task %d as failed: %v", task.ID, err)
}
} else {
// Reset to pending
_, err = s.db.Exec(
`UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL,
retry_count = retry_count + 1 WHERE id = ?`,
types.TaskStatusPending, task.ID,
)
if err != nil {
log.Printf("Failed to reset task %d: %v", task.ID, err)
} else {
// Add log entry
_, _ = s.db.Exec(
`INSERT INTO task_logs (task_id, log_level, message, step_name, created_at)
VALUES (?, ?, ?, ?, ?)`,
task.ID, types.LogLevelWarn, fmt.Sprintf("Runner died, task reset (retry %d/%d)", task.RetryCount+1, task.MaxRetries),
"", time.Now(),
)
}
}
}
s.redistributeRunnerTasks(runnerID)
// Mark runner as offline
_, _ = s.db.Exec(
@@ -457,7 +600,7 @@ func (s *Server) recoverTaskTimeouts() {
WHERE t.status = ?
AND t.started_at IS NOT NULL
AND (t.timeout_seconds IS NULL OR
datetime(t.started_at, '+' || t.timeout_seconds || ' seconds') < datetime('now'))`,
t.started_at + INTERVAL (t.timeout_seconds || ' seconds') < CURRENT_TIMESTAMP)`,
types.TaskStatusRunning,
)
if err != nil {
@@ -507,13 +650,8 @@ func (s *Server) recoverTaskTimeouts() {
types.TaskStatusPending, taskID,
)
if err == nil {
// Add log entry
_, _ = s.db.Exec(
`INSERT INTO task_logs (task_id, log_level, message, step_name, created_at)
VALUES (?, ?, ?, ?, ?)`,
taskID, types.LogLevelWarn, fmt.Sprintf("Task timeout exceeded, resetting (retry %d/%d)", retryCount+1, maxRetries),
"", time.Now(),
)
// Add log entry using the helper function
s.logTaskEvent(taskID, nil, types.LogLevelWarn, fmt.Sprintf("Task timeout exceeded, resetting (retry %d/%d)", retryCount+1, maxRetries), "")
}
}
}