massive changes and it works
This commit is contained in:
@@ -7,7 +7,7 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"fuego/pkg/types"
|
||||
"jiggablend/pkg/types"
|
||||
)
|
||||
|
||||
// handleGenerateRegistrationToken generates a new registration token
|
||||
@@ -126,7 +126,7 @@ func (s *Server) handleDeleteRunner(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, name, hostname, ip_address, status, last_heartbeat, capabilities,
|
||||
registration_token, verified, created_at
|
||||
registration_token, verified, priority, created_at
|
||||
FROM runners ORDER BY created_at DESC`,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -144,7 +144,7 @@ func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request)
|
||||
err := rows.Scan(
|
||||
&runner.ID, &runner.Name, &runner.Hostname, &runner.IPAddress,
|
||||
&runner.Status, &runner.LastHeartbeat, &runner.Capabilities,
|
||||
®istrationToken, &verified, &runner.CreatedAt,
|
||||
®istrationToken, &verified, &runner.Priority, &runner.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan runner: %v", err))
|
||||
@@ -161,9 +161,206 @@ func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request)
|
||||
"capabilities": runner.Capabilities,
|
||||
"registration_token": registrationToken.String,
|
||||
"verified": verified,
|
||||
"priority": runner.Priority,
|
||||
"created_at": runner.CreatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, runners)
|
||||
}
|
||||
|
||||
// handleListUsers lists all users
|
||||
func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) {
|
||||
// Get first user ID to mark it in the response
|
||||
firstUserID, err := s.auth.GetFirstUserID()
|
||||
if err != nil {
|
||||
// If no users exist, firstUserID will be 0, which is fine
|
||||
firstUserID = 0
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, email, name, oauth_provider, is_admin, created_at
|
||||
FROM users ORDER BY created_at DESC`,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query users: %v", err))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
users := []map[string]interface{}{}
|
||||
for rows.Next() {
|
||||
var userID int64
|
||||
var email, name, oauthProvider string
|
||||
var isAdmin bool
|
||||
var createdAt time.Time
|
||||
|
||||
err := rows.Scan(&userID, &email, &name, &oauthProvider, &isAdmin, &createdAt)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan user: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Get job count for this user
|
||||
var jobCount int
|
||||
err = s.db.QueryRow("SELECT COUNT(*) FROM jobs WHERE user_id = ?", userID).Scan(&jobCount)
|
||||
if err != nil {
|
||||
jobCount = 0 // Default to 0 if query fails
|
||||
}
|
||||
|
||||
users = append(users, map[string]interface{}{
|
||||
"id": userID,
|
||||
"email": email,
|
||||
"name": name,
|
||||
"oauth_provider": oauthProvider,
|
||||
"is_admin": isAdmin,
|
||||
"created_at": createdAt,
|
||||
"job_count": jobCount,
|
||||
"is_first_user": userID == firstUserID,
|
||||
})
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, users)
|
||||
}
|
||||
|
||||
// handleGetUserJobs gets all jobs for a specific user
|
||||
func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Verify user exists
|
||||
var exists bool
|
||||
err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", userID).Scan(&exists)
|
||||
if err != nil || !exists {
|
||||
s.respondError(w, http.StatusNotFound, "User not found")
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format,
|
||||
allow_parallel_runners, timeout_seconds, blend_metadata, created_at, started_at, completed_at, error_message
|
||||
FROM jobs WHERE user_id = ? ORDER BY created_at DESC`,
|
||||
userID,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query jobs: %v", err))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
jobs := []types.Job{}
|
||||
for rows.Next() {
|
||||
var job types.Job
|
||||
var jobType string
|
||||
var startedAt, completedAt sql.NullTime
|
||||
var blendMetadataJSON sql.NullString
|
||||
var errorMessage sql.NullString
|
||||
var frameStart, frameEnd sql.NullInt64
|
||||
var outputFormat sql.NullString
|
||||
var allowParallelRunners sql.NullBool
|
||||
|
||||
err := rows.Scan(
|
||||
&job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress,
|
||||
&frameStart, &frameEnd, &outputFormat, &allowParallelRunners, &job.TimeoutSeconds,
|
||||
&blendMetadataJSON, &job.CreatedAt, &startedAt, &completedAt, &errorMessage,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan job: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
job.JobType = types.JobType(jobType)
|
||||
if frameStart.Valid {
|
||||
fs := int(frameStart.Int64)
|
||||
job.FrameStart = &fs
|
||||
}
|
||||
if frameEnd.Valid {
|
||||
fe := int(frameEnd.Int64)
|
||||
job.FrameEnd = &fe
|
||||
}
|
||||
if outputFormat.Valid {
|
||||
job.OutputFormat = &outputFormat.String
|
||||
}
|
||||
if allowParallelRunners.Valid {
|
||||
job.AllowParallelRunners = &allowParallelRunners.Bool
|
||||
}
|
||||
if startedAt.Valid {
|
||||
job.StartedAt = &startedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
job.CompletedAt = &completedAt.Time
|
||||
}
|
||||
if blendMetadataJSON.Valid && blendMetadataJSON.String != "" {
|
||||
var metadata types.BlendMetadata
|
||||
if err := json.Unmarshal([]byte(blendMetadataJSON.String), &metadata); err == nil {
|
||||
job.BlendMetadata = &metadata
|
||||
}
|
||||
}
|
||||
if errorMessage.Valid {
|
||||
job.ErrorMessage = errorMessage.String
|
||||
}
|
||||
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, jobs)
|
||||
}
|
||||
|
||||
// handleGetRegistrationEnabled gets the registration enabled setting
|
||||
func (s *Server) handleGetRegistrationEnabled(w http.ResponseWriter, r *http.Request) {
|
||||
enabled, err := s.auth.IsRegistrationEnabled()
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to get registration setting: %v", err))
|
||||
return
|
||||
}
|
||||
s.respondJSON(w, http.StatusOK, map[string]bool{"enabled": enabled})
|
||||
}
|
||||
|
||||
// handleSetRegistrationEnabled sets the registration enabled setting
|
||||
func (s *Server) handleSetRegistrationEnabled(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.auth.SetRegistrationEnabled(req.Enabled); err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to set registration setting: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]bool{"enabled": req.Enabled})
|
||||
}
|
||||
|
||||
// handleSetUserAdminStatus sets a user's admin status (admin only)
|
||||
func (s *Server) handleSetUserAdminStatus(w http.ResponseWriter, r *http.Request) {
|
||||
targetUserID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.auth.SetUserAdminStatus(targetUserID, req.IsAdmin); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"user_id": targetUserID,
|
||||
"is_admin": req.IsAdmin,
|
||||
"message": "Admin status updated successfully",
|
||||
})
|
||||
}
|
||||
|
||||
1246
internal/api/jobs.go
1246
internal/api/jobs.go
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,7 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"fuego/pkg/types"
|
||||
"jiggablend/pkg/types"
|
||||
)
|
||||
|
||||
// handleSubmitMetadata handles metadata submission from runner
|
||||
@@ -19,7 +19,7 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Get runner ID from context (set by runnerAuthMiddleware)
|
||||
runnerID, ok := r.Context().Value("runner_id").(int64)
|
||||
runnerID, ok := r.Context().Value(runnerIDContextKey).(int64)
|
||||
if !ok {
|
||||
s.respondError(w, http.StatusUnauthorized, "runner_id not found in context")
|
||||
return
|
||||
@@ -44,16 +44,32 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Find the metadata extraction task for this job
|
||||
// First try to find task assigned to this runner, then fall back to any metadata task for this job
|
||||
var taskID int64
|
||||
err = s.db.QueryRow(
|
||||
`SELECT id FROM tasks WHERE job_id = ? AND task_type = ? AND runner_id = ?`,
|
||||
jobID, types.TaskTypeMetadata, runnerID,
|
||||
).Scan(&taskID)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Metadata extraction task not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
// Fall back to any metadata task for this job (in case assignment changed)
|
||||
err = s.db.QueryRow(
|
||||
`SELECT id FROM tasks WHERE job_id = ? AND task_type = ? ORDER BY created_at DESC LIMIT 1`,
|
||||
jobID, types.TaskTypeMetadata,
|
||||
).Scan(&taskID)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Metadata extraction task not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to find task: %v", err))
|
||||
return
|
||||
}
|
||||
// Update the task to be assigned to this runner if it wasn't already
|
||||
s.db.Exec(
|
||||
`UPDATE tasks SET runner_id = ? WHERE id = ? AND runner_id IS NULL`,
|
||||
runnerID, taskID,
|
||||
)
|
||||
} else if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to find task: %v", err))
|
||||
return
|
||||
}
|
||||
@@ -82,6 +98,9 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) {
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to mark metadata task as completed: %v", err)
|
||||
} else {
|
||||
// Update job status and progress after metadata task completes
|
||||
s.updateJobStatusFromTasks(jobID)
|
||||
}
|
||||
|
||||
log.Printf("Metadata extracted for job %d: frame_start=%d, frame_end=%d", jobID, metadata.FrameStart, metadata.FrameEnd)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,10 +10,10 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
authpkg "fuego/internal/auth"
|
||||
"fuego/internal/database"
|
||||
"fuego/internal/storage"
|
||||
"fuego/pkg/types"
|
||||
authpkg "jiggablend/internal/auth"
|
||||
"jiggablend/internal/database"
|
||||
"jiggablend/internal/storage"
|
||||
"jiggablend/pkg/types"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
@@ -35,6 +35,12 @@ type Server struct {
|
||||
runnerConnsMu sync.RWMutex
|
||||
frontendConns map[string]*websocket.Conn // key: "jobId:taskId"
|
||||
frontendConnsMu sync.RWMutex
|
||||
// Mutexes for each frontend connection to serialize writes
|
||||
frontendConnsWriteMu map[string]*sync.Mutex // key: "jobId:taskId"
|
||||
frontendConnsWriteMuMu sync.RWMutex
|
||||
// Throttling for progress updates (per job)
|
||||
progressUpdateTimes map[int64]time.Time // key: jobID
|
||||
progressUpdateTimesMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewServer creates a new API server
|
||||
@@ -57,8 +63,10 @@ func NewServer(db *database.DB, auth *authpkg.Auth, storage *storage.Storage) (*
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
},
|
||||
runnerConns: make(map[int64]*websocket.Conn),
|
||||
frontendConns: make(map[string]*websocket.Conn),
|
||||
runnerConns: make(map[int64]*websocket.Conn),
|
||||
frontendConns: make(map[string]*websocket.Conn),
|
||||
frontendConnsWriteMu: make(map[string]*sync.Mutex),
|
||||
progressUpdateTimes: make(map[int64]time.Time),
|
||||
}
|
||||
|
||||
s.setupMiddleware()
|
||||
@@ -72,13 +80,14 @@ func NewServer(db *database.DB, auth *authpkg.Auth, storage *storage.Storage) (*
|
||||
func (s *Server) setupMiddleware() {
|
||||
s.router.Use(middleware.Logger)
|
||||
s.router.Use(middleware.Recoverer)
|
||||
s.router.Use(middleware.Timeout(60 * time.Second))
|
||||
// Note: Timeout middleware is NOT applied globally to avoid conflicts with WebSocket connections
|
||||
// WebSocket connections are long-lived and should not have HTTP timeouts
|
||||
|
||||
s.router.Use(cors.Handler(cors.Options{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"},
|
||||
ExposedHeaders: []string{"Link"},
|
||||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "Range"},
|
||||
ExposedHeaders: []string{"Link", "Content-Range", "Accept-Ranges", "Content-Length"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 300,
|
||||
}))
|
||||
@@ -88,12 +97,17 @@ func (s *Server) setupMiddleware() {
|
||||
func (s *Server) setupRoutes() {
|
||||
// Public routes
|
||||
s.router.Route("/api/auth", func(r chi.Router) {
|
||||
r.Get("/providers", s.handleGetAuthProviders)
|
||||
r.Get("/google/login", s.handleGoogleLogin)
|
||||
r.Get("/google/callback", s.handleGoogleCallback)
|
||||
r.Get("/discord/login", s.handleDiscordLogin)
|
||||
r.Get("/discord/callback", s.handleDiscordCallback)
|
||||
r.Get("/local/available", s.handleLocalLoginAvailable)
|
||||
r.Post("/local/register", s.handleLocalRegister)
|
||||
r.Post("/local/login", s.handleLocalLogin)
|
||||
r.Post("/logout", s.handleLogout)
|
||||
r.Get("/me", s.handleGetMe)
|
||||
r.Post("/change-password", s.handleChangePassword)
|
||||
})
|
||||
|
||||
// Protected routes
|
||||
@@ -105,6 +119,7 @@ func (s *Server) setupRoutes() {
|
||||
r.Get("/", s.handleListJobs)
|
||||
r.Get("/{id}", s.handleGetJob)
|
||||
r.Delete("/{id}", s.handleCancelJob)
|
||||
r.Post("/{id}/delete", s.handleDeleteJob)
|
||||
r.Post("/{id}/upload", s.handleUploadJobFile)
|
||||
r.Get("/{id}/files", s.handleListJobFiles)
|
||||
r.Get("/{id}/files/{fileId}/download", s.handleDownloadJobFile)
|
||||
@@ -112,18 +127,17 @@ func (s *Server) setupRoutes() {
|
||||
r.Get("/{id}/metadata", s.handleGetJobMetadata)
|
||||
r.Get("/{id}/tasks", s.handleListJobTasks)
|
||||
r.Get("/{id}/tasks/{taskId}/logs", s.handleGetTaskLogs)
|
||||
r.Get("/{id}/tasks/{taskId}/logs/ws", s.handleStreamTaskLogsWebSocket)
|
||||
// WebSocket route - no timeout middleware (long-lived connection)
|
||||
r.With(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Remove timeout middleware for WebSocket
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}).Get("/{id}/tasks/{taskId}/logs/ws", s.handleStreamTaskLogsWebSocket)
|
||||
r.Get("/{id}/tasks/{taskId}/steps", s.handleGetTaskSteps)
|
||||
r.Post("/{id}/tasks/{taskId}/retry", s.handleRetryTask)
|
||||
})
|
||||
|
||||
s.router.Route("/api/runners", func(r chi.Router) {
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(s.auth.Middleware(next.ServeHTTP))
|
||||
})
|
||||
r.Get("/", s.handleListRunners)
|
||||
})
|
||||
|
||||
// Admin routes
|
||||
s.router.Route("/api/admin", func(r chi.Router) {
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
@@ -139,14 +153,23 @@ func (s *Server) setupRoutes() {
|
||||
r.Post("/{id}/verify", s.handleVerifyRunner)
|
||||
r.Delete("/{id}", s.handleDeleteRunner)
|
||||
})
|
||||
r.Route("/users", func(r chi.Router) {
|
||||
r.Get("/", s.handleListUsers)
|
||||
r.Get("/{id}/jobs", s.handleGetUserJobs)
|
||||
r.Post("/{id}/admin", s.handleSetUserAdminStatus)
|
||||
})
|
||||
r.Route("/settings", func(r chi.Router) {
|
||||
r.Get("/registration", s.handleGetRegistrationEnabled)
|
||||
r.Post("/registration", s.handleSetRegistrationEnabled)
|
||||
})
|
||||
})
|
||||
|
||||
// Runner API
|
||||
s.router.Route("/api/runner", func(r chi.Router) {
|
||||
// Registration doesn't require auth (uses token)
|
||||
r.Post("/register", s.handleRegisterRunner)
|
||||
r.With(middleware.Timeout(60*time.Second)).Post("/register", s.handleRegisterRunner)
|
||||
|
||||
// WebSocket endpoint (auth handled in handler)
|
||||
// WebSocket endpoint (auth handled in handler) - no timeout middleware
|
||||
r.Get("/ws", s.handleRunnerWebSocket)
|
||||
|
||||
// File operations still use HTTP (WebSocket not suitable for large files)
|
||||
@@ -156,7 +179,7 @@ func (s *Server) setupRoutes() {
|
||||
})
|
||||
r.Post("/tasks/{id}/progress", s.handleUpdateTaskProgress)
|
||||
r.Post("/tasks/{id}/steps", s.handleUpdateTaskStep)
|
||||
r.Get("/files/{jobId}/{fileName}", s.handleDownloadFileForRunner)
|
||||
r.Get("/files/{jobId}/*", s.handleDownloadFileForRunner)
|
||||
r.Post("/files/{jobId}/upload", s.handleUploadFileFromRunner)
|
||||
r.Get("/jobs/{jobId}/status", s.handleGetJobStatusForRunner)
|
||||
r.Get("/jobs/{jobId}/files", s.handleGetJobFilesForRunner)
|
||||
@@ -205,6 +228,11 @@ func (s *Server) handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
session, err := s.auth.GoogleCallback(r.Context(), code)
|
||||
if err != nil {
|
||||
// If registration is disabled, redirect back to login with error
|
||||
if err.Error() == "registration is disabled" {
|
||||
http.Redirect(w, r, "/?error=registration_disabled", http.StatusFound)
|
||||
return
|
||||
}
|
||||
s.respondError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -240,6 +268,11 @@ func (s *Server) handleDiscordCallback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
session, err := s.auth.DiscordCallback(r.Context(), code)
|
||||
if err != nil {
|
||||
// If registration is disabled, redirect back to login with error
|
||||
if err.Error() == "registration is disabled" {
|
||||
http.Redirect(w, r, "/?error=registration_disabled", http.StatusFound)
|
||||
return
|
||||
}
|
||||
s.respondError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -293,6 +326,166 @@ func (s *Server) handleGetMe(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleGetAuthProviders(w http.ResponseWriter, r *http.Request) {
|
||||
s.respondJSON(w, http.StatusOK, map[string]bool{
|
||||
"google": s.auth.IsGoogleOAuthConfigured(),
|
||||
"discord": s.auth.IsDiscordOAuthConfigured(),
|
||||
"local": s.auth.IsLocalLoginEnabled(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleLocalLoginAvailable(w http.ResponseWriter, r *http.Request) {
|
||||
s.respondJSON(w, http.StatusOK, map[string]bool{
|
||||
"available": s.auth.IsLocalLoginEnabled(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleLocalRegister(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Email == "" || req.Name == "" || req.Password == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "Email, name, and password are required")
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Password) < 8 {
|
||||
s.respondError(w, http.StatusBadRequest, "Password must be at least 8 characters long")
|
||||
return
|
||||
}
|
||||
|
||||
session, err := s.auth.RegisterLocalUser(req.Email, req.Name, req.Password)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionID := s.auth.CreateSession(session)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_id",
|
||||
Value: sessionID,
|
||||
Path: "/",
|
||||
MaxAge: 86400,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"message": "Registration successful",
|
||||
"user": map[string]interface{}{
|
||||
"id": session.UserID,
|
||||
"email": session.Email,
|
||||
"name": session.Name,
|
||||
"is_admin": session.IsAdmin,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleLocalLogin(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Username == "" || req.Password == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "Username and password are required")
|
||||
return
|
||||
}
|
||||
|
||||
session, err := s.auth.LocalLogin(req.Username, req.Password)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, "Invalid credentials")
|
||||
return
|
||||
}
|
||||
|
||||
sessionID := s.auth.CreateSession(session)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_id",
|
||||
Value: sessionID,
|
||||
Path: "/",
|
||||
MaxAge: 86400,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"message": "Login successful",
|
||||
"user": map[string]interface{}{
|
||||
"id": session.UserID,
|
||||
"email": session.Email,
|
||||
"name": session.Name,
|
||||
"is_admin": session.IsAdmin,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleChangePassword(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := getUserID(r)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
OldPassword string `json:"old_password"`
|
||||
NewPassword string `json:"new_password"`
|
||||
TargetUserID *int64 `json:"target_user_id,omitempty"` // For admin to change other users' passwords
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.NewPassword == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "New password is required")
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.NewPassword) < 8 {
|
||||
s.respondError(w, http.StatusBadRequest, "Password must be at least 8 characters long")
|
||||
return
|
||||
}
|
||||
|
||||
isAdmin := authpkg.IsAdmin(r.Context())
|
||||
|
||||
// If target_user_id is provided and user is admin, allow changing other user's password
|
||||
if req.TargetUserID != nil && isAdmin {
|
||||
if err := s.auth.AdminChangePassword(*req.TargetUserID, req.NewPassword); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Password changed successfully"})
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, user is changing their own password (requires old password)
|
||||
if req.OldPassword == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "Old password is required")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.auth.ChangePassword(userID, req.OldPassword, req.NewPassword); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Password changed successfully"})
|
||||
}
|
||||
|
||||
// Helper to get user ID from context
|
||||
func getUserID(r *http.Request) (int64, error) {
|
||||
userID, ok := authpkg.GetUserID(r.Context())
|
||||
@@ -315,6 +508,7 @@ func parseID(r *http.Request, param string) (int64, error) {
|
||||
// StartBackgroundTasks starts background goroutines for error recovery
|
||||
func (s *Server) StartBackgroundTasks() {
|
||||
go s.recoverStuckTasks()
|
||||
go s.cleanupOldMetadataJobs()
|
||||
}
|
||||
|
||||
// recoverStuckTasks periodically checks for dead runners and stuck tasks
|
||||
@@ -322,8 +516,8 @@ func (s *Server) recoverStuckTasks() {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Also distribute tasks every 5 seconds
|
||||
distributeTicker := time.NewTicker(5 * time.Second)
|
||||
// Also distribute tasks every 10 seconds (reduced frequency since we have event-driven distribution)
|
||||
distributeTicker := time.NewTicker(10 * time.Second)
|
||||
defer distributeTicker.Stop()
|
||||
|
||||
go func() {
|
||||
@@ -341,9 +535,10 @@ func (s *Server) recoverStuckTasks() {
|
||||
}()
|
||||
|
||||
// Find dead runners (no heartbeat for 90 seconds)
|
||||
// But only mark as dead if they're not actually connected via WebSocket
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id FROM runners
|
||||
WHERE last_heartbeat < datetime('now', '-90 seconds')
|
||||
WHERE last_heartbeat < CURRENT_TIMESTAMP - INTERVAL '90 seconds'
|
||||
AND status = ?`,
|
||||
types.RunnerStatusOnline,
|
||||
)
|
||||
@@ -354,12 +549,20 @@ func (s *Server) recoverStuckTasks() {
|
||||
defer rows.Close()
|
||||
|
||||
var deadRunnerIDs []int64
|
||||
s.runnerConnsMu.RLock()
|
||||
for rows.Next() {
|
||||
var runnerID int64
|
||||
if err := rows.Scan(&runnerID); err == nil {
|
||||
deadRunnerIDs = append(deadRunnerIDs, runnerID)
|
||||
// Only mark as dead if not actually connected via WebSocket
|
||||
// The WebSocket connection is the source of truth
|
||||
if _, stillConnected := s.runnerConns[runnerID]; !stillConnected {
|
||||
deadRunnerIDs = append(deadRunnerIDs, runnerID)
|
||||
}
|
||||
// If still connected, heartbeat should be updated by pong handler or heartbeat message
|
||||
// No need to manually update here - if it's stale, the pong handler isn't working
|
||||
}
|
||||
}
|
||||
s.runnerConnsMu.RUnlock()
|
||||
rows.Close()
|
||||
|
||||
if len(deadRunnerIDs) == 0 {
|
||||
@@ -370,67 +573,7 @@ func (s *Server) recoverStuckTasks() {
|
||||
|
||||
// Reset tasks assigned to dead runners
|
||||
for _, runnerID := range deadRunnerIDs {
|
||||
// Get tasks assigned to this runner
|
||||
taskRows, err := s.db.Query(
|
||||
`SELECT id, retry_count, max_retries FROM tasks
|
||||
WHERE runner_id = ? AND status = ?`,
|
||||
runnerID, types.TaskStatusRunning,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to query tasks for runner %d: %v", runnerID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
var tasksToReset []struct {
|
||||
ID int64
|
||||
RetryCount int
|
||||
MaxRetries int
|
||||
}
|
||||
|
||||
for taskRows.Next() {
|
||||
var t struct {
|
||||
ID int64
|
||||
RetryCount int
|
||||
MaxRetries int
|
||||
}
|
||||
if err := taskRows.Scan(&t.ID, &t.RetryCount, &t.MaxRetries); err == nil {
|
||||
tasksToReset = append(tasksToReset, t)
|
||||
}
|
||||
}
|
||||
taskRows.Close()
|
||||
|
||||
// Reset or fail tasks
|
||||
for _, task := range tasksToReset {
|
||||
if task.RetryCount >= task.MaxRetries {
|
||||
// Mark as failed
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL
|
||||
WHERE id = ?`,
|
||||
types.TaskStatusFailed, "Runner died, max retries exceeded", task.ID,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to mark task %d as failed: %v", task.ID, err)
|
||||
}
|
||||
} else {
|
||||
// Reset to pending
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL,
|
||||
retry_count = retry_count + 1 WHERE id = ?`,
|
||||
types.TaskStatusPending, task.ID,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to reset task %d: %v", task.ID, err)
|
||||
} else {
|
||||
// Add log entry
|
||||
_, _ = s.db.Exec(
|
||||
`INSERT INTO task_logs (task_id, log_level, message, step_name, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
task.ID, types.LogLevelWarn, fmt.Sprintf("Runner died, task reset (retry %d/%d)", task.RetryCount+1, task.MaxRetries),
|
||||
"", time.Now(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
s.redistributeRunnerTasks(runnerID)
|
||||
|
||||
// Mark runner as offline
|
||||
_, _ = s.db.Exec(
|
||||
@@ -457,7 +600,7 @@ func (s *Server) recoverTaskTimeouts() {
|
||||
WHERE t.status = ?
|
||||
AND t.started_at IS NOT NULL
|
||||
AND (t.timeout_seconds IS NULL OR
|
||||
datetime(t.started_at, '+' || t.timeout_seconds || ' seconds') < datetime('now'))`,
|
||||
t.started_at + INTERVAL (t.timeout_seconds || ' seconds') < CURRENT_TIMESTAMP)`,
|
||||
types.TaskStatusRunning,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -507,13 +650,8 @@ func (s *Server) recoverTaskTimeouts() {
|
||||
types.TaskStatusPending, taskID,
|
||||
)
|
||||
if err == nil {
|
||||
// Add log entry
|
||||
_, _ = s.db.Exec(
|
||||
`INSERT INTO task_logs (task_id, log_level, message, step_name, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
taskID, types.LogLevelWarn, fmt.Sprintf("Task timeout exceeded, resetting (retry %d/%d)", retryCount+1, maxRetries),
|
||||
"", time.Now(),
|
||||
)
|
||||
// Add log entry using the helper function
|
||||
s.logTaskEvent(taskID, nil, types.LogLevelWarn, fmt.Sprintf("Task timeout exceeded, resetting (retry %d/%d)", retryCount+1, maxRetries), "")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,21 +5,24 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
// Auth handles authentication
|
||||
type Auth struct {
|
||||
db *sql.DB
|
||||
googleConfig *oauth2.Config
|
||||
discordConfig *oauth2.Config
|
||||
sessionStore map[string]*Session
|
||||
db *sql.DB
|
||||
googleConfig *oauth2.Config
|
||||
discordConfig *oauth2.Config
|
||||
sessionStore map[string]*Session
|
||||
}
|
||||
|
||||
// Session represents a user session
|
||||
@@ -67,9 +70,95 @@ func NewAuth(db *sql.DB) (*Auth, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize admin settings on startup to ensure they persist between boots
|
||||
if err := auth.initializeSettings(); err != nil {
|
||||
log.Printf("Warning: Failed to initialize admin settings: %v", err)
|
||||
// Don't fail startup, but log the warning
|
||||
}
|
||||
|
||||
// Initialize test local user from environment variables (for testing only)
|
||||
if err := auth.initializeTestUser(); err != nil {
|
||||
log.Printf("Warning: Failed to initialize test user: %v", err)
|
||||
// Don't fail startup, but log the warning
|
||||
}
|
||||
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// initializeSettings ensures all admin settings are initialized with defaults if they don't exist
|
||||
func (a *Auth) initializeSettings() error {
|
||||
// Initialize registration_enabled setting (default: true) if it doesn't exist
|
||||
var settingCount int
|
||||
err := a.db.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check registration_enabled setting: %w", err)
|
||||
}
|
||||
if settingCount == 0 {
|
||||
_, err = a.db.Exec(
|
||||
`INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)`,
|
||||
"registration_enabled", "true",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize registration_enabled setting: %w", err)
|
||||
}
|
||||
log.Printf("Initialized admin setting: registration_enabled = true")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// initializeTestUser creates a test local user from environment variables (for testing only)
|
||||
func (a *Auth) initializeTestUser() error {
|
||||
testEmail := os.Getenv("LOCAL_TEST_EMAIL")
|
||||
testPassword := os.Getenv("LOCAL_TEST_PASSWORD")
|
||||
|
||||
if testEmail == "" || testPassword == "" {
|
||||
// No test user configured, skip
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if user already exists
|
||||
var exists bool
|
||||
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND oauth_provider = 'local')", testEmail).Scan(&exists)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if test user exists: %w", err)
|
||||
}
|
||||
|
||||
if exists {
|
||||
// User already exists, skip creation
|
||||
log.Printf("Test user %s already exists, skipping creation", testEmail)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Hash password
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(testPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash test user password: %w", err)
|
||||
}
|
||||
|
||||
// Check if this is the first user (make them admin)
|
||||
var userCount int
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
isAdmin := userCount == 0
|
||||
|
||||
// Create test user (use email as name if no name is provided)
|
||||
testName := testEmail
|
||||
if atIndex := strings.Index(testEmail, "@"); atIndex > 0 {
|
||||
testName = testEmail[:atIndex]
|
||||
}
|
||||
|
||||
// Create test user
|
||||
_, err = a.db.Exec(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?)",
|
||||
testEmail, testName, testEmail, string(hashedPassword), isAdmin,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create test user: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Created test user: %s (admin: %v)", testEmail, isAdmin)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GoogleLoginURL returns the Google OAuth login URL
|
||||
func (a *Auth) GoogleLoginURL() (string, error) {
|
||||
if a.googleConfig == nil {
|
||||
@@ -150,36 +239,119 @@ func (a *Auth) DiscordCallback(ctx context.Context, code string) (*Session, erro
|
||||
return a.getOrCreateUser("discord", userInfo.ID, userInfo.Email, userInfo.Username)
|
||||
}
|
||||
|
||||
// IsRegistrationEnabled checks if new user registration is enabled
|
||||
func (a *Auth) IsRegistrationEnabled() (bool, error) {
|
||||
var value string
|
||||
err := a.db.QueryRow("SELECT value FROM settings WHERE key = ?", "registration_enabled").Scan(&value)
|
||||
if err == sql.ErrNoRows {
|
||||
// Default to enabled if setting doesn't exist
|
||||
return true, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check registration setting: %w", err)
|
||||
}
|
||||
return value == "true", nil
|
||||
}
|
||||
|
||||
// SetRegistrationEnabled sets whether new user registration is enabled
|
||||
func (a *Auth) SetRegistrationEnabled(enabled bool) error {
|
||||
value := "false"
|
||||
if enabled {
|
||||
value = "true"
|
||||
}
|
||||
|
||||
// Check if setting exists
|
||||
var exists bool
|
||||
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", "registration_enabled").Scan(&exists)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if setting exists: %w", err)
|
||||
}
|
||||
|
||||
if exists {
|
||||
// Update existing setting
|
||||
_, err = a.db.Exec(
|
||||
"UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?",
|
||||
value, "registration_enabled",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update setting: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Insert new setting
|
||||
_, err = a.db.Exec(
|
||||
"INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)",
|
||||
"registration_enabled", value,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert setting: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getOrCreateUser gets or creates a user in the database
|
||||
// Automatically links accounts by email across different OAuth providers and local login
|
||||
func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session, error) {
|
||||
var userID int64
|
||||
var dbEmail, dbName string
|
||||
var isAdmin bool
|
||||
var dbProvider, dbOAuthID string
|
||||
|
||||
// First, try to find by provider + oauth_id
|
||||
err := a.db.QueryRow(
|
||||
"SELECT id, email, name, is_admin FROM users WHERE oauth_provider = ? AND oauth_id = ?",
|
||||
"SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE oauth_provider = ? AND oauth_id = ?",
|
||||
provider, oauthID,
|
||||
).Scan(&userID, &dbEmail, &dbName, &isAdmin)
|
||||
).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
// Check if this is the first user
|
||||
var userCount int
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
isAdmin = userCount == 0
|
||||
// Not found by provider+oauth_id, check by email for account linking
|
||||
err = a.db.QueryRow(
|
||||
"SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE email = ?",
|
||||
email,
|
||||
).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID)
|
||||
|
||||
// Create new user
|
||||
result, err := a.db.Exec(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?)",
|
||||
email, name, provider, oauthID, isAdmin,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
if err == sql.ErrNoRows {
|
||||
// User doesn't exist, check if registration is enabled
|
||||
registrationEnabled, err := a.IsRegistrationEnabled()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check registration setting: %w", err)
|
||||
}
|
||||
if !registrationEnabled {
|
||||
return nil, fmt.Errorf("registration is disabled")
|
||||
}
|
||||
|
||||
// Check if this is the first user
|
||||
var userCount int
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
isAdmin = userCount == 0
|
||||
|
||||
// Create new user
|
||||
err = a.db.QueryRow(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?) RETURNING id",
|
||||
email, name, provider, oauthID, isAdmin,
|
||||
).Scan(&userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("failed to query user by email: %w", err)
|
||||
} else {
|
||||
// User exists with same email but different provider - link accounts by updating provider info
|
||||
// This allows the user to log in with any provider that has the same email
|
||||
_, err = a.db.Exec(
|
||||
"UPDATE users SET oauth_provider = ?, oauth_id = ?, name = ? WHERE id = ?",
|
||||
provider, oauthID, name, userID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to link account: %w", err)
|
||||
}
|
||||
log.Printf("Linked account: user %d (email: %s) now accessible via %s provider", userID, email, provider)
|
||||
}
|
||||
userID, _ = result.LastInsertId()
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("failed to query user: %w", err)
|
||||
} else {
|
||||
// Update user info if changed
|
||||
// User found by provider+oauth_id, update info if changed
|
||||
if dbEmail != email || dbName != name {
|
||||
_, err = a.db.Exec(
|
||||
"UPDATE users SET email = ?, name = ? WHERE id = ?",
|
||||
@@ -238,13 +410,17 @@ func (a *Auth) Middleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("session_id")
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
session, ok := a.GetSession(cookie.Value)
|
||||
if !ok {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -275,19 +451,25 @@ func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
// First check authentication
|
||||
cookie, err := r.Cookie("session_id")
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
session, ok := a.GetSession(cookie.Value)
|
||||
if !ok {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
// Then check admin status
|
||||
if !session.IsAdmin {
|
||||
http.Error(w, "Forbidden: Admin access required", http.StatusForbidden)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Forbidden: Admin access required"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -300,3 +482,221 @@ func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// IsLocalLoginEnabled returns whether local login is enabled
|
||||
// Local login is enabled when ENABLE_LOCAL_AUTH environment variable is set to "true"
|
||||
func (a *Auth) IsLocalLoginEnabled() bool {
|
||||
return os.Getenv("ENABLE_LOCAL_AUTH") == "true"
|
||||
}
|
||||
|
||||
// IsGoogleOAuthConfigured returns whether Google OAuth is configured
|
||||
func (a *Auth) IsGoogleOAuthConfigured() bool {
|
||||
return a.googleConfig != nil
|
||||
}
|
||||
|
||||
// IsDiscordOAuthConfigured returns whether Discord OAuth is configured
|
||||
func (a *Auth) IsDiscordOAuthConfigured() bool {
|
||||
return a.discordConfig != nil
|
||||
}
|
||||
|
||||
// LocalLogin handles local username/password authentication
|
||||
func (a *Auth) LocalLogin(username, password string) (*Session, error) {
|
||||
// Find user by email (local users use email as username)
|
||||
email := username
|
||||
var userID int64
|
||||
var dbEmail, dbName, passwordHash string
|
||||
var isAdmin bool
|
||||
|
||||
err := a.db.QueryRow(
|
||||
"SELECT id, email, name, password_hash, is_admin FROM users WHERE email = ? AND oauth_provider = 'local'",
|
||||
email,
|
||||
).Scan(&userID, &dbEmail, &dbName, &passwordHash, &isAdmin)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query user: %w", err)
|
||||
}
|
||||
|
||||
// Verify password
|
||||
if passwordHash == "" {
|
||||
return nil, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
// Create session
|
||||
session := &Session{
|
||||
UserID: userID,
|
||||
Email: dbEmail,
|
||||
Name: dbName,
|
||||
IsAdmin: isAdmin,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// RegisterLocalUser creates a new local user account
|
||||
func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error) {
|
||||
// Check if registration is enabled
|
||||
registrationEnabled, err := a.IsRegistrationEnabled()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check registration setting: %w", err)
|
||||
}
|
||||
if !registrationEnabled {
|
||||
return nil, fmt.Errorf("registration is disabled")
|
||||
}
|
||||
|
||||
// Check if user already exists
|
||||
var exists bool
|
||||
err = a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ?)", email).Scan(&exists)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if user exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, fmt.Errorf("user with this email already exists")
|
||||
}
|
||||
|
||||
// Hash password
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
// Check if this is the first user (make them admin)
|
||||
var userCount int
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
isAdmin := userCount == 0
|
||||
|
||||
// Create user
|
||||
var userID int64
|
||||
err = a.db.QueryRow(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?) RETURNING id",
|
||||
email, name, email, string(hashedPassword), isAdmin,
|
||||
).Scan(&userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
// Create session
|
||||
session := &Session{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
Name: name,
|
||||
IsAdmin: isAdmin,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// ChangePassword allows a user to change their own password
|
||||
func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) error {
|
||||
// Get current password hash
|
||||
var passwordHash string
|
||||
err := a.db.QueryRow("SELECT password_hash FROM users WHERE id = ? AND oauth_provider = 'local'", userID).Scan(&passwordHash)
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("user not found or not a local user")
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query user: %w", err)
|
||||
}
|
||||
|
||||
// Verify old password
|
||||
if passwordHash == "" {
|
||||
return fmt.Errorf("user has no password set")
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(oldPassword))
|
||||
if err != nil {
|
||||
return fmt.Errorf("incorrect old password")
|
||||
}
|
||||
|
||||
// Hash new password
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
// Update password
|
||||
_, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AdminChangePassword allows an admin to change any user's password without knowing the old password
|
||||
func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error {
|
||||
// Verify user exists and is a local user
|
||||
var exists bool
|
||||
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ? AND oauth_provider = 'local')", targetUserID).Scan(&exists)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if user exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return fmt.Errorf("user not found or not a local user")
|
||||
}
|
||||
|
||||
// Hash new password
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
// Update password
|
||||
_, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), targetUserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFirstUserID returns the ID of the first user (user with the lowest ID)
|
||||
func (a *Auth) GetFirstUserID() (int64, error) {
|
||||
var firstUserID int64
|
||||
err := a.db.QueryRow("SELECT id FROM users ORDER BY id ASC LIMIT 1").Scan(&firstUserID)
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, fmt.Errorf("no users found")
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get first user ID: %w", err)
|
||||
}
|
||||
return firstUserID, nil
|
||||
}
|
||||
|
||||
// SetUserAdminStatus allows an admin to change a user's admin status
|
||||
func (a *Auth) SetUserAdminStatus(targetUserID int64, isAdmin bool) error {
|
||||
// Verify user exists
|
||||
var exists bool
|
||||
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", targetUserID).Scan(&exists)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if user exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return fmt.Errorf("user not found")
|
||||
}
|
||||
|
||||
// Prevent removing admin status from the first user
|
||||
firstUserID, err := a.GetFirstUserID()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check first user: %w", err)
|
||||
}
|
||||
if targetUserID == firstUserID && !isAdmin {
|
||||
return fmt.Errorf("cannot remove admin status from the first user")
|
||||
}
|
||||
|
||||
// Update admin status
|
||||
_, err = a.db.Exec("UPDATE users SET is_admin = ? WHERE id = ?", isAdmin, targetUserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update admin status: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,25 +8,36 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Secrets handles secret and token management
|
||||
type Secrets struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
fixedRegistrationToken string // Fixed token from environment variable (reusable, never expires)
|
||||
}
|
||||
|
||||
// NewSecrets creates a new secrets manager
|
||||
func NewSecrets(db *sql.DB) (*Secrets, error) {
|
||||
s := &Secrets{db: db}
|
||||
|
||||
|
||||
// Check for fixed registration token from environment
|
||||
fixedToken := os.Getenv("FIXED_REGISTRATION_TOKEN")
|
||||
if fixedToken != "" {
|
||||
s.fixedRegistrationToken = fixedToken
|
||||
log.Printf("Fixed registration token enabled (from FIXED_REGISTRATION_TOKEN env var)")
|
||||
log.Printf("WARNING: Fixed registration token is reusable and never expires - use only for testing/development!")
|
||||
}
|
||||
|
||||
// Ensure manager secret exists
|
||||
if err := s.ensureManagerSecret(); err != nil {
|
||||
return nil, fmt.Errorf("failed to ensure manager secret: %w", err)
|
||||
}
|
||||
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
@@ -37,20 +48,20 @@ func (s *Secrets) ensureManagerSecret() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check manager secrets: %w", err)
|
||||
}
|
||||
|
||||
|
||||
if count == 0 {
|
||||
// Generate new manager secret
|
||||
secret, err := generateSecret(32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate manager secret: %w", err)
|
||||
}
|
||||
|
||||
|
||||
_, err = s.db.Exec("INSERT INTO manager_secrets (secret) VALUES (?)", secret)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store manager secret: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -70,9 +81,9 @@ func (s *Secrets) GenerateRegistrationToken(createdBy int64, expiresIn time.Dura
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate token: %w", err)
|
||||
}
|
||||
|
||||
|
||||
expiresAt := time.Now().Add(expiresIn)
|
||||
|
||||
|
||||
_, err = s.db.Exec(
|
||||
"INSERT INTO registration_tokens (token, expires_at, created_by) VALUES (?, ?, ?)",
|
||||
token, expiresAt, createdBy,
|
||||
@@ -80,43 +91,67 @@ func (s *Secrets) GenerateRegistrationToken(createdBy int64, expiresIn time.Dura
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to store registration token: %w", err)
|
||||
}
|
||||
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// TokenValidationResult represents the result of token validation
|
||||
type TokenValidationResult struct {
|
||||
Valid bool
|
||||
Reason string // "valid", "not_found", "already_used", "expired"
|
||||
Error error
|
||||
}
|
||||
|
||||
// ValidateRegistrationToken validates a registration token
|
||||
func (s *Secrets) ValidateRegistrationToken(token string) (bool, error) {
|
||||
result, err := s.ValidateRegistrationTokenDetailed(token)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
// For backward compatibility, return just the valid boolean
|
||||
return result.Valid, nil
|
||||
}
|
||||
|
||||
// ValidateRegistrationTokenDetailed validates a registration token and returns detailed result
|
||||
func (s *Secrets) ValidateRegistrationTokenDetailed(token string) (*TokenValidationResult, error) {
|
||||
// Check fixed token first (if set) - it's reusable and never expires
|
||||
if s.fixedRegistrationToken != "" && token == s.fixedRegistrationToken {
|
||||
log.Printf("Fixed registration token used (from FIXED_REGISTRATION_TOKEN env var)")
|
||||
return &TokenValidationResult{Valid: true, Reason: "valid"}, nil
|
||||
}
|
||||
|
||||
// Check database tokens
|
||||
var used bool
|
||||
var expiresAt time.Time
|
||||
var id int64
|
||||
|
||||
|
||||
err := s.db.QueryRow(
|
||||
"SELECT id, expires_at, used FROM registration_tokens WHERE token = ?",
|
||||
token,
|
||||
).Scan(&id, &expiresAt, &used)
|
||||
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
return &TokenValidationResult{Valid: false, Reason: "not_found"}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to query token: %w", err)
|
||||
return nil, fmt.Errorf("failed to query token: %w", err)
|
||||
}
|
||||
|
||||
|
||||
if used {
|
||||
return false, nil
|
||||
return &TokenValidationResult{Valid: false, Reason: "already_used"}, nil
|
||||
}
|
||||
|
||||
|
||||
if time.Now().After(expiresAt) {
|
||||
return false, nil
|
||||
return &TokenValidationResult{Valid: false, Reason: "expired"}, nil
|
||||
}
|
||||
|
||||
|
||||
// Mark token as used
|
||||
_, err = s.db.Exec("UPDATE registration_tokens SET used = 1 WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to mark token as used: %w", err)
|
||||
return nil, fmt.Errorf("failed to mark token as used: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
|
||||
return &TokenValidationResult{Valid: true, Reason: "valid"}, nil
|
||||
}
|
||||
|
||||
// ListRegistrationTokens lists all registration tokens
|
||||
@@ -130,19 +165,19 @@ func (s *Secrets) ListRegistrationTokens() ([]map[string]interface{}, error) {
|
||||
return nil, fmt.Errorf("failed to query tokens: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
|
||||
var tokens []map[string]interface{}
|
||||
for rows.Next() {
|
||||
var id, createdBy sql.NullInt64
|
||||
var token string
|
||||
var expiresAt, createdAt time.Time
|
||||
var used bool
|
||||
|
||||
|
||||
err := rows.Scan(&id, &token, &expiresAt, &used, &createdAt, &createdBy)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
tokens = append(tokens, map[string]interface{}{
|
||||
"id": id.Int64,
|
||||
"token": token,
|
||||
@@ -152,7 +187,7 @@ func (s *Secrets) ListRegistrationTokens() ([]map[string]interface{}, error) {
|
||||
"created_by": createdBy.Int64,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
@@ -181,28 +216,29 @@ func VerifyRequest(r *http.Request, secret string, maxAge time.Duration) (bool,
|
||||
if signature == "" {
|
||||
return false, fmt.Errorf("missing signature")
|
||||
}
|
||||
|
||||
|
||||
timestampStr := r.Header.Get("X-Runner-Timestamp")
|
||||
if timestampStr == "" {
|
||||
return false, fmt.Errorf("missing timestamp")
|
||||
}
|
||||
|
||||
var timestamp time.Time
|
||||
_, err := fmt.Sscanf(timestampStr, "%d", ×tamp)
|
||||
|
||||
var timestampUnix int64
|
||||
_, err := fmt.Sscanf(timestampStr, "%d", ×tampUnix)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("invalid timestamp: %w", err)
|
||||
}
|
||||
|
||||
timestamp := time.Unix(timestampUnix, 0)
|
||||
|
||||
// Check timestamp is not too old
|
||||
if time.Since(timestamp) > maxAge {
|
||||
return false, fmt.Errorf("request too old")
|
||||
}
|
||||
|
||||
|
||||
// Check timestamp is not in the future (allow 1 minute clock skew)
|
||||
if timestamp.After(time.Now().Add(1 * time.Minute)) {
|
||||
return false, fmt.Errorf("timestamp in future")
|
||||
}
|
||||
|
||||
|
||||
// Read body
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
@@ -210,10 +246,13 @@ func VerifyRequest(r *http.Request, secret string, maxAge time.Duration) (bool,
|
||||
}
|
||||
// Restore body for handler
|
||||
r.Body = io.NopCloser(strings.NewReader(string(bodyBytes)))
|
||||
|
||||
// Verify signature
|
||||
expectedSig := SignRequest(r.Method, r.URL.Path, string(bodyBytes), secret, timestamp)
|
||||
|
||||
|
||||
// Verify signature - use path without query parameters (query params are not part of signature)
|
||||
// The runner signs with the path including query params, but we verify with just the path
|
||||
// This is intentional - query params are for identification, not part of the signature
|
||||
path := r.URL.Path
|
||||
expectedSig := SignRequest(r.Method, path, string(bodyBytes), secret, timestamp)
|
||||
|
||||
return hmac.Equal([]byte(signature), []byte(expectedSig)), nil
|
||||
}
|
||||
|
||||
@@ -241,4 +280,3 @@ func generateSecret(length int) (string, error) {
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -3,8 +3,9 @@ package database
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
_ "github.com/marcboeker/go-duckdb/v2"
|
||||
)
|
||||
|
||||
// DB wraps the database connection
|
||||
@@ -14,7 +15,7 @@ type DB struct {
|
||||
|
||||
// NewDB creates a new database connection
|
||||
func NewDB(dbPath string) (*DB, error) {
|
||||
db, err := sql.Open("sqlite3", dbPath+"?_foreign_keys=1")
|
||||
db, err := sql.Open("duckdb", dbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
@@ -33,114 +34,133 @@ func NewDB(dbPath string) (*DB, error) {
|
||||
|
||||
// migrate runs database migrations
|
||||
func (db *DB) migrate() error {
|
||||
// Create sequences for auto-incrementing primary keys
|
||||
sequences := []string{
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_users_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_jobs_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_runners_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_tasks_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_job_files_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_manager_secrets_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_registration_tokens_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_task_logs_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_task_steps_id START 1`,
|
||||
}
|
||||
|
||||
for _, seq := range sequences {
|
||||
if _, err := db.Exec(seq); err != nil {
|
||||
return fmt.Errorf("failed to create sequence: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
schema := `
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_users_id'),
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
oauth_provider TEXT NOT NULL,
|
||||
oauth_id TEXT NOT NULL,
|
||||
is_admin BOOLEAN NOT NULL DEFAULT 0,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
password_hash TEXT,
|
||||
is_admin BOOLEAN NOT NULL DEFAULT false,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(oauth_provider, oauth_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL,
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_jobs_id'),
|
||||
user_id BIGINT NOT NULL,
|
||||
job_type TEXT NOT NULL DEFAULT 'render',
|
||||
name TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
progress REAL NOT NULL DEFAULT 0.0,
|
||||
frame_start INTEGER NOT NULL,
|
||||
frame_end INTEGER NOT NULL,
|
||||
output_format TEXT NOT NULL DEFAULT 'PNG',
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at DATETIME,
|
||||
completed_at DATETIME,
|
||||
frame_start INTEGER,
|
||||
frame_end INTEGER,
|
||||
output_format TEXT,
|
||||
allow_parallel_runners BOOLEAN,
|
||||
timeout_seconds INTEGER DEFAULT 86400,
|
||||
blend_metadata TEXT,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
error_message TEXT,
|
||||
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS runners (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_runners_id'),
|
||||
name TEXT NOT NULL,
|
||||
hostname TEXT NOT NULL,
|
||||
ip_address TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'offline',
|
||||
last_heartbeat DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
last_heartbeat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
capabilities TEXT,
|
||||
registration_token TEXT,
|
||||
runner_secret TEXT,
|
||||
manager_secret TEXT,
|
||||
verified BOOLEAN NOT NULL DEFAULT 0,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
verified BOOLEAN NOT NULL DEFAULT false,
|
||||
priority INTEGER NOT NULL DEFAULT 100,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
job_id INTEGER NOT NULL,
|
||||
runner_id INTEGER,
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_tasks_id'),
|
||||
job_id BIGINT NOT NULL,
|
||||
runner_id BIGINT,
|
||||
frame_start INTEGER NOT NULL,
|
||||
frame_end INTEGER NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
output_path TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at DATETIME,
|
||||
completed_at DATETIME,
|
||||
error_message TEXT,
|
||||
FOREIGN KEY (job_id) REFERENCES jobs(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (runner_id) REFERENCES runners(id) ON DELETE SET NULL
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
error_message TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS job_files (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
job_id INTEGER NOT NULL,
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_job_files_id'),
|
||||
job_id BIGINT NOT NULL,
|
||||
file_type TEXT NOT NULL,
|
||||
file_path TEXT NOT NULL,
|
||||
file_name TEXT NOT NULL,
|
||||
file_size INTEGER NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (job_id) REFERENCES jobs(id) ON DELETE CASCADE
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS manager_secrets (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_manager_secrets_id'),
|
||||
secret TEXT UNIQUE NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS registration_tokens (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_registration_tokens_id'),
|
||||
token TEXT UNIQUE NOT NULL,
|
||||
expires_at DATETIME NOT NULL,
|
||||
used BOOLEAN NOT NULL DEFAULT 0,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
created_by INTEGER,
|
||||
FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE SET NULL
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
used BOOLEAN NOT NULL DEFAULT false,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
created_by BIGINT,
|
||||
FOREIGN KEY (created_by) REFERENCES users(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS task_logs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_id INTEGER NOT NULL,
|
||||
runner_id INTEGER,
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_task_logs_id'),
|
||||
task_id BIGINT NOT NULL,
|
||||
runner_id BIGINT,
|
||||
log_level TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
step_name TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (task_id) REFERENCES tasks(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (runner_id) REFERENCES runners(id) ON DELETE SET NULL
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS task_steps (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_id INTEGER NOT NULL,
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_task_steps_id'),
|
||||
task_id BIGINT NOT NULL,
|
||||
step_name TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
started_at DATETIME,
|
||||
completed_at DATETIME,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
duration_ms INTEGER,
|
||||
error_message TEXT,
|
||||
FOREIGN KEY (task_id) REFERENCES tasks(id) ON DELETE CASCADE
|
||||
error_message TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_user_id ON jobs(user_id);
|
||||
@@ -156,6 +176,12 @@ func (db *DB) migrate() error {
|
||||
CREATE INDEX IF NOT EXISTS idx_task_logs_runner_id ON task_logs(runner_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_task_steps_task_id ON task_steps(task_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_runners_last_heartbeat ON runners(last_heartbeat);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS settings (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
`
|
||||
|
||||
if _, err := db.Exec(schema); err != nil {
|
||||
@@ -165,30 +191,57 @@ func (db *DB) migrate() error {
|
||||
// Migrate existing tables to add new columns
|
||||
migrations := []string{
|
||||
// Add is_admin to users if it doesn't exist
|
||||
`ALTER TABLE users ADD COLUMN is_admin BOOLEAN NOT NULL DEFAULT 0`,
|
||||
`ALTER TABLE users ADD COLUMN IF NOT EXISTS is_admin BOOLEAN NOT NULL DEFAULT false`,
|
||||
// Add new columns to runners if they don't exist
|
||||
`ALTER TABLE runners ADD COLUMN registration_token TEXT`,
|
||||
`ALTER TABLE runners ADD COLUMN runner_secret TEXT`,
|
||||
`ALTER TABLE runners ADD COLUMN manager_secret TEXT`,
|
||||
`ALTER TABLE runners ADD COLUMN verified BOOLEAN NOT NULL DEFAULT 0`,
|
||||
`ALTER TABLE runners ADD COLUMN IF NOT EXISTS registration_token TEXT`,
|
||||
`ALTER TABLE runners ADD COLUMN IF NOT EXISTS runner_secret TEXT`,
|
||||
`ALTER TABLE runners ADD COLUMN IF NOT EXISTS manager_secret TEXT`,
|
||||
`ALTER TABLE runners ADD COLUMN IF NOT EXISTS verified BOOLEAN NOT NULL DEFAULT false`,
|
||||
`ALTER TABLE runners ADD COLUMN IF NOT EXISTS priority INTEGER NOT NULL DEFAULT 100`,
|
||||
// Add allow_parallel_runners to jobs if it doesn't exist
|
||||
`ALTER TABLE jobs ADD COLUMN allow_parallel_runners BOOLEAN NOT NULL DEFAULT 1`,
|
||||
`ALTER TABLE jobs ADD COLUMN IF NOT EXISTS allow_parallel_runners BOOLEAN NOT NULL DEFAULT true`,
|
||||
// Add timeout_seconds to jobs if it doesn't exist
|
||||
`ALTER TABLE jobs ADD COLUMN timeout_seconds INTEGER DEFAULT 86400`,
|
||||
`ALTER TABLE jobs ADD COLUMN IF NOT EXISTS timeout_seconds INTEGER DEFAULT 86400`,
|
||||
// Add blend_metadata to jobs if it doesn't exist
|
||||
`ALTER TABLE jobs ADD COLUMN blend_metadata TEXT`,
|
||||
`ALTER TABLE jobs ADD COLUMN IF NOT EXISTS blend_metadata TEXT`,
|
||||
// Add job_type to jobs if it doesn't exist
|
||||
`ALTER TABLE jobs ADD COLUMN IF NOT EXISTS job_type TEXT DEFAULT 'render'`,
|
||||
// Add task_type to tasks if it doesn't exist
|
||||
`ALTER TABLE tasks ADD COLUMN task_type TEXT DEFAULT 'render'`,
|
||||
`ALTER TABLE tasks ADD COLUMN IF NOT EXISTS task_type TEXT DEFAULT 'render'`,
|
||||
// Add new columns to tasks if they don't exist
|
||||
`ALTER TABLE tasks ADD COLUMN current_step TEXT`,
|
||||
`ALTER TABLE tasks ADD COLUMN retry_count INTEGER DEFAULT 0`,
|
||||
`ALTER TABLE tasks ADD COLUMN max_retries INTEGER DEFAULT 3`,
|
||||
`ALTER TABLE tasks ADD COLUMN timeout_seconds INTEGER`,
|
||||
`ALTER TABLE tasks ADD COLUMN IF NOT EXISTS current_step TEXT`,
|
||||
`ALTER TABLE tasks ADD COLUMN IF NOT EXISTS retry_count INTEGER DEFAULT 0`,
|
||||
`ALTER TABLE tasks ADD COLUMN IF NOT EXISTS max_retries INTEGER DEFAULT 3`,
|
||||
`ALTER TABLE tasks ADD COLUMN IF NOT EXISTS timeout_seconds INTEGER`,
|
||||
}
|
||||
|
||||
for _, migration := range migrations {
|
||||
// SQLite doesn't support IF NOT EXISTS for ALTER TABLE, so we ignore errors
|
||||
db.Exec(migration)
|
||||
// DuckDB supports IF NOT EXISTS for ALTER TABLE, so we can safely execute
|
||||
if _, err := db.Exec(migration); err != nil {
|
||||
// Log but don't fail - column might already exist or table might not exist yet
|
||||
// This is fine for migrations that run after schema creation
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize registration_enabled setting (default: true) if it doesn't exist
|
||||
var settingCount int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
|
||||
if err == nil && settingCount == 0 {
|
||||
_, err = db.Exec("INSERT INTO settings (key, value) VALUES (?, ?)", "registration_enabled", "true")
|
||||
if err != nil {
|
||||
// Log but don't fail - setting might have been created by another process
|
||||
log.Printf("Note: Could not initialize registration_enabled setting: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
for _, migration := range migrations {
|
||||
// DuckDB supports IF NOT EXISTS for ALTER TABLE, so we can safely execute
|
||||
if _, err := db.Exec(migration); err != nil {
|
||||
// Log but don't fail - column might already exist or table might not exist yet
|
||||
// This is fine for migrations that run after schema creation
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -198,4 +251,3 @@ func (db *DB) migrate() error {
|
||||
func (db *DB) Close() error {
|
||||
return db.DB.Close()
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,10 +1,12 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Storage handles file storage operations
|
||||
@@ -135,3 +137,60 @@ func (s *Storage) GetFileSize(filePath string) (int64, error) {
|
||||
return info.Size(), nil
|
||||
}
|
||||
|
||||
// ExtractZip extracts a ZIP file to the destination directory
|
||||
// Returns a list of all extracted file paths
|
||||
func (s *Storage) ExtractZip(zipPath, destDir string) ([]string, error) {
|
||||
r, err := zip.OpenReader(zipPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open ZIP file: %w", err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
var extractedFiles []string
|
||||
|
||||
for _, f := range r.File {
|
||||
// Sanitize file path to prevent directory traversal
|
||||
destPath := filepath.Join(destDir, f.Name)
|
||||
if !strings.HasPrefix(destPath, filepath.Clean(destDir)+string(os.PathSeparator)) {
|
||||
return nil, fmt.Errorf("invalid file path in ZIP: %s", f.Name)
|
||||
}
|
||||
|
||||
// Create directory structure
|
||||
if f.FileInfo().IsDir() {
|
||||
if err := os.MkdirAll(destPath, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Create parent directories
|
||||
if err := os.MkdirAll(filepath.Dir(destPath), 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create parent directory: %w", err)
|
||||
}
|
||||
|
||||
// Extract file
|
||||
rc, err := f.Open()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open file in ZIP: %w", err)
|
||||
}
|
||||
|
||||
outFile, err := os.Create(destPath)
|
||||
if err != nil {
|
||||
rc.Close()
|
||||
return nil, fmt.Errorf("failed to create file: %w", err)
|
||||
}
|
||||
|
||||
_, err = io.Copy(outFile, rc)
|
||||
outFile.Close()
|
||||
rc.Close()
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to extract file: %w", err)
|
||||
}
|
||||
|
||||
extractedFiles = append(extractedFiles, destPath)
|
||||
}
|
||||
|
||||
return extractedFiles, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user