Files
jiggablend/internal/api/server.go

514 lines
14 KiB
Go

package api
import (
"database/sql"
"encoding/json"
"fmt"
"log"
"net/http"
"strconv"
"sync"
"time"
authpkg "fuego/internal/auth"
"fuego/internal/database"
"fuego/internal/storage"
"fuego/pkg/types"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/cors"
"github.com/gorilla/websocket"
)
// Server represents the API server
type Server struct {
db *database.DB
auth *authpkg.Auth
secrets *authpkg.Secrets
storage *storage.Storage
router *chi.Mux
// WebSocket connections
wsUpgrader websocket.Upgrader
runnerConns map[int64]*websocket.Conn
runnerConnsMu sync.RWMutex
frontendConns map[string]*websocket.Conn // key: "jobId:taskId"
frontendConnsMu sync.RWMutex
}
// NewServer creates a new API server
func NewServer(db *database.DB, auth *authpkg.Auth, storage *storage.Storage) (*Server, error) {
secrets, err := authpkg.NewSecrets(db.DB)
if err != nil {
return nil, fmt.Errorf("failed to initialize secrets: %w", err)
}
s := &Server{
db: db,
auth: auth,
secrets: secrets,
storage: storage,
router: chi.NewRouter(),
wsUpgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins for now
},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
},
runnerConns: make(map[int64]*websocket.Conn),
frontendConns: make(map[string]*websocket.Conn),
}
s.setupMiddleware()
s.setupRoutes()
s.StartBackgroundTasks()
return s, nil
}
// setupMiddleware configures middleware
func (s *Server) setupMiddleware() {
s.router.Use(middleware.Logger)
s.router.Use(middleware.Recoverer)
s.router.Use(middleware.Timeout(60 * time.Second))
s.router.Use(cors.Handler(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"},
ExposedHeaders: []string{"Link"},
AllowCredentials: true,
MaxAge: 300,
}))
}
// setupRoutes configures routes
func (s *Server) setupRoutes() {
// Public routes
s.router.Route("/api/auth", func(r chi.Router) {
r.Get("/google/login", s.handleGoogleLogin)
r.Get("/google/callback", s.handleGoogleCallback)
r.Get("/discord/login", s.handleDiscordLogin)
r.Get("/discord/callback", s.handleDiscordCallback)
r.Post("/logout", s.handleLogout)
r.Get("/me", s.handleGetMe)
})
// Protected routes
s.router.Route("/api/jobs", func(r chi.Router) {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(s.auth.Middleware(next.ServeHTTP))
})
r.Post("/", s.handleCreateJob)
r.Get("/", s.handleListJobs)
r.Get("/{id}", s.handleGetJob)
r.Delete("/{id}", s.handleCancelJob)
r.Post("/{id}/upload", s.handleUploadJobFile)
r.Get("/{id}/files", s.handleListJobFiles)
r.Get("/{id}/files/{fileId}/download", s.handleDownloadJobFile)
r.Get("/{id}/video", s.handleStreamVideo)
r.Get("/{id}/tasks/{taskId}/logs", s.handleGetTaskLogs)
r.Get("/{id}/tasks/{taskId}/logs/ws", s.handleStreamTaskLogsWebSocket)
r.Get("/{id}/tasks/{taskId}/steps", s.handleGetTaskSteps)
r.Post("/{id}/tasks/{taskId}/retry", s.handleRetryTask)
})
s.router.Route("/api/runners", func(r chi.Router) {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(s.auth.Middleware(next.ServeHTTP))
})
r.Get("/", s.handleListRunners)
})
// Admin routes
s.router.Route("/api/admin", func(r chi.Router) {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(s.auth.AdminMiddleware(next.ServeHTTP))
})
r.Route("/runners", func(r chi.Router) {
r.Route("/tokens", func(r chi.Router) {
r.Post("/", s.handleGenerateRegistrationToken)
r.Get("/", s.handleListRegistrationTokens)
r.Delete("/{id}", s.handleRevokeRegistrationToken)
})
r.Get("/", s.handleListRunnersAdmin)
r.Post("/{id}/verify", s.handleVerifyRunner)
r.Delete("/{id}", s.handleDeleteRunner)
})
})
// Runner API
s.router.Route("/api/runner", func(r chi.Router) {
// Registration doesn't require auth (uses token)
r.Post("/register", s.handleRegisterRunner)
// WebSocket endpoint (auth handled in handler)
r.Get("/ws", s.handleRunnerWebSocket)
// File operations still use HTTP (WebSocket not suitable for large files)
r.Group(func(r chi.Router) {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(s.runnerAuthMiddleware(next.ServeHTTP))
})
r.Post("/tasks/{id}/progress", s.handleUpdateTaskProgress)
r.Get("/files/{jobId}/{fileName}", s.handleDownloadFileForRunner)
r.Post("/files/{jobId}/upload", s.handleUploadFileFromRunner)
r.Get("/jobs/{jobId}/status", s.handleGetJobStatusForRunner)
r.Get("/jobs/{jobId}/files", s.handleGetJobFilesForRunner)
})
})
// Serve static files (built React app)
s.router.Handle("/*", http.FileServer(http.Dir("./web/dist")))
}
// ServeHTTP implements http.Handler
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.router.ServeHTTP(w, r)
}
// JSON response helpers
func (s *Server) respondJSON(w http.ResponseWriter, status int, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(data); err != nil {
log.Printf("Failed to encode JSON response: %v", err)
}
}
func (s *Server) respondError(w http.ResponseWriter, status int, message string) {
s.respondJSON(w, status, map[string]string{"error": message})
}
// Auth handlers
func (s *Server) handleGoogleLogin(w http.ResponseWriter, r *http.Request) {
url, err := s.auth.GoogleLoginURL()
if err != nil {
s.respondError(w, http.StatusInternalServerError, err.Error())
return
}
http.Redirect(w, r, url, http.StatusFound)
}
func (s *Server) handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
if code == "" {
s.respondError(w, http.StatusBadRequest, "Missing code parameter")
return
}
session, err := s.auth.GoogleCallback(r.Context(), code)
if err != nil {
s.respondError(w, http.StatusInternalServerError, err.Error())
return
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: sessionID,
Path: "/",
MaxAge: 86400,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
http.Redirect(w, r, "/", http.StatusFound)
}
func (s *Server) handleDiscordLogin(w http.ResponseWriter, r *http.Request) {
url, err := s.auth.DiscordLoginURL()
if err != nil {
s.respondError(w, http.StatusInternalServerError, err.Error())
return
}
http.Redirect(w, r, url, http.StatusFound)
}
func (s *Server) handleDiscordCallback(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
if code == "" {
s.respondError(w, http.StatusBadRequest, "Missing code parameter")
return
}
session, err := s.auth.DiscordCallback(r.Context(), code)
if err != nil {
s.respondError(w, http.StatusInternalServerError, err.Error())
return
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: sessionID,
Path: "/",
MaxAge: 86400,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
http.Redirect(w, r, "/", http.StatusFound)
}
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("session_id")
if err == nil {
s.auth.DeleteSession(cookie.Value)
}
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
})
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Logged out"})
}
func (s *Server) handleGetMe(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("session_id")
if err != nil {
s.respondError(w, http.StatusUnauthorized, "Not authenticated")
return
}
session, ok := s.auth.GetSession(cookie.Value)
if !ok {
s.respondError(w, http.StatusUnauthorized, "Invalid session")
return
}
s.respondJSON(w, http.StatusOK, map[string]interface{}{
"id": session.UserID,
"email": session.Email,
"name": session.Name,
"is_admin": session.IsAdmin,
})
}
// Helper to get user ID from context
func getUserID(r *http.Request) (int64, error) {
userID, ok := authpkg.GetUserID(r.Context())
if !ok {
return 0, fmt.Errorf("user ID not found in context")
}
return userID, nil
}
// Helper to parse ID from URL
func parseID(r *http.Request, param string) (int64, error) {
idStr := chi.URLParam(r, param)
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
return 0, fmt.Errorf("invalid ID: %s", idStr)
}
return id, nil
}
// StartBackgroundTasks starts background goroutines for error recovery
func (s *Server) StartBackgroundTasks() {
go s.recoverStuckTasks()
}
// recoverStuckTasks periodically checks for dead runners and stuck tasks
func (s *Server) recoverStuckTasks() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
// Also distribute tasks every 5 seconds
distributeTicker := time.NewTicker(5 * time.Second)
defer distributeTicker.Stop()
go func() {
for range distributeTicker.C {
s.distributeTasksToRunners()
}
}()
for range ticker.C {
func() {
defer func() {
if r := recover(); r != nil {
log.Printf("Panic in recoverStuckTasks: %v", r)
}
}()
// Find dead runners (no heartbeat for 90 seconds)
rows, err := s.db.Query(
`SELECT id FROM runners
WHERE last_heartbeat < datetime('now', '-90 seconds')
AND status = ?`,
types.RunnerStatusOnline,
)
if err != nil {
log.Printf("Failed to query dead runners: %v", err)
return
}
defer rows.Close()
var deadRunnerIDs []int64
for rows.Next() {
var runnerID int64
if err := rows.Scan(&runnerID); err == nil {
deadRunnerIDs = append(deadRunnerIDs, runnerID)
}
}
rows.Close()
if len(deadRunnerIDs) == 0 {
// Check for task timeouts
s.recoverTaskTimeouts()
return
}
// Reset tasks assigned to dead runners
for _, runnerID := range deadRunnerIDs {
// Get tasks assigned to this runner
taskRows, err := s.db.Query(
`SELECT id, retry_count, max_retries FROM tasks
WHERE runner_id = ? AND status = ?`,
runnerID, types.TaskStatusRunning,
)
if err != nil {
log.Printf("Failed to query tasks for runner %d: %v", runnerID, err)
continue
}
var tasksToReset []struct {
ID int64
RetryCount int
MaxRetries int
}
for taskRows.Next() {
var t struct {
ID int64
RetryCount int
MaxRetries int
}
if err := taskRows.Scan(&t.ID, &t.RetryCount, &t.MaxRetries); err == nil {
tasksToReset = append(tasksToReset, t)
}
}
taskRows.Close()
// Reset or fail tasks
for _, task := range tasksToReset {
if task.RetryCount >= task.MaxRetries {
// Mark as failed
_, err = s.db.Exec(
`UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL
WHERE id = ?`,
types.TaskStatusFailed, "Runner died, max retries exceeded", task.ID,
)
if err != nil {
log.Printf("Failed to mark task %d as failed: %v", task.ID, err)
}
} else {
// Reset to pending
_, err = s.db.Exec(
`UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL,
retry_count = retry_count + 1 WHERE id = ?`,
types.TaskStatusPending, task.ID,
)
if err != nil {
log.Printf("Failed to reset task %d: %v", task.ID, err)
} else {
// Add log entry
_, _ = s.db.Exec(
`INSERT INTO task_logs (task_id, log_level, message, step_name, created_at)
VALUES (?, ?, ?, ?, ?)`,
task.ID, types.LogLevelWarn, fmt.Sprintf("Runner died, task reset (retry %d/%d)", task.RetryCount+1, task.MaxRetries),
"", time.Now(),
)
}
}
}
// Mark runner as offline
_, _ = s.db.Exec(
`UPDATE runners SET status = ? WHERE id = ?`,
types.RunnerStatusOffline, runnerID,
)
}
// Check for task timeouts
s.recoverTaskTimeouts()
// Distribute newly recovered tasks
s.distributeTasksToRunners()
}()
}
}
// recoverTaskTimeouts handles tasks that have exceeded their timeout
func (s *Server) recoverTaskTimeouts() {
// Find tasks running longer than their timeout
rows, err := s.db.Query(
`SELECT t.id, t.runner_id, t.retry_count, t.max_retries, t.timeout_seconds, t.started_at
FROM tasks t
WHERE t.status = ?
AND t.started_at IS NOT NULL
AND (t.timeout_seconds IS NULL OR
datetime(t.started_at, '+' || t.timeout_seconds || ' seconds') < datetime('now'))`,
types.TaskStatusRunning,
)
if err != nil {
log.Printf("Failed to query timed out tasks: %v", err)
return
}
defer rows.Close()
for rows.Next() {
var taskID int64
var runnerID sql.NullInt64
var retryCount, maxRetries int
var timeoutSeconds sql.NullInt64
var startedAt time.Time
err := rows.Scan(&taskID, &runnerID, &retryCount, &maxRetries, &timeoutSeconds, &startedAt)
if err != nil {
continue
}
// Use default timeout if not set (5 minutes for frame tasks, 24 hours for FFmpeg)
timeout := 300 // 5 minutes default
if timeoutSeconds.Valid {
timeout = int(timeoutSeconds.Int64)
}
// Check if actually timed out
if time.Since(startedAt).Seconds() < float64(timeout) {
continue
}
if retryCount >= maxRetries {
// Mark as failed
_, err = s.db.Exec(
`UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL
WHERE id = ?`,
types.TaskStatusFailed, "Task timeout exceeded, max retries reached", taskID,
)
} else {
// Reset to pending
_, err = s.db.Exec(
`UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL,
retry_count = retry_count + 1 WHERE id = ?`,
types.TaskStatusPending, taskID,
)
if err == nil {
// Add log entry
_, _ = s.db.Exec(
`INSERT INTO task_logs (task_id, log_level, message, step_name, created_at)
VALUES (?, ?, ?, ?, ?)`,
taskID, types.LogLevelWarn, fmt.Sprintf("Task timeout exceeded, resetting (retry %d/%d)", retryCount+1, maxRetries),
"", time.Now(),
)
}
}
}
}