massive changes and it works

This commit is contained in:
2025-11-23 10:58:24 -06:00
parent 30aa969433
commit 2a0ff98834
3499 changed files with 7770 additions and 634687 deletions

View File

@@ -7,7 +7,7 @@ import (
"net/http"
"time"
"fuego/pkg/types"
"jiggablend/pkg/types"
)
// handleGenerateRegistrationToken generates a new registration token
@@ -126,7 +126,7 @@ func (s *Server) handleDeleteRunner(w http.ResponseWriter, r *http.Request) {
func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request) {
rows, err := s.db.Query(
`SELECT id, name, hostname, ip_address, status, last_heartbeat, capabilities,
registration_token, verified, created_at
registration_token, verified, priority, created_at
FROM runners ORDER BY created_at DESC`,
)
if err != nil {
@@ -144,7 +144,7 @@ func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request)
err := rows.Scan(
&runner.ID, &runner.Name, &runner.Hostname, &runner.IPAddress,
&runner.Status, &runner.LastHeartbeat, &runner.Capabilities,
&registrationToken, &verified, &runner.CreatedAt,
&registrationToken, &verified, &runner.Priority, &runner.CreatedAt,
)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan runner: %v", err))
@@ -161,9 +161,206 @@ func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request)
"capabilities": runner.Capabilities,
"registration_token": registrationToken.String,
"verified": verified,
"priority": runner.Priority,
"created_at": runner.CreatedAt,
})
}
s.respondJSON(w, http.StatusOK, runners)
}
// handleListUsers lists all users
func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) {
// Get first user ID to mark it in the response
firstUserID, err := s.auth.GetFirstUserID()
if err != nil {
// If no users exist, firstUserID will be 0, which is fine
firstUserID = 0
}
rows, err := s.db.Query(
`SELECT id, email, name, oauth_provider, is_admin, created_at
FROM users ORDER BY created_at DESC`,
)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query users: %v", err))
return
}
defer rows.Close()
users := []map[string]interface{}{}
for rows.Next() {
var userID int64
var email, name, oauthProvider string
var isAdmin bool
var createdAt time.Time
err := rows.Scan(&userID, &email, &name, &oauthProvider, &isAdmin, &createdAt)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan user: %v", err))
return
}
// Get job count for this user
var jobCount int
err = s.db.QueryRow("SELECT COUNT(*) FROM jobs WHERE user_id = ?", userID).Scan(&jobCount)
if err != nil {
jobCount = 0 // Default to 0 if query fails
}
users = append(users, map[string]interface{}{
"id": userID,
"email": email,
"name": name,
"oauth_provider": oauthProvider,
"is_admin": isAdmin,
"created_at": createdAt,
"job_count": jobCount,
"is_first_user": userID == firstUserID,
})
}
s.respondJSON(w, http.StatusOK, users)
}
// handleGetUserJobs gets all jobs for a specific user
func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) {
userID, err := parseID(r, "id")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
// Verify user exists
var exists bool
err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", userID).Scan(&exists)
if err != nil || !exists {
s.respondError(w, http.StatusNotFound, "User not found")
return
}
rows, err := s.db.Query(
`SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format,
allow_parallel_runners, timeout_seconds, blend_metadata, created_at, started_at, completed_at, error_message
FROM jobs WHERE user_id = ? ORDER BY created_at DESC`,
userID,
)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query jobs: %v", err))
return
}
defer rows.Close()
jobs := []types.Job{}
for rows.Next() {
var job types.Job
var jobType string
var startedAt, completedAt sql.NullTime
var blendMetadataJSON sql.NullString
var errorMessage sql.NullString
var frameStart, frameEnd sql.NullInt64
var outputFormat sql.NullString
var allowParallelRunners sql.NullBool
err := rows.Scan(
&job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress,
&frameStart, &frameEnd, &outputFormat, &allowParallelRunners, &job.TimeoutSeconds,
&blendMetadataJSON, &job.CreatedAt, &startedAt, &completedAt, &errorMessage,
)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan job: %v", err))
return
}
job.JobType = types.JobType(jobType)
if frameStart.Valid {
fs := int(frameStart.Int64)
job.FrameStart = &fs
}
if frameEnd.Valid {
fe := int(frameEnd.Int64)
job.FrameEnd = &fe
}
if outputFormat.Valid {
job.OutputFormat = &outputFormat.String
}
if allowParallelRunners.Valid {
job.AllowParallelRunners = &allowParallelRunners.Bool
}
if startedAt.Valid {
job.StartedAt = &startedAt.Time
}
if completedAt.Valid {
job.CompletedAt = &completedAt.Time
}
if blendMetadataJSON.Valid && blendMetadataJSON.String != "" {
var metadata types.BlendMetadata
if err := json.Unmarshal([]byte(blendMetadataJSON.String), &metadata); err == nil {
job.BlendMetadata = &metadata
}
}
if errorMessage.Valid {
job.ErrorMessage = errorMessage.String
}
jobs = append(jobs, job)
}
s.respondJSON(w, http.StatusOK, jobs)
}
// handleGetRegistrationEnabled gets the registration enabled setting
func (s *Server) handleGetRegistrationEnabled(w http.ResponseWriter, r *http.Request) {
enabled, err := s.auth.IsRegistrationEnabled()
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to get registration setting: %v", err))
return
}
s.respondJSON(w, http.StatusOK, map[string]bool{"enabled": enabled})
}
// handleSetRegistrationEnabled sets the registration enabled setting
func (s *Server) handleSetRegistrationEnabled(w http.ResponseWriter, r *http.Request) {
var req struct {
Enabled bool `json:"enabled"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
if err := s.auth.SetRegistrationEnabled(req.Enabled); err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to set registration setting: %v", err))
return
}
s.respondJSON(w, http.StatusOK, map[string]bool{"enabled": req.Enabled})
}
// handleSetUserAdminStatus sets a user's admin status (admin only)
func (s *Server) handleSetUserAdminStatus(w http.ResponseWriter, r *http.Request) {
targetUserID, err := parseID(r, "id")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
var req struct {
IsAdmin bool `json:"is_admin"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
if err := s.auth.SetUserAdminStatus(targetUserID, req.IsAdmin); err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
s.respondJSON(w, http.StatusOK, map[string]interface{}{
"user_id": targetUserID,
"is_admin": req.IsAdmin,
"message": "Admin status updated successfully",
})
}

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,7 @@ import (
"log"
"net/http"
"fuego/pkg/types"
"jiggablend/pkg/types"
)
// handleSubmitMetadata handles metadata submission from runner
@@ -19,7 +19,7 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) {
}
// Get runner ID from context (set by runnerAuthMiddleware)
runnerID, ok := r.Context().Value("runner_id").(int64)
runnerID, ok := r.Context().Value(runnerIDContextKey).(int64)
if !ok {
s.respondError(w, http.StatusUnauthorized, "runner_id not found in context")
return
@@ -44,16 +44,32 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) {
}
// Find the metadata extraction task for this job
// First try to find task assigned to this runner, then fall back to any metadata task for this job
var taskID int64
err = s.db.QueryRow(
`SELECT id FROM tasks WHERE job_id = ? AND task_type = ? AND runner_id = ?`,
jobID, types.TaskTypeMetadata, runnerID,
).Scan(&taskID)
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "Metadata extraction task not found")
return
}
if err != nil {
// Fall back to any metadata task for this job (in case assignment changed)
err = s.db.QueryRow(
`SELECT id FROM tasks WHERE job_id = ? AND task_type = ? ORDER BY created_at DESC LIMIT 1`,
jobID, types.TaskTypeMetadata,
).Scan(&taskID)
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "Metadata extraction task not found")
return
}
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to find task: %v", err))
return
}
// Update the task to be assigned to this runner if it wasn't already
s.db.Exec(
`UPDATE tasks SET runner_id = ? WHERE id = ? AND runner_id IS NULL`,
runnerID, taskID,
)
} else if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to find task: %v", err))
return
}
@@ -82,6 +98,9 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) {
)
if err != nil {
log.Printf("Failed to mark metadata task as completed: %v", err)
} else {
// Update job status and progress after metadata task completes
s.updateJobStatusFromTasks(jobID)
}
log.Printf("Metadata extracted for job %d: frame_start=%d, frame_end=%d", jobID, metadata.FrameStart, metadata.FrameEnd)

File diff suppressed because it is too large Load Diff

View File

@@ -10,10 +10,10 @@ import (
"sync"
"time"
authpkg "fuego/internal/auth"
"fuego/internal/database"
"fuego/internal/storage"
"fuego/pkg/types"
authpkg "jiggablend/internal/auth"
"jiggablend/internal/database"
"jiggablend/internal/storage"
"jiggablend/pkg/types"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
@@ -35,6 +35,12 @@ type Server struct {
runnerConnsMu sync.RWMutex
frontendConns map[string]*websocket.Conn // key: "jobId:taskId"
frontendConnsMu sync.RWMutex
// Mutexes for each frontend connection to serialize writes
frontendConnsWriteMu map[string]*sync.Mutex // key: "jobId:taskId"
frontendConnsWriteMuMu sync.RWMutex
// Throttling for progress updates (per job)
progressUpdateTimes map[int64]time.Time // key: jobID
progressUpdateTimesMu sync.RWMutex
}
// NewServer creates a new API server
@@ -57,8 +63,10 @@ func NewServer(db *database.DB, auth *authpkg.Auth, storage *storage.Storage) (*
ReadBufferSize: 1024,
WriteBufferSize: 1024,
},
runnerConns: make(map[int64]*websocket.Conn),
frontendConns: make(map[string]*websocket.Conn),
runnerConns: make(map[int64]*websocket.Conn),
frontendConns: make(map[string]*websocket.Conn),
frontendConnsWriteMu: make(map[string]*sync.Mutex),
progressUpdateTimes: make(map[int64]time.Time),
}
s.setupMiddleware()
@@ -72,13 +80,14 @@ func NewServer(db *database.DB, auth *authpkg.Auth, storage *storage.Storage) (*
func (s *Server) setupMiddleware() {
s.router.Use(middleware.Logger)
s.router.Use(middleware.Recoverer)
s.router.Use(middleware.Timeout(60 * time.Second))
// Note: Timeout middleware is NOT applied globally to avoid conflicts with WebSocket connections
// WebSocket connections are long-lived and should not have HTTP timeouts
s.router.Use(cors.Handler(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"},
ExposedHeaders: []string{"Link"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "Range"},
ExposedHeaders: []string{"Link", "Content-Range", "Accept-Ranges", "Content-Length"},
AllowCredentials: true,
MaxAge: 300,
}))
@@ -88,12 +97,17 @@ func (s *Server) setupMiddleware() {
func (s *Server) setupRoutes() {
// Public routes
s.router.Route("/api/auth", func(r chi.Router) {
r.Get("/providers", s.handleGetAuthProviders)
r.Get("/google/login", s.handleGoogleLogin)
r.Get("/google/callback", s.handleGoogleCallback)
r.Get("/discord/login", s.handleDiscordLogin)
r.Get("/discord/callback", s.handleDiscordCallback)
r.Get("/local/available", s.handleLocalLoginAvailable)
r.Post("/local/register", s.handleLocalRegister)
r.Post("/local/login", s.handleLocalLogin)
r.Post("/logout", s.handleLogout)
r.Get("/me", s.handleGetMe)
r.Post("/change-password", s.handleChangePassword)
})
// Protected routes
@@ -105,6 +119,7 @@ func (s *Server) setupRoutes() {
r.Get("/", s.handleListJobs)
r.Get("/{id}", s.handleGetJob)
r.Delete("/{id}", s.handleCancelJob)
r.Post("/{id}/delete", s.handleDeleteJob)
r.Post("/{id}/upload", s.handleUploadJobFile)
r.Get("/{id}/files", s.handleListJobFiles)
r.Get("/{id}/files/{fileId}/download", s.handleDownloadJobFile)
@@ -112,18 +127,17 @@ func (s *Server) setupRoutes() {
r.Get("/{id}/metadata", s.handleGetJobMetadata)
r.Get("/{id}/tasks", s.handleListJobTasks)
r.Get("/{id}/tasks/{taskId}/logs", s.handleGetTaskLogs)
r.Get("/{id}/tasks/{taskId}/logs/ws", s.handleStreamTaskLogsWebSocket)
// WebSocket route - no timeout middleware (long-lived connection)
r.With(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Remove timeout middleware for WebSocket
next.ServeHTTP(w, r)
})
}).Get("/{id}/tasks/{taskId}/logs/ws", s.handleStreamTaskLogsWebSocket)
r.Get("/{id}/tasks/{taskId}/steps", s.handleGetTaskSteps)
r.Post("/{id}/tasks/{taskId}/retry", s.handleRetryTask)
})
s.router.Route("/api/runners", func(r chi.Router) {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(s.auth.Middleware(next.ServeHTTP))
})
r.Get("/", s.handleListRunners)
})
// Admin routes
s.router.Route("/api/admin", func(r chi.Router) {
r.Use(func(next http.Handler) http.Handler {
@@ -139,14 +153,23 @@ func (s *Server) setupRoutes() {
r.Post("/{id}/verify", s.handleVerifyRunner)
r.Delete("/{id}", s.handleDeleteRunner)
})
r.Route("/users", func(r chi.Router) {
r.Get("/", s.handleListUsers)
r.Get("/{id}/jobs", s.handleGetUserJobs)
r.Post("/{id}/admin", s.handleSetUserAdminStatus)
})
r.Route("/settings", func(r chi.Router) {
r.Get("/registration", s.handleGetRegistrationEnabled)
r.Post("/registration", s.handleSetRegistrationEnabled)
})
})
// Runner API
s.router.Route("/api/runner", func(r chi.Router) {
// Registration doesn't require auth (uses token)
r.Post("/register", s.handleRegisterRunner)
r.With(middleware.Timeout(60*time.Second)).Post("/register", s.handleRegisterRunner)
// WebSocket endpoint (auth handled in handler)
// WebSocket endpoint (auth handled in handler) - no timeout middleware
r.Get("/ws", s.handleRunnerWebSocket)
// File operations still use HTTP (WebSocket not suitable for large files)
@@ -156,7 +179,7 @@ func (s *Server) setupRoutes() {
})
r.Post("/tasks/{id}/progress", s.handleUpdateTaskProgress)
r.Post("/tasks/{id}/steps", s.handleUpdateTaskStep)
r.Get("/files/{jobId}/{fileName}", s.handleDownloadFileForRunner)
r.Get("/files/{jobId}/*", s.handleDownloadFileForRunner)
r.Post("/files/{jobId}/upload", s.handleUploadFileFromRunner)
r.Get("/jobs/{jobId}/status", s.handleGetJobStatusForRunner)
r.Get("/jobs/{jobId}/files", s.handleGetJobFilesForRunner)
@@ -205,6 +228,11 @@ func (s *Server) handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
session, err := s.auth.GoogleCallback(r.Context(), code)
if err != nil {
// If registration is disabled, redirect back to login with error
if err.Error() == "registration is disabled" {
http.Redirect(w, r, "/?error=registration_disabled", http.StatusFound)
return
}
s.respondError(w, http.StatusInternalServerError, err.Error())
return
}
@@ -240,6 +268,11 @@ func (s *Server) handleDiscordCallback(w http.ResponseWriter, r *http.Request) {
session, err := s.auth.DiscordCallback(r.Context(), code)
if err != nil {
// If registration is disabled, redirect back to login with error
if err.Error() == "registration is disabled" {
http.Redirect(w, r, "/?error=registration_disabled", http.StatusFound)
return
}
s.respondError(w, http.StatusInternalServerError, err.Error())
return
}
@@ -293,6 +326,166 @@ func (s *Server) handleGetMe(w http.ResponseWriter, r *http.Request) {
})
}
func (s *Server) handleGetAuthProviders(w http.ResponseWriter, r *http.Request) {
s.respondJSON(w, http.StatusOK, map[string]bool{
"google": s.auth.IsGoogleOAuthConfigured(),
"discord": s.auth.IsDiscordOAuthConfigured(),
"local": s.auth.IsLocalLoginEnabled(),
})
}
func (s *Server) handleLocalLoginAvailable(w http.ResponseWriter, r *http.Request) {
s.respondJSON(w, http.StatusOK, map[string]bool{
"available": s.auth.IsLocalLoginEnabled(),
})
}
func (s *Server) handleLocalRegister(w http.ResponseWriter, r *http.Request) {
var req struct {
Email string `json:"email"`
Name string `json:"name"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
if req.Email == "" || req.Name == "" || req.Password == "" {
s.respondError(w, http.StatusBadRequest, "Email, name, and password are required")
return
}
if len(req.Password) < 8 {
s.respondError(w, http.StatusBadRequest, "Password must be at least 8 characters long")
return
}
session, err := s.auth.RegisterLocalUser(req.Email, req.Name, req.Password)
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: sessionID,
Path: "/",
MaxAge: 86400,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
"message": "Registration successful",
"user": map[string]interface{}{
"id": session.UserID,
"email": session.Email,
"name": session.Name,
"is_admin": session.IsAdmin,
},
})
}
func (s *Server) handleLocalLogin(w http.ResponseWriter, r *http.Request) {
var req struct {
Username string `json:"username"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
if req.Username == "" || req.Password == "" {
s.respondError(w, http.StatusBadRequest, "Username and password are required")
return
}
session, err := s.auth.LocalLogin(req.Username, req.Password)
if err != nil {
s.respondError(w, http.StatusUnauthorized, "Invalid credentials")
return
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: sessionID,
Path: "/",
MaxAge: 86400,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
s.respondJSON(w, http.StatusOK, map[string]interface{}{
"message": "Login successful",
"user": map[string]interface{}{
"id": session.UserID,
"email": session.Email,
"name": session.Name,
"is_admin": session.IsAdmin,
},
})
}
func (s *Server) handleChangePassword(w http.ResponseWriter, r *http.Request) {
userID, err := getUserID(r)
if err != nil {
s.respondError(w, http.StatusUnauthorized, err.Error())
return
}
var req struct {
OldPassword string `json:"old_password"`
NewPassword string `json:"new_password"`
TargetUserID *int64 `json:"target_user_id,omitempty"` // For admin to change other users' passwords
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
if req.NewPassword == "" {
s.respondError(w, http.StatusBadRequest, "New password is required")
return
}
if len(req.NewPassword) < 8 {
s.respondError(w, http.StatusBadRequest, "Password must be at least 8 characters long")
return
}
isAdmin := authpkg.IsAdmin(r.Context())
// If target_user_id is provided and user is admin, allow changing other user's password
if req.TargetUserID != nil && isAdmin {
if err := s.auth.AdminChangePassword(*req.TargetUserID, req.NewPassword); err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Password changed successfully"})
return
}
// Otherwise, user is changing their own password (requires old password)
if req.OldPassword == "" {
s.respondError(w, http.StatusBadRequest, "Old password is required")
return
}
if err := s.auth.ChangePassword(userID, req.OldPassword, req.NewPassword); err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Password changed successfully"})
}
// Helper to get user ID from context
func getUserID(r *http.Request) (int64, error) {
userID, ok := authpkg.GetUserID(r.Context())
@@ -315,6 +508,7 @@ func parseID(r *http.Request, param string) (int64, error) {
// StartBackgroundTasks starts background goroutines for error recovery
func (s *Server) StartBackgroundTasks() {
go s.recoverStuckTasks()
go s.cleanupOldMetadataJobs()
}
// recoverStuckTasks periodically checks for dead runners and stuck tasks
@@ -322,8 +516,8 @@ func (s *Server) recoverStuckTasks() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
// Also distribute tasks every 5 seconds
distributeTicker := time.NewTicker(5 * time.Second)
// Also distribute tasks every 10 seconds (reduced frequency since we have event-driven distribution)
distributeTicker := time.NewTicker(10 * time.Second)
defer distributeTicker.Stop()
go func() {
@@ -341,9 +535,10 @@ func (s *Server) recoverStuckTasks() {
}()
// Find dead runners (no heartbeat for 90 seconds)
// But only mark as dead if they're not actually connected via WebSocket
rows, err := s.db.Query(
`SELECT id FROM runners
WHERE last_heartbeat < datetime('now', '-90 seconds')
WHERE last_heartbeat < CURRENT_TIMESTAMP - INTERVAL '90 seconds'
AND status = ?`,
types.RunnerStatusOnline,
)
@@ -354,12 +549,20 @@ func (s *Server) recoverStuckTasks() {
defer rows.Close()
var deadRunnerIDs []int64
s.runnerConnsMu.RLock()
for rows.Next() {
var runnerID int64
if err := rows.Scan(&runnerID); err == nil {
deadRunnerIDs = append(deadRunnerIDs, runnerID)
// Only mark as dead if not actually connected via WebSocket
// The WebSocket connection is the source of truth
if _, stillConnected := s.runnerConns[runnerID]; !stillConnected {
deadRunnerIDs = append(deadRunnerIDs, runnerID)
}
// If still connected, heartbeat should be updated by pong handler or heartbeat message
// No need to manually update here - if it's stale, the pong handler isn't working
}
}
s.runnerConnsMu.RUnlock()
rows.Close()
if len(deadRunnerIDs) == 0 {
@@ -370,67 +573,7 @@ func (s *Server) recoverStuckTasks() {
// Reset tasks assigned to dead runners
for _, runnerID := range deadRunnerIDs {
// Get tasks assigned to this runner
taskRows, err := s.db.Query(
`SELECT id, retry_count, max_retries FROM tasks
WHERE runner_id = ? AND status = ?`,
runnerID, types.TaskStatusRunning,
)
if err != nil {
log.Printf("Failed to query tasks for runner %d: %v", runnerID, err)
continue
}
var tasksToReset []struct {
ID int64
RetryCount int
MaxRetries int
}
for taskRows.Next() {
var t struct {
ID int64
RetryCount int
MaxRetries int
}
if err := taskRows.Scan(&t.ID, &t.RetryCount, &t.MaxRetries); err == nil {
tasksToReset = append(tasksToReset, t)
}
}
taskRows.Close()
// Reset or fail tasks
for _, task := range tasksToReset {
if task.RetryCount >= task.MaxRetries {
// Mark as failed
_, err = s.db.Exec(
`UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL
WHERE id = ?`,
types.TaskStatusFailed, "Runner died, max retries exceeded", task.ID,
)
if err != nil {
log.Printf("Failed to mark task %d as failed: %v", task.ID, err)
}
} else {
// Reset to pending
_, err = s.db.Exec(
`UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL,
retry_count = retry_count + 1 WHERE id = ?`,
types.TaskStatusPending, task.ID,
)
if err != nil {
log.Printf("Failed to reset task %d: %v", task.ID, err)
} else {
// Add log entry
_, _ = s.db.Exec(
`INSERT INTO task_logs (task_id, log_level, message, step_name, created_at)
VALUES (?, ?, ?, ?, ?)`,
task.ID, types.LogLevelWarn, fmt.Sprintf("Runner died, task reset (retry %d/%d)", task.RetryCount+1, task.MaxRetries),
"", time.Now(),
)
}
}
}
s.redistributeRunnerTasks(runnerID)
// Mark runner as offline
_, _ = s.db.Exec(
@@ -457,7 +600,7 @@ func (s *Server) recoverTaskTimeouts() {
WHERE t.status = ?
AND t.started_at IS NOT NULL
AND (t.timeout_seconds IS NULL OR
datetime(t.started_at, '+' || t.timeout_seconds || ' seconds') < datetime('now'))`,
t.started_at + INTERVAL (t.timeout_seconds || ' seconds') < CURRENT_TIMESTAMP)`,
types.TaskStatusRunning,
)
if err != nil {
@@ -507,13 +650,8 @@ func (s *Server) recoverTaskTimeouts() {
types.TaskStatusPending, taskID,
)
if err == nil {
// Add log entry
_, _ = s.db.Exec(
`INSERT INTO task_logs (task_id, log_level, message, step_name, created_at)
VALUES (?, ?, ?, ?, ?)`,
taskID, types.LogLevelWarn, fmt.Sprintf("Task timeout exceeded, resetting (retry %d/%d)", retryCount+1, maxRetries),
"", time.Now(),
)
// Add log entry using the helper function
s.logTaskEvent(taskID, nil, types.LogLevelWarn, fmt.Sprintf("Task timeout exceeded, resetting (retry %d/%d)", retryCount+1, maxRetries), "")
}
}
}

View File

@@ -5,21 +5,24 @@ import (
"database/sql"
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"strings"
"time"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
// Auth handles authentication
type Auth struct {
db *sql.DB
googleConfig *oauth2.Config
discordConfig *oauth2.Config
sessionStore map[string]*Session
db *sql.DB
googleConfig *oauth2.Config
discordConfig *oauth2.Config
sessionStore map[string]*Session
}
// Session represents a user session
@@ -67,9 +70,95 @@ func NewAuth(db *sql.DB) (*Auth, error) {
}
}
// Initialize admin settings on startup to ensure they persist between boots
if err := auth.initializeSettings(); err != nil {
log.Printf("Warning: Failed to initialize admin settings: %v", err)
// Don't fail startup, but log the warning
}
// Initialize test local user from environment variables (for testing only)
if err := auth.initializeTestUser(); err != nil {
log.Printf("Warning: Failed to initialize test user: %v", err)
// Don't fail startup, but log the warning
}
return auth, nil
}
// initializeSettings ensures all admin settings are initialized with defaults if they don't exist
func (a *Auth) initializeSettings() error {
// Initialize registration_enabled setting (default: true) if it doesn't exist
var settingCount int
err := a.db.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
if err != nil {
return fmt.Errorf("failed to check registration_enabled setting: %w", err)
}
if settingCount == 0 {
_, err = a.db.Exec(
`INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)`,
"registration_enabled", "true",
)
if err != nil {
return fmt.Errorf("failed to initialize registration_enabled setting: %w", err)
}
log.Printf("Initialized admin setting: registration_enabled = true")
}
return nil
}
// initializeTestUser creates a test local user from environment variables (for testing only)
func (a *Auth) initializeTestUser() error {
testEmail := os.Getenv("LOCAL_TEST_EMAIL")
testPassword := os.Getenv("LOCAL_TEST_PASSWORD")
if testEmail == "" || testPassword == "" {
// No test user configured, skip
return nil
}
// Check if user already exists
var exists bool
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND oauth_provider = 'local')", testEmail).Scan(&exists)
if err != nil {
return fmt.Errorf("failed to check if test user exists: %w", err)
}
if exists {
// User already exists, skip creation
log.Printf("Test user %s already exists, skipping creation", testEmail)
return nil
}
// Hash password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(testPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash test user password: %w", err)
}
// Check if this is the first user (make them admin)
var userCount int
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
isAdmin := userCount == 0
// Create test user (use email as name if no name is provided)
testName := testEmail
if atIndex := strings.Index(testEmail, "@"); atIndex > 0 {
testName = testEmail[:atIndex]
}
// Create test user
_, err = a.db.Exec(
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?)",
testEmail, testName, testEmail, string(hashedPassword), isAdmin,
)
if err != nil {
return fmt.Errorf("failed to create test user: %w", err)
}
log.Printf("Created test user: %s (admin: %v)", testEmail, isAdmin)
return nil
}
// GoogleLoginURL returns the Google OAuth login URL
func (a *Auth) GoogleLoginURL() (string, error) {
if a.googleConfig == nil {
@@ -150,36 +239,119 @@ func (a *Auth) DiscordCallback(ctx context.Context, code string) (*Session, erro
return a.getOrCreateUser("discord", userInfo.ID, userInfo.Email, userInfo.Username)
}
// IsRegistrationEnabled checks if new user registration is enabled
func (a *Auth) IsRegistrationEnabled() (bool, error) {
var value string
err := a.db.QueryRow("SELECT value FROM settings WHERE key = ?", "registration_enabled").Scan(&value)
if err == sql.ErrNoRows {
// Default to enabled if setting doesn't exist
return true, nil
}
if err != nil {
return false, fmt.Errorf("failed to check registration setting: %w", err)
}
return value == "true", nil
}
// SetRegistrationEnabled sets whether new user registration is enabled
func (a *Auth) SetRegistrationEnabled(enabled bool) error {
value := "false"
if enabled {
value = "true"
}
// Check if setting exists
var exists bool
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", "registration_enabled").Scan(&exists)
if err != nil {
return fmt.Errorf("failed to check if setting exists: %w", err)
}
if exists {
// Update existing setting
_, err = a.db.Exec(
"UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?",
value, "registration_enabled",
)
if err != nil {
return fmt.Errorf("failed to update setting: %w", err)
}
} else {
// Insert new setting
_, err = a.db.Exec(
"INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)",
"registration_enabled", value,
)
if err != nil {
return fmt.Errorf("failed to insert setting: %w", err)
}
}
return nil
}
// getOrCreateUser gets or creates a user in the database
// Automatically links accounts by email across different OAuth providers and local login
func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session, error) {
var userID int64
var dbEmail, dbName string
var isAdmin bool
var dbProvider, dbOAuthID string
// First, try to find by provider + oauth_id
err := a.db.QueryRow(
"SELECT id, email, name, is_admin FROM users WHERE oauth_provider = ? AND oauth_id = ?",
"SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE oauth_provider = ? AND oauth_id = ?",
provider, oauthID,
).Scan(&userID, &dbEmail, &dbName, &isAdmin)
).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID)
if err == sql.ErrNoRows {
// Check if this is the first user
var userCount int
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
isAdmin = userCount == 0
// Not found by provider+oauth_id, check by email for account linking
err = a.db.QueryRow(
"SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE email = ?",
email,
).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID)
// Create new user
result, err := a.db.Exec(
"INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?)",
email, name, provider, oauthID, isAdmin,
)
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
if err == sql.ErrNoRows {
// User doesn't exist, check if registration is enabled
registrationEnabled, err := a.IsRegistrationEnabled()
if err != nil {
return nil, fmt.Errorf("failed to check registration setting: %w", err)
}
if !registrationEnabled {
return nil, fmt.Errorf("registration is disabled")
}
// Check if this is the first user
var userCount int
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
isAdmin = userCount == 0
// Create new user
err = a.db.QueryRow(
"INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?) RETURNING id",
email, name, provider, oauthID, isAdmin,
).Scan(&userID)
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
} else if err != nil {
return nil, fmt.Errorf("failed to query user by email: %w", err)
} else {
// User exists with same email but different provider - link accounts by updating provider info
// This allows the user to log in with any provider that has the same email
_, err = a.db.Exec(
"UPDATE users SET oauth_provider = ?, oauth_id = ?, name = ? WHERE id = ?",
provider, oauthID, name, userID,
)
if err != nil {
return nil, fmt.Errorf("failed to link account: %w", err)
}
log.Printf("Linked account: user %d (email: %s) now accessible via %s provider", userID, email, provider)
}
userID, _ = result.LastInsertId()
} else if err != nil {
return nil, fmt.Errorf("failed to query user: %w", err)
} else {
// Update user info if changed
// User found by provider+oauth_id, update info if changed
if dbEmail != email || dbName != name {
_, err = a.db.Exec(
"UPDATE users SET email = ?, name = ? WHERE id = ?",
@@ -238,13 +410,17 @@ func (a *Auth) Middleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("session_id")
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
session, ok := a.GetSession(cookie.Value)
if !ok {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
@@ -275,19 +451,25 @@ func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
// First check authentication
cookie, err := r.Cookie("session_id")
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
session, ok := a.GetSession(cookie.Value)
if !ok {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
// Then check admin status
if !session.IsAdmin {
http.Error(w, "Forbidden: Admin access required", http.StatusForbidden)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
json.NewEncoder(w).Encode(map[string]string{"error": "Forbidden: Admin access required"})
return
}
@@ -300,3 +482,221 @@ func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
}
}
// IsLocalLoginEnabled returns whether local login is enabled
// Local login is enabled when ENABLE_LOCAL_AUTH environment variable is set to "true"
func (a *Auth) IsLocalLoginEnabled() bool {
return os.Getenv("ENABLE_LOCAL_AUTH") == "true"
}
// IsGoogleOAuthConfigured returns whether Google OAuth is configured
func (a *Auth) IsGoogleOAuthConfigured() bool {
return a.googleConfig != nil
}
// IsDiscordOAuthConfigured returns whether Discord OAuth is configured
func (a *Auth) IsDiscordOAuthConfigured() bool {
return a.discordConfig != nil
}
// LocalLogin handles local username/password authentication
func (a *Auth) LocalLogin(username, password string) (*Session, error) {
// Find user by email (local users use email as username)
email := username
var userID int64
var dbEmail, dbName, passwordHash string
var isAdmin bool
err := a.db.QueryRow(
"SELECT id, email, name, password_hash, is_admin FROM users WHERE email = ? AND oauth_provider = 'local'",
email,
).Scan(&userID, &dbEmail, &dbName, &passwordHash, &isAdmin)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("invalid credentials")
}
if err != nil {
return nil, fmt.Errorf("failed to query user: %w", err)
}
// Verify password
if passwordHash == "" {
return nil, fmt.Errorf("invalid credentials")
}
err = bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password))
if err != nil {
return nil, fmt.Errorf("invalid credentials")
}
// Create session
session := &Session{
UserID: userID,
Email: dbEmail,
Name: dbName,
IsAdmin: isAdmin,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
return session, nil
}
// RegisterLocalUser creates a new local user account
func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error) {
// Check if registration is enabled
registrationEnabled, err := a.IsRegistrationEnabled()
if err != nil {
return nil, fmt.Errorf("failed to check registration setting: %w", err)
}
if !registrationEnabled {
return nil, fmt.Errorf("registration is disabled")
}
// Check if user already exists
var exists bool
err = a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ?)", email).Scan(&exists)
if err != nil {
return nil, fmt.Errorf("failed to check if user exists: %w", err)
}
if exists {
return nil, fmt.Errorf("user with this email already exists")
}
// Hash password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("failed to hash password: %w", err)
}
// Check if this is the first user (make them admin)
var userCount int
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
isAdmin := userCount == 0
// Create user
var userID int64
err = a.db.QueryRow(
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?) RETURNING id",
email, name, email, string(hashedPassword), isAdmin,
).Scan(&userID)
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
// Create session
session := &Session{
UserID: userID,
Email: email,
Name: name,
IsAdmin: isAdmin,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
return session, nil
}
// ChangePassword allows a user to change their own password
func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) error {
// Get current password hash
var passwordHash string
err := a.db.QueryRow("SELECT password_hash FROM users WHERE id = ? AND oauth_provider = 'local'", userID).Scan(&passwordHash)
if err == sql.ErrNoRows {
return fmt.Errorf("user not found or not a local user")
}
if err != nil {
return fmt.Errorf("failed to query user: %w", err)
}
// Verify old password
if passwordHash == "" {
return fmt.Errorf("user has no password set")
}
err = bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(oldPassword))
if err != nil {
return fmt.Errorf("incorrect old password")
}
// Hash new password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %w", err)
}
// Update password
_, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), userID)
if err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
return nil
}
// AdminChangePassword allows an admin to change any user's password without knowing the old password
func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error {
// Verify user exists and is a local user
var exists bool
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ? AND oauth_provider = 'local')", targetUserID).Scan(&exists)
if err != nil {
return fmt.Errorf("failed to check if user exists: %w", err)
}
if !exists {
return fmt.Errorf("user not found or not a local user")
}
// Hash new password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %w", err)
}
// Update password
_, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), targetUserID)
if err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
return nil
}
// GetFirstUserID returns the ID of the first user (user with the lowest ID)
func (a *Auth) GetFirstUserID() (int64, error) {
var firstUserID int64
err := a.db.QueryRow("SELECT id FROM users ORDER BY id ASC LIMIT 1").Scan(&firstUserID)
if err == sql.ErrNoRows {
return 0, fmt.Errorf("no users found")
}
if err != nil {
return 0, fmt.Errorf("failed to get first user ID: %w", err)
}
return firstUserID, nil
}
// SetUserAdminStatus allows an admin to change a user's admin status
func (a *Auth) SetUserAdminStatus(targetUserID int64, isAdmin bool) error {
// Verify user exists
var exists bool
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", targetUserID).Scan(&exists)
if err != nil {
return fmt.Errorf("failed to check if user exists: %w", err)
}
if !exists {
return fmt.Errorf("user not found")
}
// Prevent removing admin status from the first user
firstUserID, err := a.GetFirstUserID()
if err != nil {
return fmt.Errorf("failed to check first user: %w", err)
}
if targetUserID == firstUserID && !isAdmin {
return fmt.Errorf("cannot remove admin status from the first user")
}
// Update admin status
_, err = a.db.Exec("UPDATE users SET is_admin = ? WHERE id = ?", isAdmin, targetUserID)
if err != nil {
return fmt.Errorf("failed to update admin status: %w", err)
}
return nil
}

View File

@@ -8,25 +8,36 @@ import (
"encoding/hex"
"fmt"
"io"
"log"
"net/http"
"os"
"strings"
"time"
)
// Secrets handles secret and token management
type Secrets struct {
db *sql.DB
db *sql.DB
fixedRegistrationToken string // Fixed token from environment variable (reusable, never expires)
}
// NewSecrets creates a new secrets manager
func NewSecrets(db *sql.DB) (*Secrets, error) {
s := &Secrets{db: db}
// Check for fixed registration token from environment
fixedToken := os.Getenv("FIXED_REGISTRATION_TOKEN")
if fixedToken != "" {
s.fixedRegistrationToken = fixedToken
log.Printf("Fixed registration token enabled (from FIXED_REGISTRATION_TOKEN env var)")
log.Printf("WARNING: Fixed registration token is reusable and never expires - use only for testing/development!")
}
// Ensure manager secret exists
if err := s.ensureManagerSecret(); err != nil {
return nil, fmt.Errorf("failed to ensure manager secret: %w", err)
}
return s, nil
}
@@ -37,20 +48,20 @@ func (s *Secrets) ensureManagerSecret() error {
if err != nil {
return fmt.Errorf("failed to check manager secrets: %w", err)
}
if count == 0 {
// Generate new manager secret
secret, err := generateSecret(32)
if err != nil {
return fmt.Errorf("failed to generate manager secret: %w", err)
}
_, err = s.db.Exec("INSERT INTO manager_secrets (secret) VALUES (?)", secret)
if err != nil {
return fmt.Errorf("failed to store manager secret: %w", err)
}
}
return nil
}
@@ -70,9 +81,9 @@ func (s *Secrets) GenerateRegistrationToken(createdBy int64, expiresIn time.Dura
if err != nil {
return "", fmt.Errorf("failed to generate token: %w", err)
}
expiresAt := time.Now().Add(expiresIn)
_, err = s.db.Exec(
"INSERT INTO registration_tokens (token, expires_at, created_by) VALUES (?, ?, ?)",
token, expiresAt, createdBy,
@@ -80,43 +91,67 @@ func (s *Secrets) GenerateRegistrationToken(createdBy int64, expiresIn time.Dura
if err != nil {
return "", fmt.Errorf("failed to store registration token: %w", err)
}
return token, nil
}
// TokenValidationResult represents the result of token validation
type TokenValidationResult struct {
Valid bool
Reason string // "valid", "not_found", "already_used", "expired"
Error error
}
// ValidateRegistrationToken validates a registration token
func (s *Secrets) ValidateRegistrationToken(token string) (bool, error) {
result, err := s.ValidateRegistrationTokenDetailed(token)
if err != nil {
return false, err
}
// For backward compatibility, return just the valid boolean
return result.Valid, nil
}
// ValidateRegistrationTokenDetailed validates a registration token and returns detailed result
func (s *Secrets) ValidateRegistrationTokenDetailed(token string) (*TokenValidationResult, error) {
// Check fixed token first (if set) - it's reusable and never expires
if s.fixedRegistrationToken != "" && token == s.fixedRegistrationToken {
log.Printf("Fixed registration token used (from FIXED_REGISTRATION_TOKEN env var)")
return &TokenValidationResult{Valid: true, Reason: "valid"}, nil
}
// Check database tokens
var used bool
var expiresAt time.Time
var id int64
err := s.db.QueryRow(
"SELECT id, expires_at, used FROM registration_tokens WHERE token = ?",
token,
).Scan(&id, &expiresAt, &used)
if err == sql.ErrNoRows {
return false, nil
return &TokenValidationResult{Valid: false, Reason: "not_found"}, nil
}
if err != nil {
return false, fmt.Errorf("failed to query token: %w", err)
return nil, fmt.Errorf("failed to query token: %w", err)
}
if used {
return false, nil
return &TokenValidationResult{Valid: false, Reason: "already_used"}, nil
}
if time.Now().After(expiresAt) {
return false, nil
return &TokenValidationResult{Valid: false, Reason: "expired"}, nil
}
// Mark token as used
_, err = s.db.Exec("UPDATE registration_tokens SET used = 1 WHERE id = ?", id)
if err != nil {
return false, fmt.Errorf("failed to mark token as used: %w", err)
return nil, fmt.Errorf("failed to mark token as used: %w", err)
}
return true, nil
return &TokenValidationResult{Valid: true, Reason: "valid"}, nil
}
// ListRegistrationTokens lists all registration tokens
@@ -130,19 +165,19 @@ func (s *Secrets) ListRegistrationTokens() ([]map[string]interface{}, error) {
return nil, fmt.Errorf("failed to query tokens: %w", err)
}
defer rows.Close()
var tokens []map[string]interface{}
for rows.Next() {
var id, createdBy sql.NullInt64
var token string
var expiresAt, createdAt time.Time
var used bool
err := rows.Scan(&id, &token, &expiresAt, &used, &createdAt, &createdBy)
if err != nil {
continue
}
tokens = append(tokens, map[string]interface{}{
"id": id.Int64,
"token": token,
@@ -152,7 +187,7 @@ func (s *Secrets) ListRegistrationTokens() ([]map[string]interface{}, error) {
"created_by": createdBy.Int64,
})
}
return tokens, nil
}
@@ -181,28 +216,29 @@ func VerifyRequest(r *http.Request, secret string, maxAge time.Duration) (bool,
if signature == "" {
return false, fmt.Errorf("missing signature")
}
timestampStr := r.Header.Get("X-Runner-Timestamp")
if timestampStr == "" {
return false, fmt.Errorf("missing timestamp")
}
var timestamp time.Time
_, err := fmt.Sscanf(timestampStr, "%d", &timestamp)
var timestampUnix int64
_, err := fmt.Sscanf(timestampStr, "%d", &timestampUnix)
if err != nil {
return false, fmt.Errorf("invalid timestamp: %w", err)
}
timestamp := time.Unix(timestampUnix, 0)
// Check timestamp is not too old
if time.Since(timestamp) > maxAge {
return false, fmt.Errorf("request too old")
}
// Check timestamp is not in the future (allow 1 minute clock skew)
if timestamp.After(time.Now().Add(1 * time.Minute)) {
return false, fmt.Errorf("timestamp in future")
}
// Read body
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
@@ -210,10 +246,13 @@ func VerifyRequest(r *http.Request, secret string, maxAge time.Duration) (bool,
}
// Restore body for handler
r.Body = io.NopCloser(strings.NewReader(string(bodyBytes)))
// Verify signature
expectedSig := SignRequest(r.Method, r.URL.Path, string(bodyBytes), secret, timestamp)
// Verify signature - use path without query parameters (query params are not part of signature)
// The runner signs with the path including query params, but we verify with just the path
// This is intentional - query params are for identification, not part of the signature
path := r.URL.Path
expectedSig := SignRequest(r.Method, path, string(bodyBytes), secret, timestamp)
return hmac.Equal([]byte(signature), []byte(expectedSig)), nil
}
@@ -241,4 +280,3 @@ func generateSecret(length int) (string, error) {
}
return hex.EncodeToString(bytes), nil
}

View File

@@ -3,8 +3,9 @@ package database
import (
"database/sql"
"fmt"
"log"
_ "github.com/mattn/go-sqlite3"
_ "github.com/marcboeker/go-duckdb/v2"
)
// DB wraps the database connection
@@ -14,7 +15,7 @@ type DB struct {
// NewDB creates a new database connection
func NewDB(dbPath string) (*DB, error) {
db, err := sql.Open("sqlite3", dbPath+"?_foreign_keys=1")
db, err := sql.Open("duckdb", dbPath)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
@@ -33,114 +34,133 @@ func NewDB(dbPath string) (*DB, error) {
// migrate runs database migrations
func (db *DB) migrate() error {
// Create sequences for auto-incrementing primary keys
sequences := []string{
`CREATE SEQUENCE IF NOT EXISTS seq_users_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_jobs_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_runners_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_tasks_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_job_files_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_manager_secrets_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_registration_tokens_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_task_logs_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_task_steps_id START 1`,
}
for _, seq := range sequences {
if _, err := db.Exec(seq); err != nil {
return fmt.Errorf("failed to create sequence: %w", err)
}
}
schema := `
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
id BIGINT PRIMARY KEY DEFAULT nextval('seq_users_id'),
email TEXT UNIQUE NOT NULL,
name TEXT NOT NULL,
oauth_provider TEXT NOT NULL,
oauth_id TEXT NOT NULL,
is_admin BOOLEAN NOT NULL DEFAULT 0,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
password_hash TEXT,
is_admin BOOLEAN NOT NULL DEFAULT false,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
UNIQUE(oauth_provider, oauth_id)
);
CREATE TABLE IF NOT EXISTS jobs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
id BIGINT PRIMARY KEY DEFAULT nextval('seq_jobs_id'),
user_id BIGINT NOT NULL,
job_type TEXT NOT NULL DEFAULT 'render',
name TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
progress REAL NOT NULL DEFAULT 0.0,
frame_start INTEGER NOT NULL,
frame_end INTEGER NOT NULL,
output_format TEXT NOT NULL DEFAULT 'PNG',
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
started_at DATETIME,
completed_at DATETIME,
frame_start INTEGER,
frame_end INTEGER,
output_format TEXT,
allow_parallel_runners BOOLEAN,
timeout_seconds INTEGER DEFAULT 86400,
blend_metadata TEXT,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
error_message TEXT,
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
FOREIGN KEY (user_id) REFERENCES users(id)
);
CREATE TABLE IF NOT EXISTS runners (
id INTEGER PRIMARY KEY AUTOINCREMENT,
id BIGINT PRIMARY KEY DEFAULT nextval('seq_runners_id'),
name TEXT NOT NULL,
hostname TEXT NOT NULL,
ip_address TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'offline',
last_heartbeat DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_heartbeat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
capabilities TEXT,
registration_token TEXT,
runner_secret TEXT,
manager_secret TEXT,
verified BOOLEAN NOT NULL DEFAULT 0,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
verified BOOLEAN NOT NULL DEFAULT false,
priority INTEGER NOT NULL DEFAULT 100,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS tasks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
job_id INTEGER NOT NULL,
runner_id INTEGER,
id BIGINT PRIMARY KEY DEFAULT nextval('seq_tasks_id'),
job_id BIGINT NOT NULL,
runner_id BIGINT,
frame_start INTEGER NOT NULL,
frame_end INTEGER NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
output_path TEXT,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
started_at DATETIME,
completed_at DATETIME,
error_message TEXT,
FOREIGN KEY (job_id) REFERENCES jobs(id) ON DELETE CASCADE,
FOREIGN KEY (runner_id) REFERENCES runners(id) ON DELETE SET NULL
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
error_message TEXT
);
CREATE TABLE IF NOT EXISTS job_files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
job_id INTEGER NOT NULL,
id BIGINT PRIMARY KEY DEFAULT nextval('seq_job_files_id'),
job_id BIGINT NOT NULL,
file_type TEXT NOT NULL,
file_path TEXT NOT NULL,
file_name TEXT NOT NULL,
file_size INTEGER NOT NULL,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (job_id) REFERENCES jobs(id) ON DELETE CASCADE
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS manager_secrets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
id BIGINT PRIMARY KEY DEFAULT nextval('seq_manager_secrets_id'),
secret TEXT UNIQUE NOT NULL,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS registration_tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
id BIGINT PRIMARY KEY DEFAULT nextval('seq_registration_tokens_id'),
token TEXT UNIQUE NOT NULL,
expires_at DATETIME NOT NULL,
used BOOLEAN NOT NULL DEFAULT 0,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
created_by INTEGER,
FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE SET NULL
expires_at TIMESTAMP NOT NULL,
used BOOLEAN NOT NULL DEFAULT false,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
created_by BIGINT,
FOREIGN KEY (created_by) REFERENCES users(id)
);
CREATE TABLE IF NOT EXISTS task_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_id INTEGER NOT NULL,
runner_id INTEGER,
id BIGINT PRIMARY KEY DEFAULT nextval('seq_task_logs_id'),
task_id BIGINT NOT NULL,
runner_id BIGINT,
log_level TEXT NOT NULL,
message TEXT NOT NULL,
step_name TEXT,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (task_id) REFERENCES tasks(id) ON DELETE CASCADE,
FOREIGN KEY (runner_id) REFERENCES runners(id) ON DELETE SET NULL
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS task_steps (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_id INTEGER NOT NULL,
id BIGINT PRIMARY KEY DEFAULT nextval('seq_task_steps_id'),
task_id BIGINT NOT NULL,
step_name TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
started_at DATETIME,
completed_at DATETIME,
started_at TIMESTAMP,
completed_at TIMESTAMP,
duration_ms INTEGER,
error_message TEXT,
FOREIGN KEY (task_id) REFERENCES tasks(id) ON DELETE CASCADE
error_message TEXT
);
CREATE INDEX IF NOT EXISTS idx_jobs_user_id ON jobs(user_id);
@@ -156,6 +176,12 @@ func (db *DB) migrate() error {
CREATE INDEX IF NOT EXISTS idx_task_logs_runner_id ON task_logs(runner_id);
CREATE INDEX IF NOT EXISTS idx_task_steps_task_id ON task_steps(task_id);
CREATE INDEX IF NOT EXISTS idx_runners_last_heartbeat ON runners(last_heartbeat);
CREATE TABLE IF NOT EXISTS settings (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
`
if _, err := db.Exec(schema); err != nil {
@@ -165,30 +191,57 @@ func (db *DB) migrate() error {
// Migrate existing tables to add new columns
migrations := []string{
// Add is_admin to users if it doesn't exist
`ALTER TABLE users ADD COLUMN is_admin BOOLEAN NOT NULL DEFAULT 0`,
`ALTER TABLE users ADD COLUMN IF NOT EXISTS is_admin BOOLEAN NOT NULL DEFAULT false`,
// Add new columns to runners if they don't exist
`ALTER TABLE runners ADD COLUMN registration_token TEXT`,
`ALTER TABLE runners ADD COLUMN runner_secret TEXT`,
`ALTER TABLE runners ADD COLUMN manager_secret TEXT`,
`ALTER TABLE runners ADD COLUMN verified BOOLEAN NOT NULL DEFAULT 0`,
`ALTER TABLE runners ADD COLUMN IF NOT EXISTS registration_token TEXT`,
`ALTER TABLE runners ADD COLUMN IF NOT EXISTS runner_secret TEXT`,
`ALTER TABLE runners ADD COLUMN IF NOT EXISTS manager_secret TEXT`,
`ALTER TABLE runners ADD COLUMN IF NOT EXISTS verified BOOLEAN NOT NULL DEFAULT false`,
`ALTER TABLE runners ADD COLUMN IF NOT EXISTS priority INTEGER NOT NULL DEFAULT 100`,
// Add allow_parallel_runners to jobs if it doesn't exist
`ALTER TABLE jobs ADD COLUMN allow_parallel_runners BOOLEAN NOT NULL DEFAULT 1`,
`ALTER TABLE jobs ADD COLUMN IF NOT EXISTS allow_parallel_runners BOOLEAN NOT NULL DEFAULT true`,
// Add timeout_seconds to jobs if it doesn't exist
`ALTER TABLE jobs ADD COLUMN timeout_seconds INTEGER DEFAULT 86400`,
`ALTER TABLE jobs ADD COLUMN IF NOT EXISTS timeout_seconds INTEGER DEFAULT 86400`,
// Add blend_metadata to jobs if it doesn't exist
`ALTER TABLE jobs ADD COLUMN blend_metadata TEXT`,
`ALTER TABLE jobs ADD COLUMN IF NOT EXISTS blend_metadata TEXT`,
// Add job_type to jobs if it doesn't exist
`ALTER TABLE jobs ADD COLUMN IF NOT EXISTS job_type TEXT DEFAULT 'render'`,
// Add task_type to tasks if it doesn't exist
`ALTER TABLE tasks ADD COLUMN task_type TEXT DEFAULT 'render'`,
`ALTER TABLE tasks ADD COLUMN IF NOT EXISTS task_type TEXT DEFAULT 'render'`,
// Add new columns to tasks if they don't exist
`ALTER TABLE tasks ADD COLUMN current_step TEXT`,
`ALTER TABLE tasks ADD COLUMN retry_count INTEGER DEFAULT 0`,
`ALTER TABLE tasks ADD COLUMN max_retries INTEGER DEFAULT 3`,
`ALTER TABLE tasks ADD COLUMN timeout_seconds INTEGER`,
`ALTER TABLE tasks ADD COLUMN IF NOT EXISTS current_step TEXT`,
`ALTER TABLE tasks ADD COLUMN IF NOT EXISTS retry_count INTEGER DEFAULT 0`,
`ALTER TABLE tasks ADD COLUMN IF NOT EXISTS max_retries INTEGER DEFAULT 3`,
`ALTER TABLE tasks ADD COLUMN IF NOT EXISTS timeout_seconds INTEGER`,
}
for _, migration := range migrations {
// SQLite doesn't support IF NOT EXISTS for ALTER TABLE, so we ignore errors
db.Exec(migration)
// DuckDB supports IF NOT EXISTS for ALTER TABLE, so we can safely execute
if _, err := db.Exec(migration); err != nil {
// Log but don't fail - column might already exist or table might not exist yet
// This is fine for migrations that run after schema creation
}
}
// Initialize registration_enabled setting (default: true) if it doesn't exist
var settingCount int
err := db.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
if err == nil && settingCount == 0 {
_, err = db.Exec("INSERT INTO settings (key, value) VALUES (?, ?)", "registration_enabled", "true")
if err != nil {
// Log but don't fail - setting might have been created by another process
log.Printf("Note: Could not initialize registration_enabled setting: %v", err)
}
}
return nil
for _, migration := range migrations {
// DuckDB supports IF NOT EXISTS for ALTER TABLE, so we can safely execute
if _, err := db.Exec(migration); err != nil {
// Log but don't fail - column might already exist or table might not exist yet
// This is fine for migrations that run after schema creation
}
}
return nil
@@ -198,4 +251,3 @@ func (db *DB) migrate() error {
func (db *DB) Close() error {
return db.DB.Close()
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,12 @@
package storage
import (
"archive/zip"
"fmt"
"io"
"os"
"path/filepath"
"strings"
)
// Storage handles file storage operations
@@ -135,3 +137,60 @@ func (s *Storage) GetFileSize(filePath string) (int64, error) {
return info.Size(), nil
}
// ExtractZip extracts a ZIP file to the destination directory
// Returns a list of all extracted file paths
func (s *Storage) ExtractZip(zipPath, destDir string) ([]string, error) {
r, err := zip.OpenReader(zipPath)
if err != nil {
return nil, fmt.Errorf("failed to open ZIP file: %w", err)
}
defer r.Close()
var extractedFiles []string
for _, f := range r.File {
// Sanitize file path to prevent directory traversal
destPath := filepath.Join(destDir, f.Name)
if !strings.HasPrefix(destPath, filepath.Clean(destDir)+string(os.PathSeparator)) {
return nil, fmt.Errorf("invalid file path in ZIP: %s", f.Name)
}
// Create directory structure
if f.FileInfo().IsDir() {
if err := os.MkdirAll(destPath, 0755); err != nil {
return nil, fmt.Errorf("failed to create directory: %w", err)
}
continue
}
// Create parent directories
if err := os.MkdirAll(filepath.Dir(destPath), 0755); err != nil {
return nil, fmt.Errorf("failed to create parent directory: %w", err)
}
// Extract file
rc, err := f.Open()
if err != nil {
return nil, fmt.Errorf("failed to open file in ZIP: %w", err)
}
outFile, err := os.Create(destPath)
if err != nil {
rc.Close()
return nil, fmt.Errorf("failed to create file: %w", err)
}
_, err = io.Copy(outFile, rc)
outFile.Close()
rc.Close()
if err != nil {
return nil, fmt.Errorf("failed to extract file: %w", err)
}
extractedFiles = append(extractedFiles, destPath)
}
return extractedFiles, nil
}