massive changes and it works
This commit is contained in:
@@ -7,7 +7,7 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"fuego/pkg/types"
|
||||
"jiggablend/pkg/types"
|
||||
)
|
||||
|
||||
// handleGenerateRegistrationToken generates a new registration token
|
||||
@@ -126,7 +126,7 @@ func (s *Server) handleDeleteRunner(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, name, hostname, ip_address, status, last_heartbeat, capabilities,
|
||||
registration_token, verified, created_at
|
||||
registration_token, verified, priority, created_at
|
||||
FROM runners ORDER BY created_at DESC`,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -144,7 +144,7 @@ func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request)
|
||||
err := rows.Scan(
|
||||
&runner.ID, &runner.Name, &runner.Hostname, &runner.IPAddress,
|
||||
&runner.Status, &runner.LastHeartbeat, &runner.Capabilities,
|
||||
®istrationToken, &verified, &runner.CreatedAt,
|
||||
®istrationToken, &verified, &runner.Priority, &runner.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan runner: %v", err))
|
||||
@@ -161,9 +161,206 @@ func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request)
|
||||
"capabilities": runner.Capabilities,
|
||||
"registration_token": registrationToken.String,
|
||||
"verified": verified,
|
||||
"priority": runner.Priority,
|
||||
"created_at": runner.CreatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, runners)
|
||||
}
|
||||
|
||||
// handleListUsers lists all users
|
||||
func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) {
|
||||
// Get first user ID to mark it in the response
|
||||
firstUserID, err := s.auth.GetFirstUserID()
|
||||
if err != nil {
|
||||
// If no users exist, firstUserID will be 0, which is fine
|
||||
firstUserID = 0
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, email, name, oauth_provider, is_admin, created_at
|
||||
FROM users ORDER BY created_at DESC`,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query users: %v", err))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
users := []map[string]interface{}{}
|
||||
for rows.Next() {
|
||||
var userID int64
|
||||
var email, name, oauthProvider string
|
||||
var isAdmin bool
|
||||
var createdAt time.Time
|
||||
|
||||
err := rows.Scan(&userID, &email, &name, &oauthProvider, &isAdmin, &createdAt)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan user: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Get job count for this user
|
||||
var jobCount int
|
||||
err = s.db.QueryRow("SELECT COUNT(*) FROM jobs WHERE user_id = ?", userID).Scan(&jobCount)
|
||||
if err != nil {
|
||||
jobCount = 0 // Default to 0 if query fails
|
||||
}
|
||||
|
||||
users = append(users, map[string]interface{}{
|
||||
"id": userID,
|
||||
"email": email,
|
||||
"name": name,
|
||||
"oauth_provider": oauthProvider,
|
||||
"is_admin": isAdmin,
|
||||
"created_at": createdAt,
|
||||
"job_count": jobCount,
|
||||
"is_first_user": userID == firstUserID,
|
||||
})
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, users)
|
||||
}
|
||||
|
||||
// handleGetUserJobs gets all jobs for a specific user
|
||||
func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Verify user exists
|
||||
var exists bool
|
||||
err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", userID).Scan(&exists)
|
||||
if err != nil || !exists {
|
||||
s.respondError(w, http.StatusNotFound, "User not found")
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format,
|
||||
allow_parallel_runners, timeout_seconds, blend_metadata, created_at, started_at, completed_at, error_message
|
||||
FROM jobs WHERE user_id = ? ORDER BY created_at DESC`,
|
||||
userID,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query jobs: %v", err))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
jobs := []types.Job{}
|
||||
for rows.Next() {
|
||||
var job types.Job
|
||||
var jobType string
|
||||
var startedAt, completedAt sql.NullTime
|
||||
var blendMetadataJSON sql.NullString
|
||||
var errorMessage sql.NullString
|
||||
var frameStart, frameEnd sql.NullInt64
|
||||
var outputFormat sql.NullString
|
||||
var allowParallelRunners sql.NullBool
|
||||
|
||||
err := rows.Scan(
|
||||
&job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress,
|
||||
&frameStart, &frameEnd, &outputFormat, &allowParallelRunners, &job.TimeoutSeconds,
|
||||
&blendMetadataJSON, &job.CreatedAt, &startedAt, &completedAt, &errorMessage,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan job: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
job.JobType = types.JobType(jobType)
|
||||
if frameStart.Valid {
|
||||
fs := int(frameStart.Int64)
|
||||
job.FrameStart = &fs
|
||||
}
|
||||
if frameEnd.Valid {
|
||||
fe := int(frameEnd.Int64)
|
||||
job.FrameEnd = &fe
|
||||
}
|
||||
if outputFormat.Valid {
|
||||
job.OutputFormat = &outputFormat.String
|
||||
}
|
||||
if allowParallelRunners.Valid {
|
||||
job.AllowParallelRunners = &allowParallelRunners.Bool
|
||||
}
|
||||
if startedAt.Valid {
|
||||
job.StartedAt = &startedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
job.CompletedAt = &completedAt.Time
|
||||
}
|
||||
if blendMetadataJSON.Valid && blendMetadataJSON.String != "" {
|
||||
var metadata types.BlendMetadata
|
||||
if err := json.Unmarshal([]byte(blendMetadataJSON.String), &metadata); err == nil {
|
||||
job.BlendMetadata = &metadata
|
||||
}
|
||||
}
|
||||
if errorMessage.Valid {
|
||||
job.ErrorMessage = errorMessage.String
|
||||
}
|
||||
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, jobs)
|
||||
}
|
||||
|
||||
// handleGetRegistrationEnabled gets the registration enabled setting
|
||||
func (s *Server) handleGetRegistrationEnabled(w http.ResponseWriter, r *http.Request) {
|
||||
enabled, err := s.auth.IsRegistrationEnabled()
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to get registration setting: %v", err))
|
||||
return
|
||||
}
|
||||
s.respondJSON(w, http.StatusOK, map[string]bool{"enabled": enabled})
|
||||
}
|
||||
|
||||
// handleSetRegistrationEnabled sets the registration enabled setting
|
||||
func (s *Server) handleSetRegistrationEnabled(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.auth.SetRegistrationEnabled(req.Enabled); err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to set registration setting: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]bool{"enabled": req.Enabled})
|
||||
}
|
||||
|
||||
// handleSetUserAdminStatus sets a user's admin status (admin only)
|
||||
func (s *Server) handleSetUserAdminStatus(w http.ResponseWriter, r *http.Request) {
|
||||
targetUserID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.auth.SetUserAdminStatus(targetUserID, req.IsAdmin); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"user_id": targetUserID,
|
||||
"is_admin": req.IsAdmin,
|
||||
"message": "Admin status updated successfully",
|
||||
})
|
||||
}
|
||||
|
||||
1246
internal/api/jobs.go
1246
internal/api/jobs.go
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,7 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"fuego/pkg/types"
|
||||
"jiggablend/pkg/types"
|
||||
)
|
||||
|
||||
// handleSubmitMetadata handles metadata submission from runner
|
||||
@@ -19,7 +19,7 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Get runner ID from context (set by runnerAuthMiddleware)
|
||||
runnerID, ok := r.Context().Value("runner_id").(int64)
|
||||
runnerID, ok := r.Context().Value(runnerIDContextKey).(int64)
|
||||
if !ok {
|
||||
s.respondError(w, http.StatusUnauthorized, "runner_id not found in context")
|
||||
return
|
||||
@@ -44,16 +44,32 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Find the metadata extraction task for this job
|
||||
// First try to find task assigned to this runner, then fall back to any metadata task for this job
|
||||
var taskID int64
|
||||
err = s.db.QueryRow(
|
||||
`SELECT id FROM tasks WHERE job_id = ? AND task_type = ? AND runner_id = ?`,
|
||||
jobID, types.TaskTypeMetadata, runnerID,
|
||||
).Scan(&taskID)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Metadata extraction task not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
// Fall back to any metadata task for this job (in case assignment changed)
|
||||
err = s.db.QueryRow(
|
||||
`SELECT id FROM tasks WHERE job_id = ? AND task_type = ? ORDER BY created_at DESC LIMIT 1`,
|
||||
jobID, types.TaskTypeMetadata,
|
||||
).Scan(&taskID)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Metadata extraction task not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to find task: %v", err))
|
||||
return
|
||||
}
|
||||
// Update the task to be assigned to this runner if it wasn't already
|
||||
s.db.Exec(
|
||||
`UPDATE tasks SET runner_id = ? WHERE id = ? AND runner_id IS NULL`,
|
||||
runnerID, taskID,
|
||||
)
|
||||
} else if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to find task: %v", err))
|
||||
return
|
||||
}
|
||||
@@ -82,6 +98,9 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) {
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to mark metadata task as completed: %v", err)
|
||||
} else {
|
||||
// Update job status and progress after metadata task completes
|
||||
s.updateJobStatusFromTasks(jobID)
|
||||
}
|
||||
|
||||
log.Printf("Metadata extracted for job %d: frame_start=%d, frame_end=%d", jobID, metadata.FrameStart, metadata.FrameEnd)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,10 +10,10 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
authpkg "fuego/internal/auth"
|
||||
"fuego/internal/database"
|
||||
"fuego/internal/storage"
|
||||
"fuego/pkg/types"
|
||||
authpkg "jiggablend/internal/auth"
|
||||
"jiggablend/internal/database"
|
||||
"jiggablend/internal/storage"
|
||||
"jiggablend/pkg/types"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
@@ -35,6 +35,12 @@ type Server struct {
|
||||
runnerConnsMu sync.RWMutex
|
||||
frontendConns map[string]*websocket.Conn // key: "jobId:taskId"
|
||||
frontendConnsMu sync.RWMutex
|
||||
// Mutexes for each frontend connection to serialize writes
|
||||
frontendConnsWriteMu map[string]*sync.Mutex // key: "jobId:taskId"
|
||||
frontendConnsWriteMuMu sync.RWMutex
|
||||
// Throttling for progress updates (per job)
|
||||
progressUpdateTimes map[int64]time.Time // key: jobID
|
||||
progressUpdateTimesMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewServer creates a new API server
|
||||
@@ -57,8 +63,10 @@ func NewServer(db *database.DB, auth *authpkg.Auth, storage *storage.Storage) (*
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
},
|
||||
runnerConns: make(map[int64]*websocket.Conn),
|
||||
frontendConns: make(map[string]*websocket.Conn),
|
||||
runnerConns: make(map[int64]*websocket.Conn),
|
||||
frontendConns: make(map[string]*websocket.Conn),
|
||||
frontendConnsWriteMu: make(map[string]*sync.Mutex),
|
||||
progressUpdateTimes: make(map[int64]time.Time),
|
||||
}
|
||||
|
||||
s.setupMiddleware()
|
||||
@@ -72,13 +80,14 @@ func NewServer(db *database.DB, auth *authpkg.Auth, storage *storage.Storage) (*
|
||||
func (s *Server) setupMiddleware() {
|
||||
s.router.Use(middleware.Logger)
|
||||
s.router.Use(middleware.Recoverer)
|
||||
s.router.Use(middleware.Timeout(60 * time.Second))
|
||||
// Note: Timeout middleware is NOT applied globally to avoid conflicts with WebSocket connections
|
||||
// WebSocket connections are long-lived and should not have HTTP timeouts
|
||||
|
||||
s.router.Use(cors.Handler(cors.Options{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"},
|
||||
ExposedHeaders: []string{"Link"},
|
||||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "Range"},
|
||||
ExposedHeaders: []string{"Link", "Content-Range", "Accept-Ranges", "Content-Length"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 300,
|
||||
}))
|
||||
@@ -88,12 +97,17 @@ func (s *Server) setupMiddleware() {
|
||||
func (s *Server) setupRoutes() {
|
||||
// Public routes
|
||||
s.router.Route("/api/auth", func(r chi.Router) {
|
||||
r.Get("/providers", s.handleGetAuthProviders)
|
||||
r.Get("/google/login", s.handleGoogleLogin)
|
||||
r.Get("/google/callback", s.handleGoogleCallback)
|
||||
r.Get("/discord/login", s.handleDiscordLogin)
|
||||
r.Get("/discord/callback", s.handleDiscordCallback)
|
||||
r.Get("/local/available", s.handleLocalLoginAvailable)
|
||||
r.Post("/local/register", s.handleLocalRegister)
|
||||
r.Post("/local/login", s.handleLocalLogin)
|
||||
r.Post("/logout", s.handleLogout)
|
||||
r.Get("/me", s.handleGetMe)
|
||||
r.Post("/change-password", s.handleChangePassword)
|
||||
})
|
||||
|
||||
// Protected routes
|
||||
@@ -105,6 +119,7 @@ func (s *Server) setupRoutes() {
|
||||
r.Get("/", s.handleListJobs)
|
||||
r.Get("/{id}", s.handleGetJob)
|
||||
r.Delete("/{id}", s.handleCancelJob)
|
||||
r.Post("/{id}/delete", s.handleDeleteJob)
|
||||
r.Post("/{id}/upload", s.handleUploadJobFile)
|
||||
r.Get("/{id}/files", s.handleListJobFiles)
|
||||
r.Get("/{id}/files/{fileId}/download", s.handleDownloadJobFile)
|
||||
@@ -112,18 +127,17 @@ func (s *Server) setupRoutes() {
|
||||
r.Get("/{id}/metadata", s.handleGetJobMetadata)
|
||||
r.Get("/{id}/tasks", s.handleListJobTasks)
|
||||
r.Get("/{id}/tasks/{taskId}/logs", s.handleGetTaskLogs)
|
||||
r.Get("/{id}/tasks/{taskId}/logs/ws", s.handleStreamTaskLogsWebSocket)
|
||||
// WebSocket route - no timeout middleware (long-lived connection)
|
||||
r.With(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Remove timeout middleware for WebSocket
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}).Get("/{id}/tasks/{taskId}/logs/ws", s.handleStreamTaskLogsWebSocket)
|
||||
r.Get("/{id}/tasks/{taskId}/steps", s.handleGetTaskSteps)
|
||||
r.Post("/{id}/tasks/{taskId}/retry", s.handleRetryTask)
|
||||
})
|
||||
|
||||
s.router.Route("/api/runners", func(r chi.Router) {
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(s.auth.Middleware(next.ServeHTTP))
|
||||
})
|
||||
r.Get("/", s.handleListRunners)
|
||||
})
|
||||
|
||||
// Admin routes
|
||||
s.router.Route("/api/admin", func(r chi.Router) {
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
@@ -139,14 +153,23 @@ func (s *Server) setupRoutes() {
|
||||
r.Post("/{id}/verify", s.handleVerifyRunner)
|
||||
r.Delete("/{id}", s.handleDeleteRunner)
|
||||
})
|
||||
r.Route("/users", func(r chi.Router) {
|
||||
r.Get("/", s.handleListUsers)
|
||||
r.Get("/{id}/jobs", s.handleGetUserJobs)
|
||||
r.Post("/{id}/admin", s.handleSetUserAdminStatus)
|
||||
})
|
||||
r.Route("/settings", func(r chi.Router) {
|
||||
r.Get("/registration", s.handleGetRegistrationEnabled)
|
||||
r.Post("/registration", s.handleSetRegistrationEnabled)
|
||||
})
|
||||
})
|
||||
|
||||
// Runner API
|
||||
s.router.Route("/api/runner", func(r chi.Router) {
|
||||
// Registration doesn't require auth (uses token)
|
||||
r.Post("/register", s.handleRegisterRunner)
|
||||
r.With(middleware.Timeout(60*time.Second)).Post("/register", s.handleRegisterRunner)
|
||||
|
||||
// WebSocket endpoint (auth handled in handler)
|
||||
// WebSocket endpoint (auth handled in handler) - no timeout middleware
|
||||
r.Get("/ws", s.handleRunnerWebSocket)
|
||||
|
||||
// File operations still use HTTP (WebSocket not suitable for large files)
|
||||
@@ -156,7 +179,7 @@ func (s *Server) setupRoutes() {
|
||||
})
|
||||
r.Post("/tasks/{id}/progress", s.handleUpdateTaskProgress)
|
||||
r.Post("/tasks/{id}/steps", s.handleUpdateTaskStep)
|
||||
r.Get("/files/{jobId}/{fileName}", s.handleDownloadFileForRunner)
|
||||
r.Get("/files/{jobId}/*", s.handleDownloadFileForRunner)
|
||||
r.Post("/files/{jobId}/upload", s.handleUploadFileFromRunner)
|
||||
r.Get("/jobs/{jobId}/status", s.handleGetJobStatusForRunner)
|
||||
r.Get("/jobs/{jobId}/files", s.handleGetJobFilesForRunner)
|
||||
@@ -205,6 +228,11 @@ func (s *Server) handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
session, err := s.auth.GoogleCallback(r.Context(), code)
|
||||
if err != nil {
|
||||
// If registration is disabled, redirect back to login with error
|
||||
if err.Error() == "registration is disabled" {
|
||||
http.Redirect(w, r, "/?error=registration_disabled", http.StatusFound)
|
||||
return
|
||||
}
|
||||
s.respondError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -240,6 +268,11 @@ func (s *Server) handleDiscordCallback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
session, err := s.auth.DiscordCallback(r.Context(), code)
|
||||
if err != nil {
|
||||
// If registration is disabled, redirect back to login with error
|
||||
if err.Error() == "registration is disabled" {
|
||||
http.Redirect(w, r, "/?error=registration_disabled", http.StatusFound)
|
||||
return
|
||||
}
|
||||
s.respondError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -293,6 +326,166 @@ func (s *Server) handleGetMe(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleGetAuthProviders(w http.ResponseWriter, r *http.Request) {
|
||||
s.respondJSON(w, http.StatusOK, map[string]bool{
|
||||
"google": s.auth.IsGoogleOAuthConfigured(),
|
||||
"discord": s.auth.IsDiscordOAuthConfigured(),
|
||||
"local": s.auth.IsLocalLoginEnabled(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleLocalLoginAvailable(w http.ResponseWriter, r *http.Request) {
|
||||
s.respondJSON(w, http.StatusOK, map[string]bool{
|
||||
"available": s.auth.IsLocalLoginEnabled(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleLocalRegister(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Email == "" || req.Name == "" || req.Password == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "Email, name, and password are required")
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Password) < 8 {
|
||||
s.respondError(w, http.StatusBadRequest, "Password must be at least 8 characters long")
|
||||
return
|
||||
}
|
||||
|
||||
session, err := s.auth.RegisterLocalUser(req.Email, req.Name, req.Password)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionID := s.auth.CreateSession(session)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_id",
|
||||
Value: sessionID,
|
||||
Path: "/",
|
||||
MaxAge: 86400,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"message": "Registration successful",
|
||||
"user": map[string]interface{}{
|
||||
"id": session.UserID,
|
||||
"email": session.Email,
|
||||
"name": session.Name,
|
||||
"is_admin": session.IsAdmin,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleLocalLogin(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Username == "" || req.Password == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "Username and password are required")
|
||||
return
|
||||
}
|
||||
|
||||
session, err := s.auth.LocalLogin(req.Username, req.Password)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, "Invalid credentials")
|
||||
return
|
||||
}
|
||||
|
||||
sessionID := s.auth.CreateSession(session)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_id",
|
||||
Value: sessionID,
|
||||
Path: "/",
|
||||
MaxAge: 86400,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"message": "Login successful",
|
||||
"user": map[string]interface{}{
|
||||
"id": session.UserID,
|
||||
"email": session.Email,
|
||||
"name": session.Name,
|
||||
"is_admin": session.IsAdmin,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleChangePassword(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := getUserID(r)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
OldPassword string `json:"old_password"`
|
||||
NewPassword string `json:"new_password"`
|
||||
TargetUserID *int64 `json:"target_user_id,omitempty"` // For admin to change other users' passwords
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.NewPassword == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "New password is required")
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.NewPassword) < 8 {
|
||||
s.respondError(w, http.StatusBadRequest, "Password must be at least 8 characters long")
|
||||
return
|
||||
}
|
||||
|
||||
isAdmin := authpkg.IsAdmin(r.Context())
|
||||
|
||||
// If target_user_id is provided and user is admin, allow changing other user's password
|
||||
if req.TargetUserID != nil && isAdmin {
|
||||
if err := s.auth.AdminChangePassword(*req.TargetUserID, req.NewPassword); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Password changed successfully"})
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, user is changing their own password (requires old password)
|
||||
if req.OldPassword == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "Old password is required")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.auth.ChangePassword(userID, req.OldPassword, req.NewPassword); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Password changed successfully"})
|
||||
}
|
||||
|
||||
// Helper to get user ID from context
|
||||
func getUserID(r *http.Request) (int64, error) {
|
||||
userID, ok := authpkg.GetUserID(r.Context())
|
||||
@@ -315,6 +508,7 @@ func parseID(r *http.Request, param string) (int64, error) {
|
||||
// StartBackgroundTasks starts background goroutines for error recovery
|
||||
func (s *Server) StartBackgroundTasks() {
|
||||
go s.recoverStuckTasks()
|
||||
go s.cleanupOldMetadataJobs()
|
||||
}
|
||||
|
||||
// recoverStuckTasks periodically checks for dead runners and stuck tasks
|
||||
@@ -322,8 +516,8 @@ func (s *Server) recoverStuckTasks() {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Also distribute tasks every 5 seconds
|
||||
distributeTicker := time.NewTicker(5 * time.Second)
|
||||
// Also distribute tasks every 10 seconds (reduced frequency since we have event-driven distribution)
|
||||
distributeTicker := time.NewTicker(10 * time.Second)
|
||||
defer distributeTicker.Stop()
|
||||
|
||||
go func() {
|
||||
@@ -341,9 +535,10 @@ func (s *Server) recoverStuckTasks() {
|
||||
}()
|
||||
|
||||
// Find dead runners (no heartbeat for 90 seconds)
|
||||
// But only mark as dead if they're not actually connected via WebSocket
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id FROM runners
|
||||
WHERE last_heartbeat < datetime('now', '-90 seconds')
|
||||
WHERE last_heartbeat < CURRENT_TIMESTAMP - INTERVAL '90 seconds'
|
||||
AND status = ?`,
|
||||
types.RunnerStatusOnline,
|
||||
)
|
||||
@@ -354,12 +549,20 @@ func (s *Server) recoverStuckTasks() {
|
||||
defer rows.Close()
|
||||
|
||||
var deadRunnerIDs []int64
|
||||
s.runnerConnsMu.RLock()
|
||||
for rows.Next() {
|
||||
var runnerID int64
|
||||
if err := rows.Scan(&runnerID); err == nil {
|
||||
deadRunnerIDs = append(deadRunnerIDs, runnerID)
|
||||
// Only mark as dead if not actually connected via WebSocket
|
||||
// The WebSocket connection is the source of truth
|
||||
if _, stillConnected := s.runnerConns[runnerID]; !stillConnected {
|
||||
deadRunnerIDs = append(deadRunnerIDs, runnerID)
|
||||
}
|
||||
// If still connected, heartbeat should be updated by pong handler or heartbeat message
|
||||
// No need to manually update here - if it's stale, the pong handler isn't working
|
||||
}
|
||||
}
|
||||
s.runnerConnsMu.RUnlock()
|
||||
rows.Close()
|
||||
|
||||
if len(deadRunnerIDs) == 0 {
|
||||
@@ -370,67 +573,7 @@ func (s *Server) recoverStuckTasks() {
|
||||
|
||||
// Reset tasks assigned to dead runners
|
||||
for _, runnerID := range deadRunnerIDs {
|
||||
// Get tasks assigned to this runner
|
||||
taskRows, err := s.db.Query(
|
||||
`SELECT id, retry_count, max_retries FROM tasks
|
||||
WHERE runner_id = ? AND status = ?`,
|
||||
runnerID, types.TaskStatusRunning,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to query tasks for runner %d: %v", runnerID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
var tasksToReset []struct {
|
||||
ID int64
|
||||
RetryCount int
|
||||
MaxRetries int
|
||||
}
|
||||
|
||||
for taskRows.Next() {
|
||||
var t struct {
|
||||
ID int64
|
||||
RetryCount int
|
||||
MaxRetries int
|
||||
}
|
||||
if err := taskRows.Scan(&t.ID, &t.RetryCount, &t.MaxRetries); err == nil {
|
||||
tasksToReset = append(tasksToReset, t)
|
||||
}
|
||||
}
|
||||
taskRows.Close()
|
||||
|
||||
// Reset or fail tasks
|
||||
for _, task := range tasksToReset {
|
||||
if task.RetryCount >= task.MaxRetries {
|
||||
// Mark as failed
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL
|
||||
WHERE id = ?`,
|
||||
types.TaskStatusFailed, "Runner died, max retries exceeded", task.ID,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to mark task %d as failed: %v", task.ID, err)
|
||||
}
|
||||
} else {
|
||||
// Reset to pending
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL,
|
||||
retry_count = retry_count + 1 WHERE id = ?`,
|
||||
types.TaskStatusPending, task.ID,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to reset task %d: %v", task.ID, err)
|
||||
} else {
|
||||
// Add log entry
|
||||
_, _ = s.db.Exec(
|
||||
`INSERT INTO task_logs (task_id, log_level, message, step_name, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
task.ID, types.LogLevelWarn, fmt.Sprintf("Runner died, task reset (retry %d/%d)", task.RetryCount+1, task.MaxRetries),
|
||||
"", time.Now(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
s.redistributeRunnerTasks(runnerID)
|
||||
|
||||
// Mark runner as offline
|
||||
_, _ = s.db.Exec(
|
||||
@@ -457,7 +600,7 @@ func (s *Server) recoverTaskTimeouts() {
|
||||
WHERE t.status = ?
|
||||
AND t.started_at IS NOT NULL
|
||||
AND (t.timeout_seconds IS NULL OR
|
||||
datetime(t.started_at, '+' || t.timeout_seconds || ' seconds') < datetime('now'))`,
|
||||
t.started_at + INTERVAL (t.timeout_seconds || ' seconds') < CURRENT_TIMESTAMP)`,
|
||||
types.TaskStatusRunning,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -507,13 +650,8 @@ func (s *Server) recoverTaskTimeouts() {
|
||||
types.TaskStatusPending, taskID,
|
||||
)
|
||||
if err == nil {
|
||||
// Add log entry
|
||||
_, _ = s.db.Exec(
|
||||
`INSERT INTO task_logs (task_id, log_level, message, step_name, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
taskID, types.LogLevelWarn, fmt.Sprintf("Task timeout exceeded, resetting (retry %d/%d)", retryCount+1, maxRetries),
|
||||
"", time.Now(),
|
||||
)
|
||||
// Add log entry using the helper function
|
||||
s.logTaskEvent(taskID, nil, types.LogLevelWarn, fmt.Sprintf("Task timeout exceeded, resetting (retry %d/%d)", retryCount+1, maxRetries), "")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user