diff --git a/.gitignore b/.gitignore index f21d121..102880c 100644 --- a/.gitignore +++ b/.gitignore @@ -27,7 +27,11 @@ go.work jiggablend.db jiggablend.db.wal jiggablend.db-shm +jiggablend.db-journal +# Log files +*.log +logs/ # Secrets and configuration runner-secrets.json runner-secrets-*.json diff --git a/Makefile b/Makefile index d6c9970..19bfe8b 100644 --- a/Makefile +++ b/Makefile @@ -36,20 +36,20 @@ run: cleanup build init-test @echo "Starting manager and runner in parallel..." @echo "Press Ctrl+C to stop both..." @trap 'kill $$MANAGER_PID $$RUNNER_PID 2>/dev/null; exit' INT TERM; \ - bin/jiggablend manager & \ + bin/jiggablend manager -l manager.log & \ MANAGER_PID=$$!; \ sleep 2; \ - bin/jiggablend runner --api-key=jk_r0_test_key_123456789012345678901234567890 & \ + bin/jiggablend runner -l runner.log --api-key=jk_r0_test_key_123456789012345678901234567890 & \ RUNNER_PID=$$!; \ wait $$MANAGER_PID $$RUNNER_PID # Run manager server run-manager: cleanup-manager build init-test - bin/jiggablend manager + bin/jiggablend manager -l manager.log # Run runner run-runner: cleanup-runner build - bin/jiggablend runner --api-key=jk_r0_test_key_123456789012345678901234567890 + bin/jiggablend runner -l runner.log --api-key=jk_r0_test_key_123456789012345678901234567890 # Initialize for testing (first run setup) init-test: build diff --git a/cmd/jiggablend/cmd/manager.go b/cmd/jiggablend/cmd/manager.go index 96ac197..4f37aca 100644 --- a/cmd/jiggablend/cmd/manager.go +++ b/cmd/jiggablend/cmd/manager.go @@ -6,11 +6,11 @@ import ( "os/exec" "strings" - "jiggablend/internal/api" "jiggablend/internal/auth" "jiggablend/internal/config" "jiggablend/internal/database" "jiggablend/internal/logger" + manager "jiggablend/internal/manager" "jiggablend/internal/storage" "github.com/spf13/cobra" @@ -117,8 +117,16 @@ func runManager(cmd *cobra.Command, args []string) { } logger.Info("Blender is available") - // Create API server - server, err := api.NewServer(db, cfg, authHandler, storageHandler) + // Check if ImageMagick is available + if err := checkImageMagickAvailable(); err != nil { + logger.Fatalf("ImageMagick is not available: %v\n"+ + "The manager requires ImageMagick to be installed and in PATH for EXR preview conversion.\n"+ + "Please install ImageMagick and ensure 'magick' or 'convert' command is accessible.", err) + } + logger.Info("ImageMagick is available") + + // Create manager server + server, err := manager.NewManager(db, cfg, authHandler, storageHandler) if err != nil { logger.Fatalf("Failed to create server: %v", err) } @@ -150,3 +158,20 @@ func checkBlenderAvailable() error { } return nil } + +func checkImageMagickAvailable() error { + // Try 'magick' first (ImageMagick 7+) + cmd := exec.Command("magick", "--version") + output, err := cmd.CombinedOutput() + if err == nil { + return nil + } + + // Fall back to 'convert' (ImageMagick 6 or legacy mode) + cmd = exec.Command("convert", "--version") + output, err = cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to run 'magick --version' or 'convert --version': %w (output: %s)", err, string(output)) + } + return nil +} diff --git a/cmd/jiggablend/cmd/runner.go b/cmd/jiggablend/cmd/runner.go index e9df31d..ea30212 100644 --- a/cmd/jiggablend/cmd/runner.go +++ b/cmd/jiggablend/cmd/runner.go @@ -36,6 +36,7 @@ func init() { runnerCmd.Flags().StringP("log-file", "l", "", "Log file path (truncated on start, if not set logs only to stdout)") runnerCmd.Flags().String("log-level", "info", "Log level (debug, info, warn, error)") runnerCmd.Flags().BoolP("verbose", "v", false, "Enable verbose logging (same as --log-level=debug)") + runnerCmd.Flags().Duration("poll-interval", 5*time.Second, "Job polling interval") // Bind flags to viper with JIGGABLEND_ prefix runnerViper.SetEnvPrefix("JIGGABLEND") @@ -49,6 +50,7 @@ func init() { runnerViper.BindPFlag("log_file", runnerCmd.Flags().Lookup("log-file")) runnerViper.BindPFlag("log_level", runnerCmd.Flags().Lookup("log-level")) runnerViper.BindPFlag("verbose", runnerCmd.Flags().Lookup("verbose")) + runnerViper.BindPFlag("poll_interval", runnerCmd.Flags().Lookup("poll-interval")) } func runRunner(cmd *cobra.Command, args []string) { @@ -60,14 +62,15 @@ func runRunner(cmd *cobra.Command, args []string) { logFile := runnerViper.GetString("log_file") logLevel := runnerViper.GetString("log_level") verbose := runnerViper.GetBool("verbose") + pollInterval := runnerViper.GetDuration("poll_interval") - var client *runner.Client + var r *runner.Runner defer func() { - if r := recover(); r != nil { - logger.Errorf("Runner panicked: %v", r) - if client != nil { - client.CleanupWorkspace() + if rec := recover(); rec != nil { + logger.Errorf("Runner panicked: %v", rec) + if r != nil { + r.Cleanup() } os.Exit(1) } @@ -77,7 +80,7 @@ func runRunner(cmd *cobra.Command, args []string) { hostname, _ = os.Hostname() } - // Generate unique runner ID + // Generate unique runner ID suffix runnerIDStr := generateShortID() // Generate runner name with ID if not provided @@ -114,23 +117,24 @@ func runRunner(cmd *cobra.Command, args []string) { logger.Infof("Logging to file: %s", logFile) } - client = runner.NewClient(managerURL, name, hostname) + // Create runner + r = runner.New(managerURL, name, hostname) + + // Check for required tools early to fail fast + if err := r.CheckRequiredTools(); err != nil { + logger.Fatalf("Required tool check failed: %v", err) + } // Clean up orphaned workspace directories - client.CleanupWorkspace() + r.Cleanup() - // Probe capabilities + // Probe capabilities and log them logger.Debug("Probing runner capabilities...") - client.ProbeCapabilities() - capabilities := client.GetCapabilities() + capabilities := r.ProbeCapabilities() capList := []string{} for cap, value := range capabilities { if enabled, ok := value.(bool); ok && enabled { capList = append(capList, cap) - } else if count, ok := value.(int); ok && count > 0 { - capList = append(capList, fmt.Sprintf("%s=%d", cap, count)) - } else if count, ok := value.(float64); ok && count > 0 { - capList = append(capList, fmt.Sprintf("%s=%.0f", cap, count)) } } if len(capList) > 0 { @@ -154,7 +158,7 @@ func runRunner(cmd *cobra.Command, args []string) { for { var err error - runnerID, _, _, err = client.Register(apiKey) + runnerID, err = r.Register(apiKey) if err == nil { logger.Infof("Registered runner with ID: %d", runnerID) break @@ -178,14 +182,6 @@ func runRunner(cmd *cobra.Command, args []string) { } } - // Start WebSocket connection - go client.ConnectWebSocketWithReconnect() - - // Start heartbeat loop - go client.HeartbeatLoop() - - logger.Info("Runner started, connecting to manager via WebSocket...") - // Signal handlers sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) @@ -193,13 +189,14 @@ func runRunner(cmd *cobra.Command, args []string) { go func() { sig := <-sigChan logger.Infof("Received signal: %v, killing all processes and cleaning up...", sig) - client.KillAllProcesses() - client.CleanupWorkspace() + r.KillAllProcesses() + r.Cleanup() os.Exit(0) }() - // Block forever - select {} + // Start polling for jobs + logger.Infof("Runner started, polling for jobs (interval: %v)...", pollInterval) + r.Start(pollInterval) } func generateShortID() string { diff --git a/examples/frame_0800.exr b/examples/frame_0800.exr new file mode 100644 index 0000000..b345c63 Binary files /dev/null and b/examples/frame_0800.exr differ diff --git a/examples/frame_0800.png b/examples/frame_0800.png new file mode 100644 index 0000000..6b84436 Binary files /dev/null and b/examples/frame_0800.png differ diff --git a/go.mod b/go.mod index 1cd729e..fe7a8fd 100644 --- a/go.mod +++ b/go.mod @@ -4,10 +4,9 @@ go 1.25.4 require ( github.com/go-chi/chi/v5 v5.2.3 - github.com/go-chi/cors v1.2.2 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 - github.com/mattn/go-sqlite3 v1.14.22 + github.com/mattn/go-sqlite3 v1.14.32 github.com/spf13/cobra v1.10.1 github.com/spf13/viper v1.21.0 golang.org/x/crypto v0.45.0 @@ -15,10 +14,14 @@ require ( ) require ( - cloud.google.com/go/compute/metadata v0.3.0 // indirect + cloud.google.com/go/compute/metadata v0.5.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/go-chi/cors v1.2.2 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/golang-migrate/migrate/v4 v4.19.0 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect diff --git a/go.sum b/go.sum index 4438838..55e91ab 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY= +cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -13,12 +15,19 @@ github.com/go-chi/cors v1.2.2 h1:Jmey33TE+b+rB7fT8MUy1u0I4L+NARQlK6LhzKPSyQE= github.com/go-chi/cors v1.2.2/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/golang-migrate/migrate/v4 v4.19.0 h1:RcjOnCGz3Or6HQYEJ/EEVLfWnmw9KnoigPSjzhCuaSE= +github.com/golang-migrate/migrate/v4 v4.19.0/go.mod h1:9dyEcu+hO+G9hPSw8AIg50yg622pXJsoHItQnDGZkI0= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -27,6 +36,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= +github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= diff --git a/internal/api/runners.go b/internal/api/runners.go deleted file mode 100644 index 97c9cc2..0000000 --- a/internal/api/runners.go +++ /dev/null @@ -1,2702 +0,0 @@ -package api - -import ( - "context" - "database/sql" - "encoding/json" - "errors" - "fmt" - "io" - "log" - "math/rand" - "net/http" - "net/url" - "path/filepath" - "sort" - "strconv" - "strings" - "sync" - "time" - - "jiggablend/pkg/types" - - "github.com/go-chi/chi/v5" - "github.com/gorilla/websocket" -) - -type contextKey string - -const runnerIDContextKey contextKey = "runner_id" - -// runnerAuthMiddleware verifies runner requests using API key -func (s *Server) runnerAuthMiddleware(next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // Get API key from header - apiKey := r.Header.Get("Authorization") - if apiKey == "" { - // Try alternative header - apiKey = r.Header.Get("X-API-Key") - } - if apiKey == "" { - s.respondError(w, http.StatusUnauthorized, "API key required") - return - } - - // Remove "Bearer " prefix if present - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - - // Validate API key and get its ID - apiKeyID, _, err := s.secrets.ValidateRunnerAPIKey(apiKey) - if err != nil { - log.Printf("API key validation failed: %v", err) - s.respondError(w, http.StatusUnauthorized, "invalid API key") - return - } - - // Get runner ID from query string or find runner by API key - runnerIDStr := r.URL.Query().Get("runner_id") - var runnerID int64 - - if runnerIDStr != "" { - // Runner ID provided - verify it belongs to this API key - _, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID) - if err != nil { - s.respondError(w, http.StatusBadRequest, "invalid runner_id") - return - } - - // For fixed API keys, skip database verification - if apiKeyID != -1 { - // Verify runner exists and uses this API key - var dbAPIKeyID sql.NullInt64 - err = s.db.With(func(conn *sql.DB) error { - return conn.QueryRow("SELECT api_key_id FROM runners WHERE id = ?", runnerID).Scan(&dbAPIKeyID) - }) - if err == sql.ErrNoRows { - s.respondError(w, http.StatusNotFound, "runner not found") - return - } - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to query runner API key: %v", err)) - return - } - if !dbAPIKeyID.Valid || dbAPIKeyID.Int64 != apiKeyID { - s.respondError(w, http.StatusForbidden, "runner does not belong to this API key") - return - } - } - } else { - // No runner ID provided - find the runner for this API key - // For simplicity, assume each API key has one runner - err = s.db.With(func(conn *sql.DB) error { - return conn.QueryRow("SELECT id FROM runners WHERE api_key_id = ?", apiKeyID).Scan(&runnerID) - }) - if err == sql.ErrNoRows { - s.respondError(w, http.StatusNotFound, "no runner found for this API key") - return - } - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to query runner by API key: %v", err)) - return - } - } - - // Add runner ID to context - ctx := r.Context() - ctx = context.WithValue(ctx, runnerIDContextKey, runnerID) - next(w, r.WithContext(ctx)) - } -} - -// handleRegisterRunner registers a new runner using an API key -func (s *Server) handleRegisterRunner(w http.ResponseWriter, r *http.Request) { - var req struct { - types.RegisterRunnerRequest - APIKey string `json:"api_key"` - Fingerprint string `json:"fingerprint,omitempty"` - } - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err)) - return - } - - // Lock to prevent concurrent registrations that could create duplicate runners - s.secrets.RegistrationMu.Lock() - defer s.secrets.RegistrationMu.Unlock() - - // Validate runner name - if req.Name == "" { - s.respondError(w, http.StatusBadRequest, "Runner name is required") - return - } - if len(req.Name) > 255 { - s.respondError(w, http.StatusBadRequest, "Runner name must be 255 characters or less") - return - } - - // Validate hostname - if req.Hostname != "" { - // Basic hostname validation (allow IP addresses and domain names) - if len(req.Hostname) > 253 { - s.respondError(w, http.StatusBadRequest, "Hostname must be 253 characters or less") - return - } - } - - // Validate capabilities JSON if provided - if req.Capabilities != "" { - var testCapabilities map[string]interface{} - if err := json.Unmarshal([]byte(req.Capabilities), &testCapabilities); err != nil { - s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid capabilities JSON: %v", err)) - return - } - } - - if req.APIKey == "" { - s.respondError(w, http.StatusBadRequest, "API key is required") - return - } - - // Validate API key - apiKeyID, apiKeyScope, err := s.secrets.ValidateRunnerAPIKey(req.APIKey) - if err != nil { - s.respondError(w, http.StatusUnauthorized, fmt.Sprintf("Invalid API key: %v", err)) - return - } - - // For fixed API keys (keyID = -1), skip fingerprint checking - // Set default priority if not provided - priority := 100 - if req.Priority != nil { - priority = *req.Priority - } - - // Register runner - var runnerID int64 - // For fixed API keys, don't store api_key_id in database - var dbAPIKeyID interface{} - if apiKeyID == -1 { - dbAPIKeyID = nil // NULL for fixed API keys - } else { - dbAPIKeyID = apiKeyID - } - - // Determine fingerprint value - fingerprint := req.Fingerprint - if apiKeyID == -1 || fingerprint == "" { - // For fixed API keys or when no fingerprint provided, generate a unique fingerprint - // to avoid conflicts while still maintaining some uniqueness - fingerprint = fmt.Sprintf("fixed-%s-%d", req.Name, time.Now().UnixNano()) - } - - // Check fingerprint uniqueness only for non-fixed API keys - if apiKeyID != -1 && req.Fingerprint != "" { - var existingRunnerID int64 - var existingAPIKeyID sql.NullInt64 - err = s.db.With(func(conn *sql.DB) error { - return conn.QueryRow( - "SELECT id, api_key_id FROM runners WHERE fingerprint = ?", - req.Fingerprint, - ).Scan(&existingRunnerID, &existingAPIKeyID) - }) - - if err == nil { - // Runner already exists with this fingerprint - if existingAPIKeyID.Valid && existingAPIKeyID.Int64 == apiKeyID { - // Same API key - update and return existing runner - log.Printf("Runner with fingerprint %s already exists (ID: %d), updating info", req.Fingerprint, existingRunnerID) - - err = s.db.With(func(conn *sql.DB) error { - _, err := conn.Exec( - `UPDATE runners SET name = ?, hostname = ?, capabilities = ?, status = ?, last_heartbeat = ? WHERE id = ?`, - req.Name, req.Hostname, req.Capabilities, types.RunnerStatusOnline, time.Now(), existingRunnerID, - ) - return err - }) - if err != nil { - log.Printf("Warning: Failed to update existing runner info: %v", err) - } - - s.respondJSON(w, http.StatusOK, map[string]interface{}{ - "id": existingRunnerID, - "name": req.Name, - "hostname": req.Hostname, - "status": types.RunnerStatusOnline, - "reused": true, // Indicates this was a re-registration - }) - return - } else { - // Different API key - reject registration - s.respondError(w, http.StatusConflict, "Runner with this fingerprint already registered with different API key") - return - } - } - // If err is not nil, it means no existing runner with this fingerprint - proceed with new registration - } - - // Insert runner - err = s.db.With(func(conn *sql.DB) error { - result, err := conn.Exec( - `INSERT INTO runners (name, hostname, ip_address, status, last_heartbeat, capabilities, - api_key_id, api_key_scope, priority, fingerprint) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - req.Name, req.Hostname, "", types.RunnerStatusOnline, time.Now(), req.Capabilities, - dbAPIKeyID, apiKeyScope, priority, fingerprint, - ) - if err != nil { - return err - } - runnerID, err = result.LastInsertId() - return err - }) - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to register runner: %v", err)) - return - } - - log.Printf("Registered new runner %s (ID: %d) with API key ID: %d", req.Name, runnerID, apiKeyID) - - // Return runner info - s.respondJSON(w, http.StatusCreated, map[string]interface{}{ - "id": runnerID, - "name": req.Name, - "hostname": req.Hostname, - "status": types.RunnerStatusOnline, - }) -} - -// handleRunnerPing allows runners to validate their secrets and connection -func (s *Server) handleRunnerPing(w http.ResponseWriter, r *http.Request) { - // This endpoint uses runnerAuthMiddleware, so if we get here, secrets are valid - // Get runner ID from context (set by runnerAuthMiddleware) - runnerID, ok := r.Context().Value(runnerIDContextKey).(int64) - if !ok { - s.respondError(w, http.StatusUnauthorized, "runner_id not found in context") - return - } - - // Update last heartbeat - err := s.db.With(func(conn *sql.DB) error { - _, err := conn.Exec( - `UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`, - time.Now(), types.RunnerStatusOnline, runnerID, - ) - return err - }) - if err != nil { - log.Printf("Warning: Failed to update runner heartbeat: %v", err) - } - - s.respondJSON(w, http.StatusOK, map[string]interface{}{ - "status": "ok", - "runner_id": runnerID, - "timestamp": time.Now().Unix(), - }) -} - -// handleUpdateTaskProgress updates task progress -func (s *Server) handleUpdateTaskProgress(w http.ResponseWriter, r *http.Request) { - _, err := parseID(r, "id") - if err != nil { - s.respondError(w, http.StatusBadRequest, err.Error()) - return - } - - var req struct { - Progress float64 `json:"progress"` - } - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err)) - return - } - - // This is mainly for logging/debugging, actual progress is calculated from completed tasks - s.respondJSON(w, http.StatusOK, map[string]string{"message": "Progress updated"}) -} - -// handleUpdateTaskStep handles step start/complete events from runners -func (s *Server) handleUpdateTaskStep(w http.ResponseWriter, r *http.Request) { - // Get runner ID from context (set by runnerAuthMiddleware) - runnerID, ok := r.Context().Value(runnerIDContextKey).(int64) - if !ok { - s.respondError(w, http.StatusUnauthorized, "runner_id not found in context") - return - } - - taskID, err := parseID(r, "id") - if err != nil { - s.respondError(w, http.StatusBadRequest, err.Error()) - return - } - - var req struct { - StepName string `json:"step_name"` - Status string `json:"status"` // "pending", "running", "completed", "failed", "skipped" - DurationMs *int `json:"duration_ms,omitempty"` - ErrorMessage string `json:"error_message,omitempty"` - } - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err)) - return - } - - // Verify task belongs to runner - var taskRunnerID sql.NullInt64 - err = s.db.With(func(conn *sql.DB) error { - return conn.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskID).Scan(&taskRunnerID) - }) - if err == sql.ErrNoRows { - s.respondError(w, http.StatusNotFound, "Task not found") - return - } - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify task: %v", err)) - return - } - if !taskRunnerID.Valid || taskRunnerID.Int64 != runnerID { - s.respondError(w, http.StatusForbidden, "Task does not belong to this runner") - return - } - - now := time.Now() - var stepID int64 - - // Check if step already exists - var existingStepID sql.NullInt64 - err = s.db.With(func(conn *sql.DB) error { - return conn.QueryRow( - `SELECT id FROM task_steps WHERE task_id = ? AND step_name = ?`, - taskID, req.StepName, - ).Scan(&existingStepID) - }) - - if err == sql.ErrNoRows || !existingStepID.Valid { - // Create new step - var startedAt *time.Time - var completedAt *time.Time - if req.Status == string(types.StepStatusRunning) || req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) { - startedAt = &now - } - if req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) { - completedAt = &now - } - - err = s.db.With(func(conn *sql.DB) error { - result, err := conn.Exec( - `INSERT INTO task_steps (task_id, step_name, status, started_at, completed_at, duration_ms, error_message) - VALUES (?, ?, ?, ?, ?, ?, ?)`, - taskID, req.StepName, req.Status, startedAt, completedAt, req.DurationMs, req.ErrorMessage, - ) - if err != nil { - return err - } - stepID, err = result.LastInsertId() - return err - }) - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create step: %v", err)) - return - } - } else { - // Update existing step - stepID = existingStepID.Int64 - var startedAt *time.Time - var completedAt *time.Time - - // Get existing started_at if status is running/completed/failed - if req.Status == string(types.StepStatusRunning) || req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) { - var existingStartedAt sql.NullTime - s.db.With(func(conn *sql.DB) error { - return conn.QueryRow(`SELECT started_at FROM task_steps WHERE id = ?`, stepID).Scan(&existingStartedAt) - }) - if existingStartedAt.Valid { - startedAt = &existingStartedAt.Time - } else { - startedAt = &now - } - } - - if req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) { - completedAt = &now - } - - err = s.db.With(func(conn *sql.DB) error { - _, err := conn.Exec( - `UPDATE task_steps SET status = ?, started_at = ?, completed_at = ?, duration_ms = ?, error_message = ? - WHERE id = ?`, - req.Status, startedAt, completedAt, req.DurationMs, req.ErrorMessage, stepID, - ) - return err - }) - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update step: %v", err)) - return - } - } - - // Get job ID for broadcasting - var jobID int64 - err = s.db.With(func(conn *sql.DB) error { - return conn.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&jobID) - }) - if err == nil { - // Broadcast step update to frontend - s.broadcastTaskUpdate(jobID, taskID, "step_update", map[string]interface{}{ - "step_id": stepID, - "step_name": req.StepName, - "status": req.Status, - "duration_ms": req.DurationMs, - "error_message": req.ErrorMessage, - }) - } - - s.respondJSON(w, http.StatusOK, map[string]interface{}{ - "step_id": stepID, - "message": "Step updated successfully", - }) -} - -// handleDownloadJobContext allows runners to download the job context tar -func (s *Server) handleDownloadJobContext(w http.ResponseWriter, r *http.Request) { - jobID, err := parseID(r, "jobId") - if err != nil { - s.respondError(w, http.StatusBadRequest, err.Error()) - return - } - - // Construct the context file path - contextPath := filepath.Join(s.storage.JobPath(jobID), "context.tar") - - // Check if context file exists - if !s.storage.FileExists(contextPath) { - log.Printf("Context archive not found for job %d", jobID) - s.respondError(w, http.StatusNotFound, "Context archive not found. The file may not have been uploaded successfully.") - return - } - - // Open and serve file - file, err := s.storage.GetFile(contextPath) - if err != nil { - s.respondError(w, http.StatusNotFound, "Context file not found on disk") - return - } - defer file.Close() - - // Set appropriate headers for tar file - w.Header().Set("Content-Type", "application/x-tar") - w.Header().Set("Content-Disposition", "attachment; filename=context.tar") - - // Stream the file to the response - io.Copy(w, file) -} - -// handleUploadFileFromRunner allows runners to upload output files -func (s *Server) handleUploadFileFromRunner(w http.ResponseWriter, r *http.Request) { - jobID, err := parseID(r, "jobId") - if err != nil { - s.respondError(w, http.StatusBadRequest, err.Error()) - return - } - - err = r.ParseMultipartForm(50 << 30) // 50 GB (for large output files) - if err != nil { - s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Failed to parse multipart form: %v", err)) - return - } - - file, header, err := r.FormFile("file") - if err != nil { - s.respondError(w, http.StatusBadRequest, "No file provided") - return - } - defer file.Close() - - // Save file - filePath, err := s.storage.SaveOutput(jobID, header.Filename, file) - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to save file: %v", err)) - return - } - - // Record in database - var fileID int64 - err = s.db.With(func(conn *sql.DB) error { - result, err := conn.Exec( - `INSERT INTO job_files (job_id, file_type, file_path, file_name, file_size) - VALUES (?, ?, ?, ?, ?)`, - jobID, types.JobFileTypeOutput, filePath, header.Filename, header.Size, - ) - if err != nil { - return err - } - fileID, err = result.LastInsertId() - return err - }) - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to record file: %v", err)) - return - } - - // Broadcast file addition - s.broadcastJobUpdate(jobID, "file_added", map[string]interface{}{ - "file_id": fileID, - "file_type": types.JobFileTypeOutput, - "file_name": header.Filename, - "file_size": header.Size, - }) - - s.respondJSON(w, http.StatusCreated, map[string]interface{}{ - "file_path": filePath, - "file_name": header.Filename, - }) -} - -// handleGetJobStatusForRunner allows runners to check job status -func (s *Server) handleGetJobStatusForRunner(w http.ResponseWriter, r *http.Request) { - jobID, err := parseID(r, "jobId") - if err != nil { - s.respondError(w, http.StatusBadRequest, err.Error()) - return - } - - var job types.Job - var startedAt, completedAt sql.NullTime - var errorMessage sql.NullString - - var jobType string - var frameStart, frameEnd sql.NullInt64 - var outputFormat sql.NullString - var allowParallelRunners sql.NullBool - err = s.db.With(func(conn *sql.DB) error { - return conn.QueryRow( - `SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format, - allow_parallel_runners, created_at, started_at, completed_at, error_message - FROM jobs WHERE id = ?`, - jobID, - ).Scan( - &job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress, - &frameStart, &frameEnd, &outputFormat, &allowParallelRunners, - &job.CreatedAt, &startedAt, &completedAt, &errorMessage, - ) - }) - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query job: %v", err)) - return - } - - job.JobType = types.JobType(jobType) - if frameStart.Valid { - fs := int(frameStart.Int64) - job.FrameStart = &fs - } - if frameEnd.Valid { - fe := int(frameEnd.Int64) - job.FrameEnd = &fe - } - if outputFormat.Valid { - job.OutputFormat = &outputFormat.String - } - if allowParallelRunners.Valid { - job.AllowParallelRunners = &allowParallelRunners.Bool - } - - if err == sql.ErrNoRows { - s.respondError(w, http.StatusNotFound, "Job not found") - return - } - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query job: %v", err)) - return - } - - if startedAt.Valid { - job.StartedAt = &startedAt.Time - } - if completedAt.Valid { - job.CompletedAt = &completedAt.Time - } - if errorMessage.Valid { - job.ErrorMessage = errorMessage.String - } - - s.respondJSON(w, http.StatusOK, job) -} - -// handleGetJobFilesForRunner allows runners to get job files -func (s *Server) handleGetJobFilesForRunner(w http.ResponseWriter, r *http.Request) { - jobID, err := parseID(r, "jobId") - if err != nil { - s.respondError(w, http.StatusBadRequest, err.Error()) - return - } - - var rows *sql.Rows - err = s.db.With(func(conn *sql.DB) error { - var err error - rows, err = conn.Query( - `SELECT id, job_id, file_type, file_path, file_name, file_size, created_at - FROM job_files WHERE job_id = ? ORDER BY file_name`, - jobID, - ) - return err - }) - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query files: %v", err)) - return - } - defer rows.Close() - - files := []types.JobFile{} - for rows.Next() { - var file types.JobFile - err := rows.Scan( - &file.ID, &file.JobID, &file.FileType, &file.FilePath, - &file.FileName, &file.FileSize, &file.CreatedAt, - ) - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan file: %v", err)) - return - } - files = append(files, file) - } - - s.respondJSON(w, http.StatusOK, files) -} - -// handleGetJobMetadataForRunner allows runners to get job metadata -func (s *Server) handleGetJobMetadataForRunner(w http.ResponseWriter, r *http.Request) { - jobID, err := parseID(r, "jobId") - if err != nil { - s.respondError(w, http.StatusBadRequest, err.Error()) - return - } - - var blendMetadataJSON sql.NullString - err = s.db.With(func(conn *sql.DB) error { - return conn.QueryRow( - `SELECT blend_metadata FROM jobs WHERE id = ?`, - jobID, - ).Scan(&blendMetadataJSON) - }) - - if err == sql.ErrNoRows { - s.respondError(w, http.StatusNotFound, "Job not found") - return - } - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query job: %v", err)) - return - } - - if !blendMetadataJSON.Valid || blendMetadataJSON.String == "" { - s.respondJSON(w, http.StatusOK, nil) - return - } - - var metadata types.BlendMetadata - if err := json.Unmarshal([]byte(blendMetadataJSON.String), &metadata); err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to parse metadata JSON: %v", err)) - return - } - - s.respondJSON(w, http.StatusOK, metadata) -} - -// handleDownloadFileForRunner allows runners to download a file by fileName -func (s *Server) handleDownloadFileForRunner(w http.ResponseWriter, r *http.Request) { - jobID, err := parseID(r, "jobId") - if err != nil { - s.respondError(w, http.StatusBadRequest, err.Error()) - return - } - - // Get fileName from URL path (may need URL decoding) - fileName := chi.URLParam(r, "fileName") - if fileName == "" { - s.respondError(w, http.StatusBadRequest, "fileName is required") - return - } - - // URL decode the fileName in case it contains encoded characters - decodedFileName, err := url.QueryUnescape(fileName) - if err != nil { - // If decoding fails, use original fileName - decodedFileName = fileName - } - - // Get file info from database - var filePath string - err = s.db.With(func(conn *sql.DB) error { - return conn.QueryRow( - `SELECT file_path FROM job_files WHERE job_id = ? AND file_name = ?`, - jobID, decodedFileName, - ).Scan(&filePath) - }) - if err == sql.ErrNoRows { - s.respondError(w, http.StatusNotFound, "File not found") - return - } - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query file: %v", err)) - return - } - - // Open file - file, err := s.storage.GetFile(filePath) - if err != nil { - s.respondError(w, http.StatusNotFound, "File not found on disk") - return - } - defer file.Close() - - // Determine content type based on file extension - contentType := "application/octet-stream" - fileNameLower := strings.ToLower(decodedFileName) - switch { - case strings.HasSuffix(fileNameLower, ".png"): - contentType = "image/png" - case strings.HasSuffix(fileNameLower, ".jpg") || strings.HasSuffix(fileNameLower, ".jpeg"): - contentType = "image/jpeg" - case strings.HasSuffix(fileNameLower, ".gif"): - contentType = "image/gif" - case strings.HasSuffix(fileNameLower, ".webp"): - contentType = "image/webp" - case strings.HasSuffix(fileNameLower, ".exr") || strings.HasSuffix(fileNameLower, ".EXR"): - contentType = "image/x-exr" - case strings.HasSuffix(fileNameLower, ".mp4"): - contentType = "video/mp4" - case strings.HasSuffix(fileNameLower, ".webm"): - contentType = "video/webm" - } - - // Set headers - w.Header().Set("Content-Type", contentType) - w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", decodedFileName)) - - // Stream file - io.Copy(w, file) -} - -// WebSocket message types -type WSMessage struct { - Type string `json:"type"` - Data json.RawMessage `json:"data"` - Timestamp int64 `json:"timestamp"` -} - -type WSTaskAssignment struct { - TaskID int64 `json:"task_id"` - JobID int64 `json:"job_id"` - JobName string `json:"job_name"` - OutputFormat string `json:"output_format"` - FrameStart int `json:"frame_start"` - FrameEnd int `json:"frame_end"` - TaskType string `json:"task_type"` - InputFiles []string `json:"input_files"` -} - -type WSLogEntry struct { - TaskID int64 `json:"task_id"` - LogLevel string `json:"log_level"` - Message string `json:"message"` - StepName string `json:"step_name,omitempty"` -} - -type WSTaskUpdate struct { - TaskID int64 `json:"task_id"` - Status string `json:"status"` - OutputPath string `json:"output_path,omitempty"` - Success bool `json:"success"` - Error string `json:"error,omitempty"` -} - -// handleRunnerWebSocket handles WebSocket connections from runners -func (s *Server) handleRunnerWebSocket(w http.ResponseWriter, r *http.Request) { - // Get API key from query params or headers - apiKey := r.URL.Query().Get("api_key") - if apiKey == "" { - apiKey = r.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - } - if apiKey == "" { - s.respondError(w, http.StatusBadRequest, "API key required") - return - } - - // Validate API key - apiKeyID, _, err := s.secrets.ValidateRunnerAPIKey(apiKey) - if err != nil { - s.respondError(w, http.StatusUnauthorized, fmt.Sprintf("Invalid API key: %v", err)) - return - } - - // Get runner ID from query params or find by API key - runnerIDStr := r.URL.Query().Get("runner_id") - var runnerID int64 - - if runnerIDStr != "" { - // Runner ID provided - verify it belongs to this API key - _, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID) - if err != nil { - s.respondError(w, http.StatusBadRequest, "invalid runner_id") - return - } - - // For fixed API keys, skip database verification - if apiKeyID != -1 { - var dbAPIKeyID sql.NullInt64 - err = s.db.With(func(conn *sql.DB) error { - return conn.QueryRow("SELECT api_key_id FROM runners WHERE id = ?", runnerID).Scan(&dbAPIKeyID) - }) - if err == sql.ErrNoRows { - s.respondError(w, http.StatusNotFound, "runner not found") - return - } - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to query runner API key: %v", err)) - return - } - if !dbAPIKeyID.Valid || dbAPIKeyID.Int64 != apiKeyID { - s.respondError(w, http.StatusForbidden, "runner does not belong to this API key") - return - } - } - } else { - // No runner ID provided - find the runner for this API key - err = s.db.With(func(conn *sql.DB) error { - return conn.QueryRow("SELECT id FROM runners WHERE api_key_id = ?", apiKeyID).Scan(&runnerID) - }) - if err == sql.ErrNoRows { - s.respondError(w, http.StatusNotFound, "no runner found for this API key") - return - } - if err != nil { - s.respondError(w, http.StatusInternalServerError, "database error") - return - } - } - - // Upgrade to WebSocket - conn, err := s.wsUpgrader.Upgrade(w, r, nil) - if err != nil { - log.Printf("Failed to upgrade WebSocket: %v", err) - return - } - defer conn.Close() - - // Register connection (must be done before any distribution checks) - // Fix race condition: Close old connection and create write mutex BEFORE registering new connection - var oldConn *websocket.Conn - s.runnerConnsMu.Lock() - if existingConn, exists := s.runnerConns[runnerID]; exists { - oldConn = existingConn - } - s.runnerConnsMu.Unlock() - - // Close old connection BEFORE registering new one to prevent race conditions - if oldConn != nil { - log.Printf("Runner %d: closing existing WebSocket connection (reconnection)", runnerID) - oldConn.Close() - } - - // Create write mutex BEFORE registering connection to prevent race condition - s.runnerConnsWriteMuMu.Lock() - s.runnerConnsWriteMu[runnerID] = &sync.Mutex{} - s.runnerConnsWriteMuMu.Unlock() - - // Now register the new connection - s.runnerConnsMu.Lock() - s.runnerConns[runnerID] = conn - s.runnerConnsMu.Unlock() - - log.Printf("Runner %d: WebSocket connection established successfully", runnerID) - - // Check if runner was offline and had tasks assigned - redistribute them - // This handles the case where the manager restarted and marked the runner offline - // but tasks were still assigned to it - s.db.With(func(conn *sql.DB) error { - var count int - err := conn.QueryRow( - `SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status = ?`, - runnerID, types.TaskStatusRunning, - ).Scan(&count) - if err == nil && count > 0 { - log.Printf("Runner %d reconnected with %d running tasks assigned - redistributing them", runnerID, count) - s.redistributeRunnerTasks(runnerID) - } - return nil - }) - - // Update runner status to online - s.db.With(func(conn *sql.DB) error { - _, _ = conn.Exec( - `UPDATE runners SET status = ?, last_heartbeat = ? WHERE id = ?`, - types.RunnerStatusOnline, time.Now(), runnerID, - ) - return nil - }) - - // Immediately try to distribute pending tasks to this newly connected runner - log.Printf("Runner %d connected, distributing pending tasks", runnerID) - s.triggerTaskDistribution() - - // Note: We don't log to task logs here because we don't know which tasks will be assigned yet - // Task assignment logging happens in distributeTasksToRunners - - // Cleanup on disconnect - // Fix race condition: Only cleanup if this is still the current connection (not replaced by reconnection) - defer func() { - log.Printf("Runner %d: WebSocket connection cleanup started", runnerID) - - // Check if this is still the current connection before cleanup - s.runnerConnsMu.Lock() - currentConn, stillCurrent := s.runnerConns[runnerID] - if !stillCurrent || currentConn != conn { - // Connection was replaced by a newer one, don't cleanup - s.runnerConnsMu.Unlock() - log.Printf("Runner %d: Skipping cleanup - connection was replaced by newer connection", runnerID) - return - } - // Remove connection from map - delete(s.runnerConns, runnerID) - s.runnerConnsMu.Unlock() - - // Update database status - err := s.db.With(func(conn *sql.DB) error { - _, err := conn.Exec( - `UPDATE runners SET status = ?, last_heartbeat = ? WHERE id = ?`, - types.RunnerStatusOffline, time.Now(), runnerID, - ) - return err - }) - if err != nil { - log.Printf("Warning: Failed to update runner %d status to offline: %v", runnerID, err) - } - - // Clean up write mutex - s.runnerConnsWriteMuMu.Lock() - delete(s.runnerConnsWriteMu, runnerID) - s.runnerConnsWriteMuMu.Unlock() - - // Immediately redistribute tasks that were assigned to this runner - log.Printf("Runner %d: WebSocket disconnected, redistributing tasks", runnerID) - s.redistributeRunnerTasks(runnerID) - - log.Printf("Runner %d: WebSocket connection cleanup completed", runnerID) - }() - - // Set pong handler to update heartbeat when we receive pong responses from runner - // Also reset read deadline to keep connection alive - conn.SetPongHandler(func(string) error { - conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Increased to 90 seconds - s.db.With(func(conn *sql.DB) error { - _, _ = conn.Exec( - `UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`, - time.Now(), types.RunnerStatusOnline, runnerID, - ) - return nil - }) - return nil - }) - - // Set read deadline to ensure we process control frames (like pong) - conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Increased to 90 seconds - - // Send ping every 30 seconds to trigger pong responses - go func() { - ticker := time.NewTicker(30 * time.Second) - defer ticker.Stop() - for range ticker.C { - s.runnerConnsMu.RLock() - currentConn, exists := s.runnerConns[runnerID] - s.runnerConnsMu.RUnlock() - if !exists || currentConn != conn { - // Connection was replaced or removed - return - } - // Get write mutex for this connection - s.runnerConnsWriteMuMu.RLock() - writeMu, hasMu := s.runnerConnsWriteMu[runnerID] - s.runnerConnsWriteMuMu.RUnlock() - if !hasMu || writeMu == nil { - return - } - // Send ping - runner should respond with pong automatically - // Reset read deadline before sending ping to ensure we can receive pong - conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Increased to 90 seconds - writeMu.Lock() - err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)) - writeMu.Unlock() - if err != nil { - // Write failed - connection is likely dead, read loop will detect and cleanup - log.Printf("Failed to send ping to runner %d: %v", runnerID, err) - return - } - } - }() - - // Handle incoming messages - for { - // Reset read deadline for each message - this is critical to keep connection alive - conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Increased to 90 seconds for safety - - var msg WSMessage - err := conn.ReadJSON(&msg) - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - log.Printf("WebSocket error for runner %d: %v", runnerID, err) - } - break - } - - // Reset read deadline after successfully reading a message - // This ensures the connection stays alive as long as we're receiving messages - conn.SetReadDeadline(time.Now().Add(90 * time.Second)) - - switch msg.Type { - case "heartbeat": - // Heartbeat messages are handled by pong handler (manager-side) - // Reset read deadline to keep connection alive - conn.SetReadDeadline(time.Now().Add(90 * time.Second)) - // Note: Heartbeat updates are consolidated to pong handler to avoid race conditions - // The pong handler is the single source of truth for heartbeat updates - - case "log_entry": - var logEntry WSLogEntry - if err := json.Unmarshal(msg.Data, &logEntry); err == nil { - s.handleWebSocketLog(runnerID, logEntry) - } - - case "task_update": - var taskUpdate WSTaskUpdate - if err := json.Unmarshal(msg.Data, &taskUpdate); err == nil { - s.handleWebSocketTaskUpdate(runnerID, taskUpdate) - } - - case "task_complete": - var taskUpdate WSTaskUpdate - if err := json.Unmarshal(msg.Data, &taskUpdate); err == nil { - s.handleWebSocketTaskComplete(runnerID, taskUpdate) - } - } - } -} - -// handleWebSocketLog handles log entries from WebSocket -func (s *Server) handleWebSocketLog(runnerID int64, logEntry WSLogEntry) { - // Store log in database - err := s.db.With(func(conn *sql.DB) error { - _, err := conn.Exec( - `INSERT INTO task_logs (task_id, runner_id, log_level, message, step_name, created_at) - VALUES (?, ?, ?, ?, ?, ?)`, - logEntry.TaskID, runnerID, logEntry.LogLevel, logEntry.Message, logEntry.StepName, time.Now(), - ) - return err - }) - if err != nil { - log.Printf("Failed to store log: %v", err) - return - } - - // Broadcast to frontend clients - s.broadcastLogToFrontend(logEntry.TaskID, logEntry) - - // If this log contains a frame number (Fra:), update progress for single-runner render jobs - if strings.Contains(logEntry.Message, "Fra:") { - // Get job ID from task - var jobID int64 - err := s.db.With(func(conn *sql.DB) error { - return conn.QueryRow("SELECT job_id FROM tasks WHERE id = ?", logEntry.TaskID).Scan(&jobID) - }) - if err == nil { - // Throttle progress updates (max once per 2 seconds per job) - s.progressUpdateTimesMu.RLock() - lastUpdate, exists := s.progressUpdateTimes[jobID] - s.progressUpdateTimesMu.RUnlock() - - shouldUpdate := !exists || time.Since(lastUpdate) >= 2*time.Second - if shouldUpdate { - s.progressUpdateTimesMu.Lock() - s.progressUpdateTimes[jobID] = time.Now() - s.progressUpdateTimesMu.Unlock() - - // Update progress in background to avoid blocking log processing - go s.updateJobStatusFromTasks(jobID) - } - } - } -} - -// handleWebSocketTaskUpdate handles task status updates from WebSocket -func (s *Server) handleWebSocketTaskUpdate(runnerID int64, taskUpdate WSTaskUpdate) { - // This can be used for progress updates - // For now, we'll just log it - log.Printf("Task %d update from runner %d: %s", taskUpdate.TaskID, runnerID, taskUpdate.Status) -} - -// handleWebSocketTaskComplete handles task completion from WebSocket -func (s *Server) handleWebSocketTaskComplete(runnerID int64, taskUpdate WSTaskUpdate) { - // Verify task belongs to runner - var taskRunnerID sql.NullInt64 - err := s.db.With(func(conn *sql.DB) error { - return conn.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskUpdate.TaskID).Scan(&taskRunnerID) - }) - if err != nil || !taskRunnerID.Valid || taskRunnerID.Int64 != runnerID { - log.Printf("Task %d does not belong to runner %d", taskUpdate.TaskID, runnerID) - return - } - - status := types.TaskStatusCompleted - if !taskUpdate.Success { - status = types.TaskStatusFailed - } - - // Get job ID first for atomic update - var jobID int64 - err = s.db.With(func(conn *sql.DB) error { - return conn.QueryRow( - `SELECT job_id FROM tasks WHERE id = ?`, - taskUpdate.TaskID, - ).Scan(&jobID) - }) - if err != nil { - log.Printf("Failed to get job ID for task %d: %v", taskUpdate.TaskID, err) - return - } - - // Use transaction to update task and job status atomically - now := time.Now() - err = s.db.WithTx(func(tx *sql.Tx) error { - // Update columns individually - _, err := tx.Exec(`UPDATE tasks SET status = ? WHERE id = ?`, status, taskUpdate.TaskID) - if err != nil { - log.Printf("Failed to update task status: %v", err) - return err - } - - if taskUpdate.OutputPath != "" { - _, err = tx.Exec(`UPDATE tasks SET output_path = ? WHERE id = ?`, taskUpdate.OutputPath, taskUpdate.TaskID) - if err != nil { - log.Printf("Failed to update task output_path: %v", err) - return err - } - } - - _, err = tx.Exec(`UPDATE tasks SET completed_at = ? WHERE id = ?`, now, taskUpdate.TaskID) - if err != nil { - log.Printf("Failed to update task completed_at: %v", err) - return err - } - - if taskUpdate.Error != "" { - _, err = tx.Exec(`UPDATE tasks SET error_message = ? WHERE id = ?`, taskUpdate.Error, taskUpdate.TaskID) - if err != nil { - log.Printf("Failed to update task error_message: %v", err) - return err - } - } - - return nil // Commit on nil return - }) - if err != nil { - log.Printf("Failed to update task %d: %v", taskUpdate.TaskID, err) - return - } - - // Broadcast task update - s.broadcastTaskUpdate(jobID, taskUpdate.TaskID, "task_update", map[string]interface{}{ - "status": status, - "output_path": taskUpdate.OutputPath, - "completed_at": now, - "error": taskUpdate.Error, - }) - // Update job status and progress (this will query tasks and update job accordingly) - s.updateJobStatusFromTasks(jobID) -} - -// parseBlenderFrame extracts the current frame number from Blender log messages -// Looks for patterns like "Fra:2470" in log messages -func parseBlenderFrame(logMessage string) (int, bool) { - // Look for "Fra:" followed by digits - // Pattern: "Fra:2470" or "Fra: 2470" or similar variations - fraIndex := strings.Index(logMessage, "Fra:") - if fraIndex == -1 { - return 0, false - } - - // Find the number after "Fra:" - start := fraIndex + 4 // Skip "Fra:" - // Skip whitespace - for start < len(logMessage) && (logMessage[start] == ' ' || logMessage[start] == '\t') { - start++ - } - - // Extract digits - end := start - for end < len(logMessage) && logMessage[end] >= '0' && logMessage[end] <= '9' { - end++ - } - - if end > start { - frame, err := strconv.Atoi(logMessage[start:end]) - if err == nil { - return frame, true - } - } - - return 0, false -} - -// getCurrentFrameFromLogs gets the highest frame number found in logs for a job's render tasks -func (s *Server) getCurrentFrameFromLogs(jobID int64) (int, bool) { - // Get all render tasks for this job - var rows *sql.Rows - err := s.db.With(func(conn *sql.DB) error { - var err error - rows, err = conn.Query( - `SELECT id FROM tasks WHERE job_id = ? AND task_type = ? AND status = ?`, - jobID, types.TaskTypeRender, types.TaskStatusRunning, - ) - return err - }) - if err != nil { - return 0, false - } - defer rows.Close() - - maxFrame := 0 - found := false - - for rows.Next() { - var taskID int64 - if err := rows.Scan(&taskID); err != nil { - log.Printf("Failed to scan task ID in getCurrentFrameFromLogs: %v", err) - continue - } - - // Get the most recent log entries for this task (last 100 to avoid scanning all logs) - var logRows *sql.Rows - err := s.db.With(func(conn *sql.DB) error { - var err error - logRows, err = conn.Query( - `SELECT message FROM task_logs - WHERE task_id = ? AND message LIKE '%Fra:%' - ORDER BY id DESC LIMIT 100`, - taskID, - ) - return err - }) - if err != nil { - continue - } - - for logRows.Next() { - var message string - if err := logRows.Scan(&message); err != nil { - continue - } - - if frame, ok := parseBlenderFrame(message); ok { - if frame > maxFrame { - maxFrame = frame - found = true - } - } - } - logRows.Close() - } - - return maxFrame, found -} - -// resetFailedTasksAndRedistribute resets all failed tasks for a job to pending and redistributes them -func (s *Server) resetFailedTasksAndRedistribute(jobID int64) error { - // Reset all failed tasks to pending and clear their retry_count - err := s.db.With(func(conn *sql.DB) error { - _, err := conn.Exec( - `UPDATE tasks SET status = ?, retry_count = 0, runner_id = NULL, started_at = NULL, completed_at = NULL, error_message = NULL - WHERE job_id = ? AND status = ?`, - types.TaskStatusPending, jobID, types.TaskStatusFailed, - ) - if err != nil { - return fmt.Errorf("failed to reset failed tasks: %v", err) - } - - // Increment job retry_count - _, err = conn.Exec( - `UPDATE jobs SET retry_count = retry_count + 1 WHERE id = ?`, - jobID, - ) - if err != nil { - return fmt.Errorf("failed to increment job retry_count: %v", err) - } - return nil - }) - if err != nil { - return err - } - - log.Printf("Reset failed tasks for job %d and incremented retry_count", jobID) - - // Trigger task distribution to redistribute the reset tasks - s.triggerTaskDistribution() - - return nil -} - -// cancelActiveTasksForJob cancels all active (pending or running) tasks for a job -func (s *Server) cancelActiveTasksForJob(jobID int64) error { - // Tasks don't have a cancelled status - mark them as failed instead - err := s.db.With(func(conn *sql.DB) error { - _, err := conn.Exec( - `UPDATE tasks SET status = ?, error_message = ? WHERE job_id = ? AND status IN (?, ?)`, - types.TaskStatusFailed, "Job cancelled", jobID, types.TaskStatusPending, types.TaskStatusRunning, - ) - if err != nil { - return fmt.Errorf("failed to cancel active tasks: %v", err) - } - return nil - }) - if err != nil { - return err - } - - log.Printf("Cancelled all active tasks for job %d", jobID) - return nil -} - -// updateJobStatusFromTasks updates job status and progress based on task states -func (s *Server) updateJobStatusFromTasks(jobID int64) { - now := time.Now() - - // Get job info to check if it's a render job without parallel runners - var jobType string - var frameStart, frameEnd sql.NullInt64 - var allowParallelRunners sql.NullBool - err := s.db.With(func(conn *sql.DB) error { - return conn.QueryRow( - `SELECT job_type, frame_start, frame_end, allow_parallel_runners FROM jobs WHERE id = ?`, - jobID, - ).Scan(&jobType, &frameStart, &frameEnd, &allowParallelRunners) - }) - if err != nil { - log.Printf("Failed to get job info for job %d: %v", jobID, err) - return - } - - // Check if we should use frame-based progress (render job, single runner) - useFrameProgress := jobType == string(types.JobTypeRender) && - allowParallelRunners.Valid && !allowParallelRunners.Bool && - frameStart.Valid && frameEnd.Valid - - // Get current job status to detect changes - var currentStatus string - err = s.db.With(func(conn *sql.DB) error { - return conn.QueryRow(`SELECT status FROM jobs WHERE id = ?`, jobID).Scan(¤tStatus) - }) - if err != nil { - log.Printf("Failed to get current job status for job %d: %v", jobID, err) - return - } - - // Count total tasks and completed tasks - var totalTasks, completedTasks int - err = s.db.With(func(conn *sql.DB) error { - err := conn.QueryRow( - `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?, ?, ?)`, - jobID, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed, - ).Scan(&totalTasks) - if err != nil { - return err - } - return conn.QueryRow( - `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`, - jobID, types.TaskStatusCompleted, - ).Scan(&completedTasks) - }) - if err != nil { - log.Printf("Failed to count completed tasks for job %d: %v", jobID, err) - return - } - - // Calculate progress - var progress float64 - if totalTasks == 0 { - // All tasks cancelled or no tasks, set progress to 0 - progress = 0.0 - } else if useFrameProgress { - // For single-runner render jobs, use frame-based progress from logs - currentFrame, frameFound := s.getCurrentFrameFromLogs(jobID) - frameStartVal := int(frameStart.Int64) - frameEndVal := int(frameEnd.Int64) - totalFrames := frameEndVal - frameStartVal + 1 - - // Count non-render tasks (like video generation) separately - var nonRenderTasks, nonRenderCompleted int - s.db.With(func(conn *sql.DB) error { - conn.QueryRow( - `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND task_type != ? AND status IN (?, ?, ?, ?)`, - jobID, types.TaskTypeRender, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed, - ).Scan(&nonRenderTasks) - conn.QueryRow( - `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND task_type != ? AND status = ?`, - jobID, types.TaskTypeRender, types.TaskStatusCompleted, - ).Scan(&nonRenderCompleted) - return nil - }) - - // Calculate render task progress from frames - var renderProgress float64 - if frameFound && totalFrames > 0 { - // Calculate how many frames have been rendered (current - start + 1) - // But cap at frame_end to handle cases where logs show frames beyond end - renderedFrames := currentFrame - frameStartVal + 1 - if currentFrame > frameEndVal { - renderedFrames = totalFrames - } else if renderedFrames < 0 { - renderedFrames = 0 - } - if renderedFrames > totalFrames { - renderedFrames = totalFrames - } - renderProgress = float64(renderedFrames) / float64(totalFrames) * 100.0 - } else { - // Fall back to task-based progress for render tasks - var renderTasks, renderCompleted int - s.db.With(func(conn *sql.DB) error { - conn.QueryRow( - `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND task_type = ? AND status IN (?, ?, ?, ?)`, - jobID, types.TaskTypeRender, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed, - ).Scan(&renderTasks) - conn.QueryRow( - `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND task_type = ? AND status = ?`, - jobID, types.TaskTypeRender, types.TaskStatusCompleted, - ).Scan(&renderCompleted) - return nil - }) - if renderTasks > 0 { - renderProgress = float64(renderCompleted) / float64(renderTasks) * 100.0 - } - } - - // Combine render progress with non-render task progress - // Weight: render tasks contribute 90%, other tasks contribute 10% (adjust as needed) - var nonRenderProgress float64 - if nonRenderTasks > 0 { - nonRenderProgress = float64(nonRenderCompleted) / float64(nonRenderTasks) * 100.0 - } - - // Weighted average: render progress is most important - if totalTasks > 0 { - renderWeight := 0.9 - nonRenderWeight := 0.1 - progress = renderProgress*renderWeight + nonRenderProgress*nonRenderWeight - } else { - progress = renderProgress - } - } else { - // Standard task-based progress - progress = float64(completedTasks) / float64(totalTasks) * 100.0 - } - - var jobStatus string - var outputFormat sql.NullString - s.db.With(func(conn *sql.DB) error { - conn.QueryRow(`SELECT output_format FROM jobs WHERE id = ?`, jobID).Scan(&outputFormat) - return nil - }) - outputFormatStr := "" - if outputFormat.Valid { - outputFormatStr = outputFormat.String - } - - // Check if all non-cancelled tasks are completed - var pendingOrRunningTasks int - err = s.db.With(func(conn *sql.DB) error { - return conn.QueryRow( - `SELECT COUNT(*) FROM tasks - WHERE job_id = ? AND status IN (?, ?)`, - jobID, types.TaskStatusPending, types.TaskStatusRunning, - ).Scan(&pendingOrRunningTasks) - }) - if err != nil { - log.Printf("Failed to count pending/running tasks for job %d: %v", jobID, err) - return - } - - if pendingOrRunningTasks == 0 && totalTasks > 0 { - // All tasks are either completed or failed/cancelled - // Check if any tasks failed - var failedTasks int - s.db.With(func(conn *sql.DB) error { - conn.QueryRow( - `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`, - jobID, types.TaskStatusFailed, - ).Scan(&failedTasks) - return nil - }) - - if failedTasks > 0 { - // Some tasks failed - check if job has retries left - var retryCount, maxRetries int - err := s.db.With(func(conn *sql.DB) error { - return conn.QueryRow( - `SELECT retry_count, max_retries FROM jobs WHERE id = ?`, - jobID, - ).Scan(&retryCount, &maxRetries) - }) - if err != nil { - log.Printf("Failed to get retry info for job %d: %v", jobID, err) - // Fall back to marking job as failed - jobStatus = string(types.JobStatusFailed) - } else if retryCount < maxRetries { - // Job has retries left - reset failed tasks and redistribute - if err := s.resetFailedTasksAndRedistribute(jobID); err != nil { - log.Printf("Failed to reset failed tasks for job %d: %v", jobID, err) - // If reset fails, mark job as failed - jobStatus = string(types.JobStatusFailed) - } else { - // Tasks reset successfully - job remains in running/pending state - // Don't update job status, just update progress - jobStatus = currentStatus // Keep current status - // Recalculate progress after reset (failed tasks are now pending again) - var newTotalTasks, newCompletedTasks int - s.db.With(func(conn *sql.DB) error { - conn.QueryRow( - `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?, ?, ?)`, - jobID, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed, - ).Scan(&newTotalTasks) - conn.QueryRow( - `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`, - jobID, types.TaskStatusCompleted, - ).Scan(&newCompletedTasks) - return nil - }) - if newTotalTasks > 0 { - progress = float64(newCompletedTasks) / float64(newTotalTasks) * 100.0 - } - // Update progress only - err := s.db.With(func(conn *sql.DB) error { - _, err := conn.Exec( - `UPDATE jobs SET progress = ? WHERE id = ?`, - progress, jobID, - ) - return err - }) - if err != nil { - log.Printf("Failed to update job %d progress: %v", jobID, err) - } else { - // Broadcast job update via WebSocket - s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{ - "status": jobStatus, - "progress": progress, - }) - } - return // Exit early since we've handled the retry - } - } else { - // No retries left - mark job as failed and cancel active tasks - jobStatus = string(types.JobStatusFailed) - if err := s.cancelActiveTasksForJob(jobID); err != nil { - log.Printf("Failed to cancel active tasks for job %d: %v", jobID, err) - } - } - } else { - // All tasks completed successfully - jobStatus = string(types.JobStatusCompleted) - progress = 100.0 // Ensure progress is 100% when all tasks complete - } - - // Update job status (if we didn't return early from retry logic) - if jobStatus != "" { - err := s.db.With(func(conn *sql.DB) error { - _, err := conn.Exec( - `UPDATE jobs SET status = ?, progress = ?, completed_at = ? WHERE id = ?`, - jobStatus, progress, now, jobID, - ) - return err - }) - if err != nil { - log.Printf("Failed to update job %d status to %s: %v", jobID, jobStatus, err) - } else { - // Only log if status actually changed - if currentStatus != jobStatus { - log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed tasks: %d/%d)", jobID, currentStatus, jobStatus, progress, completedTasks, totalTasks) - } - // Broadcast job update via WebSocket - s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{ - "status": jobStatus, - "progress": progress, - "completed_at": now, - }) - } - } - - if outputFormatStr == "EXR_264_MP4" || outputFormatStr == "EXR_AV1_MP4" { - // Check if a video generation task already exists for this job (any status) - var existingVideoTask int - s.db.With(func(conn *sql.DB) error { - conn.QueryRow( - `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND task_type = ?`, - jobID, types.TaskTypeVideoGeneration, - ).Scan(&existingVideoTask) - return nil - }) - - if existingVideoTask == 0 { - // Create a video generation task instead of calling generateMP4Video directly - // This prevents race conditions when multiple runners complete frames simultaneously - videoTaskTimeout := 86400 // 24 hours for video generation - var videoTaskID int64 - err := s.db.With(func(conn *sql.DB) error { - result, err := conn.Exec( - `INSERT INTO tasks (job_id, frame_start, frame_end, task_type, status, timeout_seconds, max_retries) - VALUES (?, ?, ?, ?, ?, ?, ?)`, - jobID, 0, 0, types.TaskTypeVideoGeneration, types.TaskStatusPending, videoTaskTimeout, 1, - ) - if err != nil { - return err - } - videoTaskID, err = result.LastInsertId() - return err - }) - if err != nil { - log.Printf("Failed to create video generation task for job %d: %v", jobID, err) - } else { - // Broadcast that a new task was added - if s.verboseWSLogging { - log.Printf("Broadcasting task_added for job %d: video generation task %d", jobID, videoTaskID) - } - s.broadcastTaskUpdate(jobID, videoTaskID, "task_added", map[string]interface{}{ - "task_id": videoTaskID, - "task_type": types.TaskTypeVideoGeneration, - }) - // Update job status to ensure it's marked as running (has pending video task) - s.updateJobStatusFromTasks(jobID) - // Try to distribute the task immediately - s.triggerTaskDistribution() - } - } else { - log.Printf("Skipping video generation task creation for job %d (video task already exists)", jobID) - } - } - } else { - // Job has pending or running tasks - determine if it's running or still pending - var runningTasks int - s.db.With(func(conn *sql.DB) error { - conn.QueryRow( - `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`, - jobID, types.TaskStatusRunning, - ).Scan(&runningTasks) - return nil - }) - - if runningTasks > 0 { - // Has running tasks - job is running - jobStatus = string(types.JobStatusRunning) - var startedAt sql.NullTime - s.db.With(func(conn *sql.DB) error { - conn.QueryRow(`SELECT started_at FROM jobs WHERE id = ?`, jobID).Scan(&startedAt) - if !startedAt.Valid { - conn.Exec(`UPDATE jobs SET started_at = ? WHERE id = ?`, now, jobID) - } - return nil - }) - } else { - // All tasks are pending - job is pending - jobStatus = string(types.JobStatusPending) - } - - err := s.db.With(func(conn *sql.DB) error { - _, err := conn.Exec( - `UPDATE jobs SET status = ?, progress = ? WHERE id = ?`, - jobStatus, progress, jobID, - ) - return err - }) - if err != nil { - log.Printf("Failed to update job %d status to %s: %v", jobID, jobStatus, err) - } else { - // Only log if status actually changed - if currentStatus != jobStatus { - log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed: %d/%d, pending: %d, running: %d)", jobID, currentStatus, jobStatus, progress, completedTasks, totalTasks, pendingOrRunningTasks-runningTasks, runningTasks) - } - // Broadcast job update during execution (not just on completion) - s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{ - "status": jobStatus, - "progress": progress, - }) - } - } -} - -// broadcastLogToFrontend broadcasts log to connected frontend clients -func (s *Server) broadcastLogToFrontend(taskID int64, logEntry WSLogEntry) { - // Get job_id, user_id, and task status from task - var jobID, userID int64 - var taskStatus string - var taskRunnerID sql.NullInt64 - var taskStartedAt sql.NullTime - err := s.db.With(func(conn *sql.DB) error { - return conn.QueryRow( - `SELECT t.job_id, j.user_id, t.status, t.runner_id, t.started_at - FROM tasks t - JOIN jobs j ON t.job_id = j.id - WHERE t.id = ?`, - taskID, - ).Scan(&jobID, &userID, &taskStatus, &taskRunnerID, &taskStartedAt) - }) - if err != nil { - return - } - - // Get full log entry from database for consistency - // Use a more reliable query that gets the most recent log with matching message - // This avoids race conditions with concurrent inserts - var taskLog types.TaskLog - var runnerID sql.NullInt64 - var stepName sql.NullString - err = s.db.With(func(conn *sql.DB) error { - return conn.QueryRow( - `SELECT id, task_id, runner_id, log_level, message, step_name, created_at - FROM task_logs WHERE task_id = ? AND message = ? ORDER BY id DESC LIMIT 1`, - taskID, logEntry.Message, - ).Scan(&taskLog.ID, &taskLog.TaskID, &runnerID, &taskLog.LogLevel, &taskLog.Message, &stepName, &taskLog.CreatedAt) - }) - if err != nil { - return - } - if runnerID.Valid { - taskLog.RunnerID = &runnerID.Int64 - } - if stepName.Valid { - taskLog.StepName = stepName.String - } - - msg := map[string]interface{}{ - "type": "log", - "task_id": taskID, - "job_id": jobID, - "data": taskLog, - "timestamp": time.Now().Unix(), - } - - // Broadcast to client WebSocket if subscribed to logs:{jobId}:{taskId} - channel := fmt.Sprintf("logs:%d:%d", jobID, taskID) - if s.verboseWSLogging { - runnerIDStr := "none" - if taskRunnerID.Valid { - runnerIDStr = fmt.Sprintf("%d", taskRunnerID.Int64) - } - log.Printf("broadcastLogToFrontend: Broadcasting log for task %d (job %d, user %d) on channel %s, log_id=%d, task_status=%s, runner_id=%s", taskID, jobID, userID, channel, taskLog.ID, taskStatus, runnerIDStr) - } - s.broadcastToClient(userID, channel, msg) - - // If task status is pending but logs are coming in, log a warning - // This indicates the initial assignment broadcast may have been missed or the database update failed - if taskStatus == string(types.TaskStatusPending) { - log.Printf("broadcastLogToFrontend: ERROR - Task %d has logs but status is 'pending'. This indicates the initial task assignment failed or the task_update broadcast was missed.", taskID) - } - - // Also broadcast to old WebSocket connection (for backwards compatibility during migration) - key := fmt.Sprintf("%d:%d", jobID, taskID) - s.frontendConnsMu.RLock() - conn, exists := s.frontendConns[key] - s.frontendConnsMu.RUnlock() - - if exists && conn != nil { - // Serialize writes to prevent concurrent write panics - s.frontendConnsWriteMuMu.RLock() - writeMu, hasMu := s.frontendConnsWriteMu[key] - s.frontendConnsWriteMuMu.RUnlock() - - if hasMu && writeMu != nil { - writeMu.Lock() - conn.WriteJSON(msg) - writeMu.Unlock() - } else { - // Fallback if mutex doesn't exist yet (shouldn't happen, but be safe) - conn.WriteJSON(msg) - } - } -} - -// triggerTaskDistribution triggers task distribution in a serialized manner -func (s *Server) triggerTaskDistribution() { - go func() { - // Try to acquire lock - if already running, log and skip - if !s.taskDistMu.TryLock() { - // Log when distribution is skipped to help with debugging - log.Printf("Task distribution already in progress, skipping trigger") - return // Distribution already in progress - } - defer s.taskDistMu.Unlock() - s.distributeTasksToRunners() - }() -} - -// distributeTasksToRunners pushes available tasks to connected runners -// This function should only be called while holding taskDistMu lock -func (s *Server) distributeTasksToRunners() { - // Quick check: if there are no pending tasks, skip the expensive query - var pendingCount int - err := s.db.With(func(conn *sql.DB) error { - return conn.QueryRow( - `SELECT COUNT(*) FROM tasks t - JOIN jobs j ON t.job_id = j.id - WHERE t.status = ? AND j.status != ?`, - types.TaskStatusPending, types.JobStatusCancelled, - ).Scan(&pendingCount) - }) - if err != nil { - log.Printf("Failed to check pending tasks count: %v", err) - return - } - if pendingCount == 0 { - // No pending tasks, nothing to distribute - return - } - - // Get all pending tasks - var rows *sql.Rows - err = s.db.With(func(conn *sql.DB) error { - var err error - rows, err = conn.Query( - `SELECT t.id, t.job_id, t.frame_start, t.frame_end, t.task_type, j.allow_parallel_runners, j.status as job_status, j.name as job_name, j.user_id - FROM tasks t - JOIN jobs j ON t.job_id = j.id - WHERE t.status = ? AND j.status != ? - ORDER BY t.created_at ASC`, - types.TaskStatusPending, types.JobStatusCancelled, - ) - return err - }) - if err != nil { - log.Printf("Failed to query pending tasks: %v", err) - return - } - defer rows.Close() - - var pendingTasks []struct { - TaskID int64 - JobID int64 - FrameStart int - FrameEnd int - TaskType string - AllowParallelRunners bool - JobName string - JobStatus string - JobUserID int64 - } - - for rows.Next() { - var t struct { - TaskID int64 - JobID int64 - FrameStart int - FrameEnd int - TaskType string - AllowParallelRunners bool - JobName string - JobStatus string - JobUserID int64 - } - var allowParallel sql.NullBool - err := rows.Scan(&t.TaskID, &t.JobID, &t.FrameStart, &t.FrameEnd, &t.TaskType, &allowParallel, &t.JobStatus, &t.JobName, &t.JobUserID) - if err != nil { - log.Printf("Failed to scan pending task: %v", err) - continue - } - // Default to true if NULL (for metadata jobs or legacy data) - if allowParallel.Valid { - t.AllowParallelRunners = allowParallel.Bool - } else { - t.AllowParallelRunners = true - } - pendingTasks = append(pendingTasks, t) - } - - if len(pendingTasks) == 0 { - log.Printf("No pending tasks found for distribution") - return - } - - log.Printf("Found %d pending tasks for distribution", len(pendingTasks)) - - // Get connected runners (WebSocket connection is source of truth) - // Use a read lock to safely read the map - s.runnerConnsMu.RLock() - connectedRunners := make([]int64, 0, len(s.runnerConns)) - for runnerID := range s.runnerConns { - // Verify connection is still valid (not closed) - conn := s.runnerConns[runnerID] - if conn != nil { - connectedRunners = append(connectedRunners, runnerID) - } - } - s.runnerConnsMu.RUnlock() - - // Get runner priorities, capabilities, and API key scopes for all connected runners - runnerPriorities := make(map[int64]int) - runnerCapabilities := make(map[int64]map[string]interface{}) - runnerScopes := make(map[int64]string) - for _, runnerID := range connectedRunners { - var priority int - var capabilitiesJSON sql.NullString - var scope string - err := s.db.With(func(conn *sql.DB) error { - return conn.QueryRow("SELECT priority, capabilities, api_key_scope FROM runners WHERE id = ?", runnerID).Scan(&priority, &capabilitiesJSON, &scope) - }) - if err != nil { - // Default to 100 if priority not found - priority = 100 - capabilitiesJSON = sql.NullString{String: "{}", Valid: true} - } - runnerPriorities[runnerID] = priority - runnerScopes[runnerID] = scope - - // Parse capabilities JSON (can contain both bools and numbers) - capabilitiesStr := "{}" - if capabilitiesJSON.Valid { - capabilitiesStr = capabilitiesJSON.String - } - var capabilities map[string]interface{} - if err := json.Unmarshal([]byte(capabilitiesStr), &capabilities); err != nil { - // If parsing fails, try old format (map[string]bool) for backward compatibility - var oldCapabilities map[string]bool - if err2 := json.Unmarshal([]byte(capabilitiesStr), &oldCapabilities); err2 == nil { - // Convert old format to new format - capabilities = make(map[string]interface{}) - for k, v := range oldCapabilities { - capabilities[k] = v - } - } else { - // Both formats failed, assume no capabilities - capabilities = make(map[string]interface{}) - } - } - runnerCapabilities[runnerID] = capabilities - } - - // Update database status for all connected runners (outside the lock to avoid holding it too long) - for _, runnerID := range connectedRunners { - // Ensure database status matches WebSocket connection - // Update status to online if it's not already - s.db.With(func(conn *sql.DB) error { - _, _ = conn.Exec( - `UPDATE runners SET status = ?, last_heartbeat = ? WHERE id = ? AND status != ?`, - types.RunnerStatusOnline, time.Now(), runnerID, types.RunnerStatusOnline, - ) - return nil - }) - } - - if len(connectedRunners) == 0 { - log.Printf("No connected runners available for task distribution (checked WebSocket connections)") - // Log to task logs that no runners are available - for _, task := range pendingTasks { - if task.TaskType == string(types.TaskTypeMetadata) { - s.logTaskEvent(task.TaskID, nil, types.LogLevelWarn, "No connected runners available for task assignment", "") - } - } - return - } - - // Log task types being distributed - taskTypes := make(map[string]int) - for _, task := range pendingTasks { - taskTypes[task.TaskType]++ - } - log.Printf("Distributing %d pending tasks (%v) to %d connected runners: %v", len(pendingTasks), taskTypes, len(connectedRunners), connectedRunners) - - // Distribute tasks to runners - // Sort tasks to prioritize metadata tasks - sort.Slice(pendingTasks, func(i, j int) bool { - // Metadata tasks first - if pendingTasks[i].TaskType == string(types.TaskTypeMetadata) && pendingTasks[j].TaskType != string(types.TaskTypeMetadata) { - return true - } - if pendingTasks[i].TaskType != string(types.TaskTypeMetadata) && pendingTasks[j].TaskType == string(types.TaskTypeMetadata) { - return false - } - return false // Keep original order for same type - }) - - // Track how many tasks each runner has been assigned in this distribution cycle - runnerTaskCounts := make(map[int64]int) - - for _, task := range pendingTasks { - // Determine required capability for this task - var requiredCapability string - switch task.TaskType { - case string(types.TaskTypeRender), string(types.TaskTypeMetadata): - requiredCapability = "blender" - case string(types.TaskTypeVideoGeneration): - requiredCapability = "ffmpeg" - default: - requiredCapability = "" // Unknown task type - } - - // Find available runner - var selectedRunnerID int64 - var bestRunnerID int64 - var bestPriority int = -1 - var bestTaskCount int = -1 - var bestRandom float64 = -1 // Random tie-breaker - - // Try to find the best runner for this task - for _, runnerID := range connectedRunners { - // Check if runner's API key scope allows working on this job - runnerScope := runnerScopes[runnerID] - if runnerScope == "user" && task.JobUserID != 0 { - // User-scoped runner - check if they can work on jobs from this user - // For now, user-scoped runners can only work on jobs from the same user who created their API key - var apiKeyCreatedBy int64 - if runnerScope == "user" { - // Get the user who created this runner's API key - var apiKeyID sql.NullInt64 - err := s.db.With(func(conn *sql.DB) error { - return conn.QueryRow("SELECT api_key_id FROM runners WHERE id = ?", runnerID).Scan(&apiKeyID) - }) - if err == nil && apiKeyID.Valid { - s.db.With(func(conn *sql.DB) error { - return conn.QueryRow("SELECT created_by FROM runner_api_keys WHERE id = ?", apiKeyID.Int64).Scan(&apiKeyCreatedBy) - }) - if err != nil { - continue // Skip this runner if we can't determine API key ownership - } - // Only allow if the job owner matches the API key creator - if apiKeyCreatedBy != task.JobUserID { - continue // This user-scoped runner cannot work on this job - } - } - } - // Manager-scoped runners can work on any job - } - - // Check if runner has required capability - capabilities := runnerCapabilities[runnerID] - hasRequired := false - if reqVal, ok := capabilities[requiredCapability]; ok { - if reqBool, ok := reqVal.(bool); ok { - hasRequired = reqBool - } else if reqFloat, ok := reqVal.(float64); ok { - hasRequired = reqFloat > 0 - } else if reqInt, ok := reqVal.(int); ok { - hasRequired = reqInt > 0 - } - } - if !hasRequired && requiredCapability != "" { - continue // Runner doesn't have required capability - } - - // Check if runner has ANY tasks (pending or running) - one task at a time only - // This prevents any runner from doing more than one task at a time - var activeTaskCount int - err := s.db.With(func(conn *sql.DB) error { - return conn.QueryRow( - `SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status IN (?, ?)`, - runnerID, types.TaskStatusPending, types.TaskStatusRunning, - ).Scan(&activeTaskCount) - }) - if err != nil { - log.Printf("Failed to check active tasks for runner %d: %v", runnerID, err) - continue - } - - if activeTaskCount > 0 { - continue // Runner is busy with another task, cannot run any other tasks - } - - // For non-parallel jobs, check if runner already has tasks from this job - if !task.AllowParallelRunners { - var jobTaskCount int - err := s.db.With(func(conn *sql.DB) error { - return conn.QueryRow( - `SELECT COUNT(*) FROM tasks - WHERE job_id = ? AND runner_id = ? AND status IN (?, ?)`, - task.JobID, runnerID, types.TaskStatusPending, types.TaskStatusRunning, - ).Scan(&jobTaskCount) - }) - if err != nil { - log.Printf("Failed to check job tasks for runner %d: %v", runnerID, err) - continue - } - if jobTaskCount > 0 { - continue // Another runner is working on this job - } - } - - // Get runner priority and task count - priority := runnerPriorities[runnerID] - currentTaskCount := runnerTaskCounts[runnerID] - // Generate a small random value for absolute tie-breaking - randomValue := rand.Float64() - - // Selection priority: - // 1. Priority (higher is better) - // 2. Task count (fewer is better) - // 3. Random value (absolute tie-breaker) - isBetter := false - if bestRunnerID == 0 { - isBetter = true - } else if priority > bestPriority { - // Higher priority - isBetter = true - } else if priority == bestPriority { - if currentTaskCount < bestTaskCount { - // Same priority, but fewer tasks assigned in this cycle - isBetter = true - } else if currentTaskCount == bestTaskCount { - // Absolute tie - use random value as tie-breaker - if randomValue > bestRandom { - isBetter = true - } - } - } - - if isBetter { - bestRunnerID = runnerID - bestPriority = priority - bestTaskCount = currentTaskCount - bestRandom = randomValue - } - } - - // Use the best runner we found (prioritized by priority, then load balanced) - if bestRunnerID != 0 { - selectedRunnerID = bestRunnerID - } - - if selectedRunnerID == 0 { - if task.TaskType == string(types.TaskTypeMetadata) { - log.Printf("Warning: No available runner for metadata task %d (job %d)", task.TaskID, task.JobID) - // Log that no runner is available - s.logTaskEvent(task.TaskID, nil, types.LogLevelWarn, "No available runner for task assignment", "") - } - continue // No available runner - task stays in queue - } - - // Track assignment for load balancing - runnerTaskCounts[selectedRunnerID]++ - - // Atomically assign task to runner using UPDATE with WHERE runner_id IS NULL - // This prevents race conditions when multiple goroutines try to assign the same task - // Use a transaction to ensure atomicity - now := time.Now() - var rowsAffected int64 - var verifyStatus string - var verifyRunnerID sql.NullInt64 - var verifyStartedAt sql.NullTime - err := s.db.WithTx(func(tx *sql.Tx) error { - result, err := tx.Exec( - `UPDATE tasks SET runner_id = ?, status = ?, started_at = ? - WHERE id = ? AND runner_id IS NULL AND status = ?`, - selectedRunnerID, types.TaskStatusRunning, now, task.TaskID, types.TaskStatusPending, - ) - if err != nil { - return err - } - - // Check if the update actually affected a row (task was successfully assigned) - rowsAffected, err = result.RowsAffected() - if err != nil { - return err - } - - if rowsAffected == 0 { - return sql.ErrNoRows // Task was already assigned - } - - // Verify the update within the transaction before committing - // This ensures we catch any issues before the transaction is committed - err = tx.QueryRow( - `SELECT status, runner_id, started_at FROM tasks WHERE id = ?`, - task.TaskID, - ).Scan(&verifyStatus, &verifyRunnerID, &verifyStartedAt) - if err != nil { - return err - } - if verifyStatus != string(types.TaskStatusRunning) { - return fmt.Errorf("task status is %s after assignment, expected running", verifyStatus) - } - if !verifyRunnerID.Valid || verifyRunnerID.Int64 != selectedRunnerID { - return fmt.Errorf("task runner_id is %v after assignment, expected %d", verifyRunnerID, selectedRunnerID) - } - - return nil // Commit on nil return - }) - if err == sql.ErrNoRows { - // Task was already assigned by another goroutine, skip - continue - } - if err != nil { - log.Printf("Failed to atomically assign task %d: %v", task.TaskID, err) - continue - } - - log.Printf("Verified and committed task %d assignment: status=%s, runner_id=%d, started_at=%v", task.TaskID, verifyStatus, verifyRunnerID.Int64, verifyStartedAt) - - // Broadcast task assignment - include all fields to ensure frontend has complete info - updateData := map[string]interface{}{ - "status": types.TaskStatusRunning, - "runner_id": selectedRunnerID, - "started_at": verifyStartedAt.Time, - } - if !verifyStartedAt.Valid { - updateData["started_at"] = now - } - if s.verboseWSLogging { - log.Printf("Broadcasting task_update for task %d (job %d, user %d): status=%s, runner_id=%d, started_at=%v", task.TaskID, task.JobID, task.JobUserID, types.TaskStatusRunning, selectedRunnerID, now) - } - s.broadcastTaskUpdate(task.JobID, task.TaskID, "task_update", updateData) - - // Task was successfully assigned in database, now send via WebSocket - log.Printf("Assigned task %d (type: %s, job: %d) to runner %d", task.TaskID, task.TaskType, task.JobID, selectedRunnerID) - - // Log runner assignment to task logs - s.logTaskEvent(task.TaskID, nil, types.LogLevelInfo, fmt.Sprintf("Task assigned to runner %d", selectedRunnerID), "") - - // Attempt to send task to runner via WebSocket - if err := s.assignTaskToRunner(selectedRunnerID, task.TaskID); err != nil { - log.Printf("Failed to send task %d to runner %d: %v", task.TaskID, selectedRunnerID, err) - // Log assignment failure - s.logTaskEvent(task.TaskID, nil, types.LogLevelError, fmt.Sprintf("Failed to send task to runner %d: %v", selectedRunnerID, err), "") - // Rollback the assignment if WebSocket send fails with retry mechanism - rollbackSuccess := false - for retry := 0; retry < 3; retry++ { - rollbackErr := s.db.WithTx(func(tx *sql.Tx) error { - _, err := tx.Exec( - `UPDATE tasks SET runner_id = NULL, status = ?, started_at = NULL - WHERE id = ? AND runner_id = ?`, - types.TaskStatusPending, task.TaskID, selectedRunnerID, - ) - return err - }) - if rollbackErr != nil { - log.Printf("Failed to rollback task %d assignment (attempt %d/3): %v", task.TaskID, retry+1, rollbackErr) - if retry < 2 { - time.Sleep(time.Duration(retry+1) * 100 * time.Millisecond) // Exponential backoff - continue - } - // Final attempt failed - log.Printf("CRITICAL: Failed to rollback task %d after 3 attempts - task may be in inconsistent state", task.TaskID) - s.triggerTaskDistribution() - break - } - // Rollback succeeded - rollbackSuccess = true - s.logTaskEvent(task.TaskID, nil, types.LogLevelWarn, fmt.Sprintf("Task assignment rolled back - runner %d connection failed", selectedRunnerID), "") - s.updateJobStatusFromTasks(task.JobID) - s.triggerTaskDistribution() - break - } - if !rollbackSuccess { - // Schedule background cleanup for inconsistent state - go func() { - time.Sleep(5 * time.Second) - // Retry rollback one more time in background - err := s.db.WithTx(func(tx *sql.Tx) error { - _, err := tx.Exec( - `UPDATE tasks SET runner_id = NULL, status = ?, started_at = NULL - WHERE id = ? AND runner_id = ? AND status = ?`, - types.TaskStatusPending, task.TaskID, selectedRunnerID, types.TaskStatusRunning, - ) - return err - }) - if err == nil { - log.Printf("Background cleanup: Successfully rolled back task %d", task.TaskID) - s.updateJobStatusFromTasks(task.JobID) - s.triggerTaskDistribution() - } - }() - } - } else { - // WebSocket send succeeded, update job status - s.updateJobStatusFromTasks(task.JobID) - } - } -} - -// assignTaskToRunner sends a task to a runner via WebSocket -func (s *Server) assignTaskToRunner(runnerID int64, taskID int64) error { - // Hold read lock during entire operation to prevent connection from being replaced - s.runnerConnsMu.RLock() - conn, exists := s.runnerConns[runnerID] - if !exists { - s.runnerConnsMu.RUnlock() - return fmt.Errorf("runner %d not connected", runnerID) - } - // Keep lock held to prevent connection replacement during operation - defer s.runnerConnsMu.RUnlock() - - // Get task details - var task WSTaskAssignment - var jobName string - var outputFormat sql.NullString - var taskType string - err := s.db.With(func(conn *sql.DB) error { - return conn.QueryRow( - `SELECT t.job_id, t.frame_start, t.frame_end, t.task_type, j.name, j.output_format - FROM tasks t JOIN jobs j ON t.job_id = j.id WHERE t.id = ?`, - taskID, - ).Scan(&task.JobID, &task.FrameStart, &task.FrameEnd, &taskType, &jobName, &outputFormat) - }) - if err != nil { - return err - } - - task.TaskID = taskID - task.JobName = jobName - if outputFormat.Valid { - task.OutputFormat = outputFormat.String - log.Printf("Task %d assigned with output_format: '%s' (from job %d)", taskID, outputFormat.String, task.JobID) - } else { - log.Printf("Task %d assigned with no output_format (job %d)", taskID, task.JobID) - } - task.TaskType = taskType - - // Get input files - var rows *sql.Rows - err = s.db.With(func(conn *sql.DB) error { - var err error - rows, err = conn.Query( - `SELECT file_path FROM job_files WHERE job_id = ? AND file_type = ?`, - task.JobID, types.JobFileTypeInput, - ) - return err - }) - if err == nil { - defer rows.Close() - for rows.Next() { - var filePath string - if err := rows.Scan(&filePath); err == nil { - task.InputFiles = append(task.InputFiles, filePath) - } else { - log.Printf("Failed to scan input file path for task %d: %v", taskID, err) - } - } - } else { - log.Printf("Warning: Failed to query input files for task %d (job %d): %v", taskID, task.JobID, err) - } - - if len(task.InputFiles) == 0 { - errMsg := fmt.Sprintf("No input files found for task %d (job %d). Cannot assign task without input files.", taskID, task.JobID) - log.Printf("ERROR: %s", errMsg) - // Don't send the task - it will fail anyway - // Rollback the assignment - s.db.With(func(conn *sql.DB) error { - _, _ = conn.Exec( - `UPDATE tasks SET runner_id = NULL, status = ?, started_at = NULL - WHERE id = ?`, - types.TaskStatusPending, taskID, - ) - return nil - }) - s.logTaskEvent(taskID, nil, types.LogLevelError, errMsg, "") - return errors.New(errMsg) - } - - // Note: Task is already assigned in database by the atomic update in distributeTasksToRunners - // We just need to verify it's still assigned to this runner - var assignedRunnerID sql.NullInt64 - err = s.db.With(func(conn *sql.DB) error { - return conn.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskID).Scan(&assignedRunnerID) - }) - if err != nil { - return fmt.Errorf("task not found: %w", err) - } - if !assignedRunnerID.Valid || assignedRunnerID.Int64 != runnerID { - return fmt.Errorf("task %d is not assigned to runner %d", taskID, runnerID) - } - - // Send task via WebSocket with write mutex protection - msg := WSMessage{ - Type: "task_assignment", - Timestamp: time.Now().Unix(), - } - msg.Data, _ = json.Marshal(task) - - // Get write mutex for this connection - s.runnerConnsWriteMuMu.RLock() - writeMu, hasMu := s.runnerConnsWriteMu[runnerID] - s.runnerConnsWriteMuMu.RUnlock() - - if !hasMu || writeMu == nil { - return fmt.Errorf("runner %d write mutex not found", runnerID) - } - - // Connection is still valid (we're holding the read lock) - // Write to connection with mutex protection - writeMu.Lock() - err = conn.WriteJSON(msg) - writeMu.Unlock() - return err -} - -// redistributeRunnerTasks resets tasks assigned to a disconnected/dead runner and redistributes them -func (s *Server) redistributeRunnerTasks(runnerID int64) { - log.Printf("Starting task redistribution for disconnected runner %d", runnerID) - - // Get tasks assigned to this runner that are still running - var taskRows *sql.Rows - err := s.db.With(func(conn *sql.DB) error { - var err error - taskRows, err = conn.Query( - `SELECT id, retry_count, max_retries, job_id FROM tasks - WHERE runner_id = ? AND status = ?`, - runnerID, types.TaskStatusRunning, - ) - return err - }) - if err != nil { - log.Printf("Failed to query tasks for runner %d: %v", runnerID, err) - return - } - defer taskRows.Close() - - var tasksToReset []struct { - ID int64 - RetryCount int - MaxRetries int - JobID int64 - } - - for taskRows.Next() { - var t struct { - ID int64 - RetryCount int - MaxRetries int - JobID int64 - } - if err := taskRows.Scan(&t.ID, &t.RetryCount, &t.MaxRetries, &t.JobID); err != nil { - log.Printf("Failed to scan task for runner %d: %v", runnerID, err) - continue - } - tasksToReset = append(tasksToReset, t) - } - - if len(tasksToReset) == 0 { - log.Printf("No running tasks found for runner %d to redistribute", runnerID) - return - } - - log.Printf("Redistributing %d running tasks from disconnected runner %d", len(tasksToReset), runnerID) - - // Reset or fail tasks - resetCount := 0 - failedCount := 0 - - for _, task := range tasksToReset { - if task.RetryCount >= task.MaxRetries { - // Mark as failed - err = s.db.With(func(conn *sql.DB) error { - _, err := conn.Exec(`UPDATE tasks SET status = ? WHERE id = ? AND runner_id = ?`, types.TaskStatusFailed, task.ID, runnerID) - if err != nil { - return err - } - _, err = conn.Exec(`UPDATE tasks SET error_message = ? WHERE id = ? AND runner_id = ?`, "Runner disconnected, max retries exceeded", task.ID, runnerID) - if err != nil { - return err - } - _, err = conn.Exec(`UPDATE tasks SET runner_id = NULL WHERE id = ? AND runner_id = ?`, task.ID, runnerID) - if err != nil { - return err - } - _, err = conn.Exec(`UPDATE tasks SET completed_at = ? WHERE id = ?`, time.Now(), task.ID) - return err - }) - if err != nil { - log.Printf("Failed to mark task %d as failed: %v", task.ID, err) - } else { - failedCount++ - // Log task failure - s.logTaskEvent(task.ID, &runnerID, types.LogLevelError, - fmt.Sprintf("Task failed - runner %d disconnected, max retries (%d) exceeded", runnerID, task.MaxRetries), "") - } - } else { - // Reset to pending so it can be redistributed - err = s.db.With(func(conn *sql.DB) error { - _, err := conn.Exec( - `UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL, - retry_count = retry_count + 1, started_at = NULL WHERE id = ? AND runner_id = ?`, - types.TaskStatusPending, task.ID, runnerID, - ) - return err - }) - if err != nil { - log.Printf("Failed to reset task %d: %v", task.ID, err) - } else { - resetCount++ - // Log task reset for redistribution - s.logTaskEvent(task.ID, &runnerID, types.LogLevelWarn, - fmt.Sprintf("Runner %d disconnected, task reset for redistribution (retry %d/%d)", runnerID, task.RetryCount+1, task.MaxRetries), "") - } - } - } - - log.Printf("Task redistribution complete for runner %d: %d tasks reset, %d tasks failed", runnerID, resetCount, failedCount) - - // Update job statuses for affected jobs - jobIDs := make(map[int64]bool) - for _, task := range tasksToReset { - jobIDs[task.JobID] = true - } - - for jobID := range jobIDs { - // Update job status based on remaining tasks - go s.updateJobStatusFromTasks(jobID) - } - - // Immediately redistribute the reset tasks - if resetCount > 0 { - log.Printf("Triggering task distribution for %d reset tasks from runner %d", resetCount, runnerID) - s.triggerTaskDistribution() - } -} - -// logTaskEvent logs an event to a task's log (manager-side logging) -func (s *Server) logTaskEvent(taskID int64, runnerID *int64, logLevel types.LogLevel, message, stepName string) { - var runnerIDValue interface{} - if runnerID != nil { - runnerIDValue = *runnerID - } - - err := s.db.With(func(conn *sql.DB) error { - _, err := conn.Exec( - `INSERT INTO task_logs (task_id, runner_id, log_level, message, step_name, created_at) - VALUES (?, ?, ?, ?, ?, ?)`, - taskID, runnerIDValue, logLevel, message, stepName, time.Now(), - ) - return err - }) - if err != nil { - log.Printf("Failed to log task event for task %d: %v", taskID, err) - return - } - - // Broadcast to frontend if there are connected clients - s.broadcastLogToFrontend(taskID, WSLogEntry{ - TaskID: taskID, - LogLevel: string(logLevel), - Message: message, - StepName: stepName, - }) -} - -// cleanupOldOfflineRunners periodically deletes runners that have been offline for more than 1 month -func (s *Server) cleanupOldOfflineRunners() { - // Run cleanup every 24 hours - ticker := time.NewTicker(24 * time.Hour) - defer ticker.Stop() - - // Run once immediately on startup - s.cleanupOldOfflineRunnersOnce() - - for range ticker.C { - s.cleanupOldOfflineRunnersOnce() - } -} - -// cleanupOldOfflineRunnersOnce finds and deletes runners that have been offline for more than 1 month -func (s *Server) cleanupOldOfflineRunnersOnce() { - defer func() { - if r := recover(); r != nil { - log.Printf("Panic in cleanupOldOfflineRunners: %v", r) - } - }() - - // Find runners that: - // 1. Are offline - // 2. Haven't had a heartbeat in over 1 month - // 3. Are not currently connected via WebSocket - var rows *sql.Rows - err := s.db.With(func(conn *sql.DB) error { - var err error - rows, err = conn.Query( - `SELECT id, name FROM runners - WHERE status = ? - AND last_heartbeat < datetime('now', '-1 month')`, - types.RunnerStatusOffline, - ) - return err - }) - if err != nil { - log.Printf("Failed to query old offline runners: %v", err) - return - } - defer rows.Close() - - type runnerInfo struct { - ID int64 - Name string - } - var runnersToDelete []runnerInfo - - s.runnerConnsMu.RLock() - for rows.Next() { - var info runnerInfo - if err := rows.Scan(&info.ID, &info.Name); err == nil { - // Double-check runner is not connected via WebSocket - if _, connected := s.runnerConns[info.ID]; !connected { - runnersToDelete = append(runnersToDelete, info) - } - } - } - s.runnerConnsMu.RUnlock() - rows.Close() - - if len(runnersToDelete) == 0 { - return - } - - log.Printf("Cleaning up %d old offline runners (offline for more than 1 month)", len(runnersToDelete)) - - // Delete each runner - for _, runner := range runnersToDelete { - // First, check if there are any tasks still assigned to this runner - // If so, reset them to pending before deleting the runner - var assignedTaskCount int - err := s.db.With(func(conn *sql.DB) error { - return conn.QueryRow( - `SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status IN (?, ?)`, - runner.ID, types.TaskStatusRunning, types.TaskStatusPending, - ).Scan(&assignedTaskCount) - }) - if err != nil { - log.Printf("Failed to check assigned tasks for runner %d: %v", runner.ID, err) - continue - } - - if assignedTaskCount > 0 { - // Reset any tasks assigned to this runner - log.Printf("Resetting %d tasks assigned to runner %d before deletion", assignedTaskCount, runner.ID) - err = s.db.With(func(conn *sql.DB) error { - _, err := conn.Exec( - `UPDATE tasks SET runner_id = NULL, status = ? WHERE runner_id = ? AND status IN (?, ?)`, - types.TaskStatusPending, runner.ID, types.TaskStatusRunning, types.TaskStatusPending, - ) - return err - }) - if err != nil { - log.Printf("Failed to reset tasks for runner %d: %v", runner.ID, err) - continue - } - } - - // Delete the runner - err = s.db.With(func(conn *sql.DB) error { - _, err := conn.Exec("DELETE FROM runners WHERE id = ?", runner.ID) - return err - }) - if err != nil { - log.Printf("Failed to delete runner %d (%s): %v", runner.ID, runner.Name, err) - continue - } - - log.Printf("Deleted old offline runner: %d (%s)", runner.ID, runner.Name) - } - - // Trigger task distribution if any tasks were reset - s.triggerTaskDistribution() -} diff --git a/internal/auth/jobtoken.go b/internal/auth/jobtoken.go new file mode 100644 index 0000000..2128ba4 --- /dev/null +++ b/internal/auth/jobtoken.go @@ -0,0 +1,115 @@ +package auth + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "time" +) + +// JobTokenDuration is the validity period for job tokens +const JobTokenDuration = 1 * time.Hour + +// JobTokenClaims represents the claims in a job token +type JobTokenClaims struct { + JobID int64 `json:"job_id"` + RunnerID int64 `json:"runner_id"` + TaskID int64 `json:"task_id"` + Exp int64 `json:"exp"` // Unix timestamp +} + +// jobTokenSecret is the secret used to sign job tokens +// Generated once at startup and kept in memory +var jobTokenSecret []byte + +func init() { + // Generate a random secret for signing job tokens + // This means tokens are invalidated on server restart, which is acceptable + // for short-lived job tokens + jobTokenSecret = make([]byte, 32) + if _, err := rand.Read(jobTokenSecret); err != nil { + panic(fmt.Sprintf("failed to generate job token secret: %v", err)) + } +} + +// GenerateJobToken creates a new job token for a specific job/runner/task combination +func GenerateJobToken(jobID, runnerID, taskID int64) (string, error) { + claims := JobTokenClaims{ + JobID: jobID, + RunnerID: runnerID, + TaskID: taskID, + Exp: time.Now().Add(JobTokenDuration).Unix(), + } + + // Encode claims to JSON + claimsJSON, err := json.Marshal(claims) + if err != nil { + return "", fmt.Errorf("failed to marshal claims: %w", err) + } + + // Create HMAC signature + h := hmac.New(sha256.New, jobTokenSecret) + h.Write(claimsJSON) + signature := h.Sum(nil) + + // Combine claims and signature: base64(claims).base64(signature) + token := base64.RawURLEncoding.EncodeToString(claimsJSON) + "." + + base64.RawURLEncoding.EncodeToString(signature) + + return token, nil +} + +// ValidateJobToken validates a job token and returns the claims if valid +func ValidateJobToken(token string) (*JobTokenClaims, error) { + // Split token into claims and signature + var claimsB64, sigB64 string + dotIdx := -1 + for i := len(token) - 1; i >= 0; i-- { + if token[i] == '.' { + dotIdx = i + break + } + } + if dotIdx == -1 { + return nil, fmt.Errorf("invalid token format") + } + claimsB64 = token[:dotIdx] + sigB64 = token[dotIdx+1:] + + // Decode claims + claimsJSON, err := base64.RawURLEncoding.DecodeString(claimsB64) + if err != nil { + return nil, fmt.Errorf("invalid token encoding: %w", err) + } + + // Decode signature + signature, err := base64.RawURLEncoding.DecodeString(sigB64) + if err != nil { + return nil, fmt.Errorf("invalid signature encoding: %w", err) + } + + // Verify signature + h := hmac.New(sha256.New, jobTokenSecret) + h.Write(claimsJSON) + expectedSig := h.Sum(nil) + if !hmac.Equal(signature, expectedSig) { + return nil, fmt.Errorf("invalid signature") + } + + // Parse claims + var claims JobTokenClaims + if err := json.Unmarshal(claimsJSON, &claims); err != nil { + return nil, fmt.Errorf("invalid claims: %w", err) + } + + // Check expiration + if time.Now().Unix() > claims.Exp { + return nil, fmt.Errorf("token expired") + } + + return &claims, nil +} + diff --git a/internal/database/migrations/000001_initial_schema.down.sql b/internal/database/migrations/000001_initial_schema.down.sql new file mode 100644 index 0000000..cdabb0f --- /dev/null +++ b/internal/database/migrations/000001_initial_schema.down.sql @@ -0,0 +1,36 @@ +-- Drop indexes +DROP INDEX IF EXISTS idx_sessions_expires_at; +DROP INDEX IF EXISTS idx_sessions_user_id; +DROP INDEX IF EXISTS idx_sessions_session_id; +DROP INDEX IF EXISTS idx_runners_last_heartbeat; +DROP INDEX IF EXISTS idx_task_steps_task_id; +DROP INDEX IF EXISTS idx_task_logs_runner_id; +DROP INDEX IF EXISTS idx_task_logs_task_id_id; +DROP INDEX IF EXISTS idx_task_logs_task_id_created_at; +DROP INDEX IF EXISTS idx_runners_api_key_id; +DROP INDEX IF EXISTS idx_runner_api_keys_created_by; +DROP INDEX IF EXISTS idx_runner_api_keys_active; +DROP INDEX IF EXISTS idx_runner_api_keys_prefix; +DROP INDEX IF EXISTS idx_job_files_job_id; +DROP INDEX IF EXISTS idx_tasks_started_at; +DROP INDEX IF EXISTS idx_tasks_job_status; +DROP INDEX IF EXISTS idx_tasks_status; +DROP INDEX IF EXISTS idx_tasks_runner_id; +DROP INDEX IF EXISTS idx_tasks_job_id; +DROP INDEX IF EXISTS idx_jobs_user_status_created; +DROP INDEX IF EXISTS idx_jobs_status; +DROP INDEX IF EXISTS idx_jobs_user_id; + +-- Drop tables (order matters due to foreign keys) +DROP TABLE IF EXISTS sessions; +DROP TABLE IF EXISTS settings; +DROP TABLE IF EXISTS task_steps; +DROP TABLE IF EXISTS task_logs; +DROP TABLE IF EXISTS manager_secrets; +DROP TABLE IF EXISTS job_files; +DROP TABLE IF EXISTS tasks; +DROP TABLE IF EXISTS runners; +DROP TABLE IF EXISTS jobs; +DROP TABLE IF EXISTS runner_api_keys; +DROP TABLE IF EXISTS users; + diff --git a/internal/database/migrations/000001_initial_schema.up.sql b/internal/database/migrations/000001_initial_schema.up.sql new file mode 100644 index 0000000..62ba159 --- /dev/null +++ b/internal/database/migrations/000001_initial_schema.up.sql @@ -0,0 +1,184 @@ +-- Enable foreign keys for SQLite +PRAGMA foreign_keys = ON; + +-- Users table +CREATE TABLE users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + email TEXT UNIQUE NOT NULL, + name TEXT NOT NULL, + oauth_provider TEXT NOT NULL, + oauth_id TEXT NOT NULL, + password_hash TEXT, + is_admin INTEGER NOT NULL DEFAULT 0, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(oauth_provider, oauth_id) +); + +-- Runner API keys table +CREATE TABLE runner_api_keys ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + key_prefix TEXT NOT NULL, + key_hash TEXT NOT NULL, + name TEXT NOT NULL, + description TEXT, + scope TEXT NOT NULL DEFAULT 'user', + is_active INTEGER NOT NULL DEFAULT 1, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_by INTEGER, + FOREIGN KEY (created_by) REFERENCES users(id), + UNIQUE(key_prefix) +); + +-- Jobs table +CREATE TABLE jobs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + job_type TEXT NOT NULL DEFAULT 'render', + name TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + progress REAL NOT NULL DEFAULT 0.0, + frame_start INTEGER, + frame_end INTEGER, + output_format TEXT, + blend_metadata TEXT, + retry_count INTEGER NOT NULL DEFAULT 0, + max_retries INTEGER NOT NULL DEFAULT 3, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + started_at TIMESTAMP, + completed_at TIMESTAMP, + error_message TEXT, + assigned_runner_id INTEGER, + FOREIGN KEY (user_id) REFERENCES users(id) +); + +-- Runners table +CREATE TABLE runners ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + hostname TEXT NOT NULL, + ip_address TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'offline', + last_heartbeat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + capabilities TEXT, + api_key_id INTEGER, + api_key_scope TEXT NOT NULL DEFAULT 'user', + priority INTEGER NOT NULL DEFAULT 100, + fingerprint TEXT, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (api_key_id) REFERENCES runner_api_keys(id) +); + +-- Tasks table +CREATE TABLE tasks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + job_id INTEGER NOT NULL, + runner_id INTEGER, + frame INTEGER NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + output_path TEXT, + task_type TEXT NOT NULL DEFAULT 'render', + current_step TEXT, + retry_count INTEGER NOT NULL DEFAULT 0, + max_retries INTEGER NOT NULL DEFAULT 3, + runner_failure_count INTEGER NOT NULL DEFAULT 0, + timeout_seconds INTEGER, + condition TEXT, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + started_at TIMESTAMP, + completed_at TIMESTAMP, + error_message TEXT, + FOREIGN KEY (job_id) REFERENCES jobs(id), + FOREIGN KEY (runner_id) REFERENCES runners(id) +); + +-- Job files table +CREATE TABLE job_files ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + job_id INTEGER NOT NULL, + file_type TEXT NOT NULL, + file_path TEXT NOT NULL, + file_name TEXT NOT NULL, + file_size INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (job_id) REFERENCES jobs(id) +); + +-- Manager secrets table +CREATE TABLE manager_secrets ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + secret TEXT UNIQUE NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +-- Task logs table +CREATE TABLE task_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_id INTEGER NOT NULL, + runner_id INTEGER, + log_level TEXT NOT NULL, + message TEXT NOT NULL, + step_name TEXT, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (task_id) REFERENCES tasks(id), + FOREIGN KEY (runner_id) REFERENCES runners(id) +); + +-- Task steps table +CREATE TABLE task_steps ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_id INTEGER NOT NULL, + step_name TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + started_at TIMESTAMP, + completed_at TIMESTAMP, + duration_ms INTEGER, + error_message TEXT, + FOREIGN KEY (task_id) REFERENCES tasks(id) +); + +-- Settings table +CREATE TABLE settings ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +-- Sessions table +CREATE TABLE sessions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT UNIQUE NOT NULL, + user_id INTEGER NOT NULL, + email TEXT NOT NULL, + name TEXT NOT NULL, + is_admin INTEGER NOT NULL DEFAULT 0, + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users(id) +); + +-- Indexes +CREATE INDEX idx_jobs_user_id ON jobs(user_id); +CREATE INDEX idx_jobs_status ON jobs(status); +CREATE INDEX idx_jobs_user_status_created ON jobs(user_id, status, created_at DESC); +CREATE INDEX idx_tasks_job_id ON tasks(job_id); +CREATE INDEX idx_tasks_runner_id ON tasks(runner_id); +CREATE INDEX idx_tasks_status ON tasks(status); +CREATE INDEX idx_tasks_job_status ON tasks(job_id, status); +CREATE INDEX idx_tasks_started_at ON tasks(started_at); +CREATE INDEX idx_job_files_job_id ON job_files(job_id); +CREATE INDEX idx_runner_api_keys_prefix ON runner_api_keys(key_prefix); +CREATE INDEX idx_runner_api_keys_active ON runner_api_keys(is_active); +CREATE INDEX idx_runner_api_keys_created_by ON runner_api_keys(created_by); +CREATE INDEX idx_runners_api_key_id ON runners(api_key_id); +CREATE INDEX idx_task_logs_task_id_created_at ON task_logs(task_id, created_at); +CREATE INDEX idx_task_logs_task_id_id ON task_logs(task_id, id DESC); +CREATE INDEX idx_task_logs_runner_id ON task_logs(runner_id); +CREATE INDEX idx_task_steps_task_id ON task_steps(task_id); +CREATE INDEX idx_runners_last_heartbeat ON runners(last_heartbeat); +CREATE INDEX idx_sessions_session_id ON sessions(session_id); +CREATE INDEX idx_sessions_user_id ON sessions(user_id); +CREATE INDEX idx_sessions_expires_at ON sessions(expires_at); + +-- Initialize registration_enabled setting +INSERT INTO settings (key, value, updated_at) VALUES ('registration_enabled', 'true', CURRENT_TIMESTAMP); + diff --git a/internal/database/schema.go b/internal/database/schema.go index 5c9c3a6..3598e16 100644 --- a/internal/database/schema.go +++ b/internal/database/schema.go @@ -2,26 +2,44 @@ package database import ( "database/sql" + "embed" "fmt" + "io/fs" "log" - "sync" + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/sqlite3" + "github.com/golang-migrate/migrate/v4/source/iofs" _ "github.com/mattn/go-sqlite3" ) -// DB wraps the database connection with mutex protection +//go:embed migrations/*.sql +var migrationsFS embed.FS + +// DB wraps the database connection +// Note: No mutex needed - we only have one connection per process and SQLite with WAL mode +// handles concurrent access safely type DB struct { - db *sql.DB - mu sync.Mutex + db *sql.DB } // NewDB creates a new database connection func NewDB(dbPath string) (*DB, error) { - db, err := sql.Open("sqlite3", dbPath) + // Use WAL mode for better concurrency (allows readers and writers simultaneously) + // Add timeout and busy handler for better concurrent access + db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_busy_timeout=5000") if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } + // Configure connection pool for better concurrency + // SQLite with WAL mode supports multiple concurrent readers and one writer + // Increasing pool size allows multiple HTTP requests to query the database simultaneously + // This prevents blocking when multiple requests come in (e.g., on page refresh) + db.SetMaxOpenConns(10) // Allow up to 10 concurrent connections + db.SetMaxIdleConns(5) // Keep 5 idle connections ready + db.SetConnMaxLifetime(0) // Connections don't expire + if err := db.Ping(); err != nil { return nil, fmt.Errorf("failed to ping database: %w", err) } @@ -31,30 +49,37 @@ func NewDB(dbPath string) (*DB, error) { return nil, fmt.Errorf("failed to enable foreign keys: %w", err) } + // Enable WAL mode explicitly (in case the connection string didn't work) + if _, err := db.Exec("PRAGMA journal_mode = WAL"); err != nil { + log.Printf("Warning: Failed to enable WAL mode: %v", err) + } + database := &DB{db: db} if err := database.migrate(); err != nil { return nil, fmt.Errorf("failed to migrate database: %w", err) } + // Verify connection is still open after migration + if err := db.Ping(); err != nil { + return nil, fmt.Errorf("database connection closed after migration: %w", err) + } + return database, nil } -// With executes a function with mutex-protected access to the database +// With executes a function with access to the database // The function receives the underlying *sql.DB connection +// No mutex needed - single connection + WAL mode handles concurrency func (db *DB) With(fn func(*sql.DB) error) error { - db.mu.Lock() - defer db.mu.Unlock() return fn(db.db) } -// WithTx executes a function within a transaction with mutex protection +// WithTx executes a function within a transaction // The function receives a *sql.Tx transaction // If the function returns an error, the transaction is rolled back // If the function returns nil, the transaction is committed +// No mutex needed - single connection + WAL mode handles concurrency func (db *DB) WithTx(fn func(*sql.Tx) error) error { - db.mu.Lock() - defer db.mu.Unlock() - tx, err := db.db.Begin() if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) @@ -74,234 +99,61 @@ func (db *DB) WithTx(fn func(*sql.Tx) error) error { return nil } -// migrate runs database migrations +// migrate runs database migrations using golang-migrate func (db *DB) migrate() error { - // SQLite uses INTEGER PRIMARY KEY AUTOINCREMENT instead of sequences - schema := ` - CREATE TABLE IF NOT EXISTS users ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - email TEXT UNIQUE NOT NULL, - name TEXT NOT NULL, - oauth_provider TEXT NOT NULL, - oauth_id TEXT NOT NULL, - password_hash TEXT, - is_admin INTEGER NOT NULL DEFAULT 0, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - UNIQUE(oauth_provider, oauth_id) - ); - - CREATE TABLE IF NOT EXISTS runner_api_keys ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - key_prefix TEXT NOT NULL, - key_hash TEXT NOT NULL, - name TEXT NOT NULL, - description TEXT, - scope TEXT NOT NULL DEFAULT 'user', - is_active INTEGER NOT NULL DEFAULT 1, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - created_by INTEGER, - FOREIGN KEY (created_by) REFERENCES users(id), - UNIQUE(key_prefix) - ); - - CREATE TABLE IF NOT EXISTS jobs ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id INTEGER NOT NULL, - job_type TEXT NOT NULL DEFAULT 'render', - name TEXT NOT NULL, - status TEXT NOT NULL DEFAULT 'pending', - progress REAL NOT NULL DEFAULT 0.0, - frame_start INTEGER, - frame_end INTEGER, - output_format TEXT, - allow_parallel_runners INTEGER NOT NULL DEFAULT 1, - timeout_seconds INTEGER DEFAULT 86400, - blend_metadata TEXT, - retry_count INTEGER NOT NULL DEFAULT 0, - max_retries INTEGER NOT NULL DEFAULT 3, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - started_at TIMESTAMP, - completed_at TIMESTAMP, - error_message TEXT, - FOREIGN KEY (user_id) REFERENCES users(id) - ); - - CREATE TABLE IF NOT EXISTS runners ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL, - hostname TEXT NOT NULL, - ip_address TEXT NOT NULL, - status TEXT NOT NULL DEFAULT 'offline', - last_heartbeat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - capabilities TEXT, - api_key_id INTEGER, - api_key_scope TEXT NOT NULL DEFAULT 'user', - priority INTEGER NOT NULL DEFAULT 100, - fingerprint TEXT, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (api_key_id) REFERENCES runner_api_keys(id) - ); - - CREATE TABLE IF NOT EXISTS tasks ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - job_id INTEGER NOT NULL, - runner_id INTEGER, - frame_start INTEGER NOT NULL, - frame_end INTEGER NOT NULL, - status TEXT NOT NULL DEFAULT 'pending', - output_path TEXT, - task_type TEXT NOT NULL DEFAULT 'render', - current_step TEXT, - retry_count INTEGER NOT NULL DEFAULT 0, - max_retries INTEGER NOT NULL DEFAULT 3, - timeout_seconds INTEGER, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - started_at TIMESTAMP, - completed_at TIMESTAMP, - error_message TEXT, - FOREIGN KEY (job_id) REFERENCES jobs(id), - FOREIGN KEY (runner_id) REFERENCES runners(id) - ); - - CREATE TABLE IF NOT EXISTS job_files ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - job_id INTEGER NOT NULL, - file_type TEXT NOT NULL, - file_path TEXT NOT NULL, - file_name TEXT NOT NULL, - file_size INTEGER NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (job_id) REFERENCES jobs(id) - ); - - CREATE TABLE IF NOT EXISTS manager_secrets ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - secret TEXT UNIQUE NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP - ); - - CREATE TABLE IF NOT EXISTS task_logs ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - task_id INTEGER NOT NULL, - runner_id INTEGER, - log_level TEXT NOT NULL, - message TEXT NOT NULL, - step_name TEXT, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (task_id) REFERENCES tasks(id), - FOREIGN KEY (runner_id) REFERENCES runners(id) - ); - - CREATE TABLE IF NOT EXISTS task_steps ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - task_id INTEGER NOT NULL, - step_name TEXT NOT NULL, - status TEXT NOT NULL DEFAULT 'pending', - started_at TIMESTAMP, - completed_at TIMESTAMP, - duration_ms INTEGER, - error_message TEXT, - FOREIGN KEY (task_id) REFERENCES tasks(id) - ); - - CREATE INDEX IF NOT EXISTS idx_jobs_user_id ON jobs(user_id); - CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status); - CREATE INDEX IF NOT EXISTS idx_jobs_user_status_created ON jobs(user_id, status, created_at DESC); - CREATE INDEX IF NOT EXISTS idx_tasks_job_id ON tasks(job_id); - CREATE INDEX IF NOT EXISTS idx_tasks_runner_id ON tasks(runner_id); - CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status); - CREATE INDEX IF NOT EXISTS idx_tasks_job_status ON tasks(job_id, status); - CREATE INDEX IF NOT EXISTS idx_tasks_started_at ON tasks(started_at); - CREATE INDEX IF NOT EXISTS idx_job_files_job_id ON job_files(job_id); - CREATE INDEX IF NOT EXISTS idx_runner_api_keys_prefix ON runner_api_keys(key_prefix); - CREATE INDEX IF NOT EXISTS idx_runner_api_keys_active ON runner_api_keys(is_active); - CREATE INDEX IF NOT EXISTS idx_runner_api_keys_created_by ON runner_api_keys(created_by); - CREATE INDEX IF NOT EXISTS idx_runners_api_key_id ON runners(api_key_id); - CREATE INDEX IF NOT EXISTS idx_task_logs_task_id_created_at ON task_logs(task_id, created_at); - CREATE INDEX IF NOT EXISTS idx_task_logs_task_id_id ON task_logs(task_id, id DESC); - CREATE INDEX IF NOT EXISTS idx_task_logs_runner_id ON task_logs(runner_id); - CREATE INDEX IF NOT EXISTS idx_task_steps_task_id ON task_steps(task_id); - CREATE INDEX IF NOT EXISTS idx_runners_last_heartbeat ON runners(last_heartbeat); - - CREATE TABLE IF NOT EXISTS settings ( - key TEXT PRIMARY KEY, - value TEXT NOT NULL, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP - ); - - CREATE TABLE IF NOT EXISTS sessions ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - session_id TEXT UNIQUE NOT NULL, - user_id INTEGER NOT NULL, - email TEXT NOT NULL, - name TEXT NOT NULL, - is_admin INTEGER NOT NULL DEFAULT 0, - expires_at TIMESTAMP NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (user_id) REFERENCES users(id) - ); - - CREATE INDEX IF NOT EXISTS idx_sessions_session_id ON sessions(session_id); - CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id); - CREATE INDEX IF NOT EXISTS idx_sessions_expires_at ON sessions(expires_at); - ` - - if err := db.With(func(conn *sql.DB) error { - _, err := conn.Exec(schema) - return err - }); err != nil { - return fmt.Errorf("failed to create schema: %w", err) + // Create SQLite driver instance + // Note: We use db.db directly since we're in the same package and this is called during initialization + driver, err := sqlite3.WithInstance(db.db, &sqlite3.Config{}) + if err != nil { + return fmt.Errorf("failed to create sqlite3 driver: %w", err) } - // Database migrations for schema updates - // NOTE: Migrations are currently disabled since the database is cleared by 'make cleanup-manager' - // before running. All schema changes have been rolled into the main schema above. - // When ready to implement proper migrations for production, uncomment and populate this array. - // TODO: Implement proper database migration system for production use - migrations := []string{ - // Future migrations will go here when we implement proper migration handling + // Create embedded filesystem source + migrationFS, err := fs.Sub(migrationsFS, "migrations") + if err != nil { + return fmt.Errorf("failed to create migration filesystem: %w", err) } - for _, migration := range migrations { - if err := db.With(func(conn *sql.DB) error { - _, err := conn.Exec(migration) - return err - }); err != nil { - // Log but don't fail - column might already exist or table might not exist yet - // This is fine for migrations that run after schema creation - log.Printf("Migration warning: %v", err) + sourceDriver, err := iofs.New(migrationFS, ".") + if err != nil { + return fmt.Errorf("failed to create iofs source driver: %w", err) + } + + // Create migrate instance + m, err := migrate.NewWithInstance("iofs", sourceDriver, "sqlite3", driver) + if err != nil { + return fmt.Errorf("failed to create migrate instance: %w", err) + } + + // Run migrations + if err := m.Up(); err != nil { + // If the error is "no change", that's fine - database is already up to date + if err == migrate.ErrNoChange { + log.Printf("Database is already up to date") + // Don't close migrate instance - it may close the database connection + // The migrate instance will be garbage collected + return nil } + // Don't close migrate instance on error either - it may close the DB + return fmt.Errorf("failed to run migrations: %w", err) } - // Initialize registration_enabled setting (default: true) if it doesn't exist - var settingCount int - err := db.With(func(conn *sql.DB) error { - return conn.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount) - }) - if err == nil && settingCount == 0 { - err = db.With(func(conn *sql.DB) error { - _, err := conn.Exec("INSERT INTO settings (key, value) VALUES (?, ?)", "registration_enabled", "true") - return err - }) - if err != nil { - // Log but don't fail - setting might have been created by another process - log.Printf("Note: Could not initialize registration_enabled setting: %v", err) - } - } + // Don't close the migrate instance - with sqlite3.WithInstance, closing it + // may close the underlying database connection. The migrate instance will + // be garbage collected when it goes out of scope. + // If we need to close it later, we can store it in the DB struct and close + // it when DB.Close() is called, but for now we'll let it be GC'd. + log.Printf("Database migrations completed successfully") return nil } // Ping checks the database connection func (db *DB) Ping() error { - db.mu.Lock() - defer db.mu.Unlock() return db.db.Ping() } // Close closes the database connection func (db *DB) Close() error { - db.mu.Lock() - defer db.mu.Unlock() return db.db.Close() } diff --git a/internal/api/admin.go b/internal/manager/admin.go similarity index 83% rename from internal/api/admin.go rename to internal/manager/admin.go index 5b36bd6..0211c55 100644 --- a/internal/api/admin.go +++ b/internal/manager/admin.go @@ -11,7 +11,7 @@ import ( ) // handleGenerateRunnerAPIKey generates a new runner API key -func (s *Server) handleGenerateRunnerAPIKey(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleGenerateRunnerAPIKey(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -62,7 +62,7 @@ func (s *Server) handleGenerateRunnerAPIKey(w http.ResponseWriter, r *http.Reque } // handleListRunnerAPIKeys lists all runner API keys -func (s *Server) handleListRunnerAPIKeys(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleListRunnerAPIKeys(w http.ResponseWriter, r *http.Request) { keys, err := s.secrets.ListRunnerAPIKeys() if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to list API keys: %v", err)) @@ -90,7 +90,7 @@ func (s *Server) handleListRunnerAPIKeys(w http.ResponseWriter, r *http.Request) } // handleRevokeRunnerAPIKey revokes a runner API key -func (s *Server) handleRevokeRunnerAPIKey(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleRevokeRunnerAPIKey(w http.ResponseWriter, r *http.Request) { keyID, err := parseID(r, "id") if err != nil { s.respondError(w, http.StatusBadRequest, err.Error()) @@ -106,7 +106,7 @@ func (s *Server) handleRevokeRunnerAPIKey(w http.ResponseWriter, r *http.Request } // handleDeleteRunnerAPIKey deletes a runner API key -func (s *Server) handleDeleteRunnerAPIKey(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleDeleteRunnerAPIKey(w http.ResponseWriter, r *http.Request) { keyID, err := parseID(r, "id") if err != nil { s.respondError(w, http.StatusBadRequest, err.Error()) @@ -122,7 +122,7 @@ func (s *Server) handleDeleteRunnerAPIKey(w http.ResponseWriter, r *http.Request } // handleVerifyRunner manually verifies a runner -func (s *Server) handleVerifyRunner(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleVerifyRunner(w http.ResponseWriter, r *http.Request) { runnerID, err := parseID(r, "id") if err != nil { s.respondError(w, http.StatusBadRequest, err.Error()) @@ -153,7 +153,7 @@ func (s *Server) handleVerifyRunner(w http.ResponseWriter, r *http.Request) { } // handleDeleteRunner removes a runner -func (s *Server) handleDeleteRunner(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleDeleteRunner(w http.ResponseWriter, r *http.Request) { runnerID, err := parseID(r, "id") if err != nil { s.respondError(w, http.StatusBadRequest, err.Error()) @@ -184,15 +184,15 @@ func (s *Server) handleDeleteRunner(w http.ResponseWriter, r *http.Request) { } // handleListRunnersAdmin lists all runners with admin details -func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request) { var rows *sql.Rows err := s.db.With(func(conn *sql.DB) error { var err error rows, err = conn.Query( - `SELECT id, name, hostname, status, last_heartbeat, capabilities, + `SELECT id, name, hostname, status, last_heartbeat, capabilities, api_key_id, api_key_scope, priority, created_at FROM runners ORDER BY created_at DESC`, - ) + ) return err }) if err != nil { @@ -201,15 +201,6 @@ func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request) } defer rows.Close() - // Get the set of currently connected runners via WebSocket - // This is the source of truth for online status - s.runnerConnsMu.RLock() - connectedRunners := make(map[int64]bool) - for runnerID := range s.runnerConns { - connectedRunners[runnerID] = true - } - s.runnerConnsMu.RUnlock() - runners := []map[string]interface{}{} for rows.Next() { var runner types.Runner @@ -226,21 +217,13 @@ func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request) return } - // Override status based on actual WebSocket connection state - // The WebSocket connection is the source of truth for runner status - actualStatus := runner.Status - if connectedRunners[runner.ID] { - actualStatus = types.RunnerStatusOnline - } else if runner.Status == types.RunnerStatusOnline { - // Database says online but not connected via WebSocket - mark as offline - actualStatus = types.RunnerStatusOffline - } - + // In polling model, database status is the source of truth + // Runners update their status when they poll for jobs runners = append(runners, map[string]interface{}{ "id": runner.ID, "name": runner.Name, "hostname": runner.Hostname, - "status": actualStatus, + "status": runner.Status, "last_heartbeat": runner.LastHeartbeat, "capabilities": runner.Capabilities, "api_key_id": apiKeyID.Int64, @@ -254,7 +237,7 @@ func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request) } // handleListUsers lists all users -func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleListUsers(w http.ResponseWriter, r *http.Request) { // Get first user ID to mark it in the response firstUserID, err := s.auth.GetFirstUserID() if err != nil { @@ -266,9 +249,9 @@ func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) { err = s.db.With(func(conn *sql.DB) error { var err error rows, err = conn.Query( - `SELECT id, email, name, oauth_provider, is_admin, created_at + `SELECT id, email, name, oauth_provider, is_admin, created_at FROM users ORDER BY created_at DESC`, - ) + ) return err }) if err != nil { @@ -315,7 +298,7 @@ func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) { } // handleGetUserJobs gets all jobs for a specific user -func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleGetUserJobs(w http.ResponseWriter, r *http.Request) { userID, err := parseID(r, "id") if err != nil { s.respondError(w, http.StatusBadRequest, err.Error()) @@ -336,11 +319,11 @@ func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) { err = s.db.With(func(conn *sql.DB) error { var err error rows, err = conn.Query( - `SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format, - allow_parallel_runners, timeout_seconds, blend_metadata, created_at, started_at, completed_at, error_message + `SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format, + blend_metadata, created_at, started_at, completed_at, error_message FROM jobs WHERE user_id = ? ORDER BY created_at DESC`, - userID, - ) + userID, + ) return err }) if err != nil { @@ -358,11 +341,9 @@ func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) { var errorMessage sql.NullString var frameStart, frameEnd sql.NullInt64 var outputFormat sql.NullString - var allowParallelRunners sql.NullBool - err := rows.Scan( &job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress, - &frameStart, &frameEnd, &outputFormat, &allowParallelRunners, &job.TimeoutSeconds, + &frameStart, &frameEnd, &outputFormat, &blendMetadataJSON, &job.CreatedAt, &startedAt, &completedAt, &errorMessage, ) if err != nil { @@ -382,9 +363,6 @@ func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) { if outputFormat.Valid { job.OutputFormat = &outputFormat.String } - if allowParallelRunners.Valid { - job.AllowParallelRunners = &allowParallelRunners.Bool - } if startedAt.Valid { job.StartedAt = &startedAt.Time } @@ -408,7 +386,7 @@ func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) { } // handleGetRegistrationEnabled gets the registration enabled setting -func (s *Server) handleGetRegistrationEnabled(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleGetRegistrationEnabled(w http.ResponseWriter, r *http.Request) { enabled, err := s.auth.IsRegistrationEnabled() if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to get registration setting: %v", err)) @@ -418,7 +396,7 @@ func (s *Server) handleGetRegistrationEnabled(w http.ResponseWriter, r *http.Req } // handleSetRegistrationEnabled sets the registration enabled setting -func (s *Server) handleSetRegistrationEnabled(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleSetRegistrationEnabled(w http.ResponseWriter, r *http.Request) { var req struct { Enabled bool `json:"enabled"` } @@ -436,7 +414,7 @@ func (s *Server) handleSetRegistrationEnabled(w http.ResponseWriter, r *http.Req } // handleSetUserAdminStatus sets a user's admin status (admin only) -func (s *Server) handleSetUserAdminStatus(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleSetUserAdminStatus(w http.ResponseWriter, r *http.Request) { targetUserID, err := parseID(r, "id") if err != nil { s.respondError(w, http.StatusBadRequest, err.Error()) diff --git a/internal/manager/blender.go b/internal/manager/blender.go new file mode 100644 index 0000000..6fad4e3 --- /dev/null +++ b/internal/manager/blender.go @@ -0,0 +1,831 @@ +package api + +import ( + "archive/tar" + "compress/bzip2" + "compress/gzip" + "fmt" + "io" + "log" + "net/http" + "os" + "os/exec" + "path/filepath" + "regexp" + "sort" + "strings" + "sync" + "time" +) + +const ( + BlenderDownloadBaseURL = "https://download.blender.org/release/" + BlenderVersionCacheTTL = 1 * time.Hour +) + +// BlenderVersion represents a parsed Blender version +type BlenderVersion struct { + Major int `json:"major"` + Minor int `json:"minor"` + Patch int `json:"patch"` + Full string `json:"full"` // e.g., "4.2.3" + DirName string `json:"dir_name"` // e.g., "Blender4.2" + Filename string `json:"filename"` // e.g., "blender-4.2.3-linux-x64.tar.xz" + URL string `json:"url"` // Full download URL +} + +// BlenderVersionCache caches available Blender versions +type BlenderVersionCache struct { + versions []BlenderVersion + fetchedAt time.Time + mu sync.RWMutex +} + +var blenderVersionCache = &BlenderVersionCache{} + +// FetchBlenderVersions fetches available Blender versions from download.blender.org +// Returns versions sorted by version number (newest first) +func (s *Manager) FetchBlenderVersions() ([]BlenderVersion, error) { + // Check cache first + blenderVersionCache.mu.RLock() + if time.Since(blenderVersionCache.fetchedAt) < BlenderVersionCacheTTL && len(blenderVersionCache.versions) > 0 { + versions := make([]BlenderVersion, len(blenderVersionCache.versions)) + copy(versions, blenderVersionCache.versions) + blenderVersionCache.mu.RUnlock() + return versions, nil + } + blenderVersionCache.mu.RUnlock() + + // Fetch from website with timeout + client := &http.Client{ + Timeout: WSWriteDeadline, + } + resp, err := client.Get(BlenderDownloadBaseURL) + if err != nil { + return nil, fmt.Errorf("failed to fetch blender releases: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to fetch blender releases: status %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + // Parse directory listing for Blender version folders + // Looking for patterns like href="Blender4.2/" or href="Blender3.6/" + dirPattern := regexp.MustCompile(`href="Blender(\d+)\.(\d+)/"`) + log.Printf("Fetching Blender versions from %s", BlenderDownloadBaseURL) + matches := dirPattern.FindAllStringSubmatch(string(body), -1) + + // Fetch sub-versions concurrently to speed up the process + type versionResult struct { + versions []BlenderVersion + err error + } + results := make(chan versionResult, len(matches)) + var wg sync.WaitGroup + + for _, match := range matches { + if len(match) < 3 { + continue + } + + major := 0 + minor := 0 + fmt.Sscanf(match[1], "%d", &major) + fmt.Sscanf(match[2], "%d", &minor) + + // Skip very old versions (pre-2.80) + if major < 2 || (major == 2 && minor < 80) { + continue + } + + dirName := fmt.Sprintf("Blender%d.%d", major, minor) + + // Fetch the specific version directory concurrently + wg.Add(1) + go func(dn string, maj, min int) { + defer wg.Done() + subVersions, err := fetchSubVersions(dn, maj, min) + results <- versionResult{versions: subVersions, err: err} + }(dirName, major, minor) + } + + // Close results channel when all goroutines complete + go func() { + wg.Wait() + close(results) + }() + + var versions []BlenderVersion + for result := range results { + if result.err != nil { + log.Printf("Warning: failed to fetch sub-versions: %v", result.err) + continue + } + versions = append(versions, result.versions...) + } + + // Sort by version (newest first) + sort.Slice(versions, func(i, j int) bool { + if versions[i].Major != versions[j].Major { + return versions[i].Major > versions[j].Major + } + if versions[i].Minor != versions[j].Minor { + return versions[i].Minor > versions[j].Minor + } + return versions[i].Patch > versions[j].Patch + }) + + // Update cache + blenderVersionCache.mu.Lock() + blenderVersionCache.versions = versions + blenderVersionCache.fetchedAt = time.Now() + blenderVersionCache.mu.Unlock() + + return versions, nil +} + +// fetchSubVersions fetches specific version files from a Blender release directory +func fetchSubVersions(dirName string, major, minor int) ([]BlenderVersion, error) { + url := BlenderDownloadBaseURL + dirName + "/" + client := &http.Client{ + Timeout: WSWriteDeadline, + } + resp, err := client.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("status %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + // Look for linux 64-bit tar.xz/bz2 files + // Various naming conventions across versions: + // - Modern (2.93+): blender-4.2.3-linux-x64.tar.xz + // - 2.83 early: blender-2.83.0-linux64.tar.xz + // - 2.80-2.82: blender-2.80-linux-glibc217-x86_64.tar.bz2 + // Skip: rc versions, alpha/beta, i686 (32-bit) + filePatterns := []*regexp.Regexp{ + // Modern format: blender-X.Y.Z-linux-x64.tar.xz + regexp.MustCompile(`blender-(\d+)\.(\d+)\.(\d+)-linux-x64\.tar\.(xz|bz2)`), + // Older format: blender-X.Y.Z-linux64.tar.xz + regexp.MustCompile(`blender-(\d+)\.(\d+)\.(\d+)-linux64\.tar\.(xz|bz2)`), + // glibc format: blender-X.Y.Z-linux-glibc217-x86_64.tar.bz2 (prefer glibc217 for compatibility) + regexp.MustCompile(`blender-(\d+)\.(\d+)\.(\d+)-linux-glibc217-x86_64\.tar\.(xz|bz2)`), + } + + var versions []BlenderVersion + seen := make(map[string]bool) + + for _, filePattern := range filePatterns { + matches := filePattern.FindAllStringSubmatch(string(body), -1) + + for _, match := range matches { + if len(match) < 5 { + continue + } + + patch := 0 + fmt.Sscanf(match[3], "%d", &patch) + + full := fmt.Sprintf("%d.%d.%d", major, minor, patch) + if seen[full] { + continue + } + seen[full] = true + + filename := match[0] + versions = append(versions, BlenderVersion{ + Major: major, + Minor: minor, + Patch: patch, + Full: full, + DirName: dirName, + Filename: filename, + URL: url + filename, + }) + } + } + + return versions, nil +} + +// GetLatestBlenderForMajorMinor returns the latest patch version for a given major.minor +// If exact match not found, uses fuzzy matching to find the closest available version +func (s *Manager) GetLatestBlenderForMajorMinor(major, minor int) (*BlenderVersion, error) { + versions, err := s.FetchBlenderVersions() + if err != nil { + return nil, err + } + + if len(versions) == 0 { + return nil, fmt.Errorf("no blender versions available") + } + + // Try exact match first - find the highest patch for this major.minor + var exactMatch *BlenderVersion + for i := range versions { + v := &versions[i] + if v.Major == major && v.Minor == minor { + if exactMatch == nil || v.Patch > exactMatch.Patch { + exactMatch = v + } + } + } + if exactMatch != nil { + log.Printf("Found Blender %d.%d.%d for requested %d.%d", exactMatch.Major, exactMatch.Minor, exactMatch.Patch, major, minor) + return exactMatch, nil + } + + // Fuzzy matching: find closest version + // Priority: same major with closest minor > closest major + log.Printf("No exact match for Blender %d.%d, using fuzzy matching", major, minor) + + var bestMatch *BlenderVersion + bestScore := -1000000 // Large negative number + + for i := range versions { + v := &versions[i] + score := 0 + + if v.Major == major { + // Same major version - prefer this + score = 10000 + + // Prefer lower minor versions (more stable/compatible) + // but not too far back + minorDiff := minor - v.Minor + if minorDiff >= 0 { + // v.Minor <= minor (older or same) - prefer closer + score += 1000 - minorDiff*10 + } else { + // v.Minor > minor (newer) - less preferred but acceptable + score += 500 + minorDiff*10 + } + + // Higher patch is better + score += v.Patch + } else { + // Different major - less preferred + majorDiff := major - v.Major + if majorDiff > 0 { + // v.Major < major (older major) - acceptable fallback + score = 5000 - majorDiff*1000 + v.Minor*10 + v.Patch + } else { + // v.Major > major (newer major) - avoid if possible + score = -majorDiff * 1000 + } + } + + if score > bestScore { + bestScore = score + bestMatch = v + } + } + + if bestMatch != nil { + log.Printf("Fuzzy match: requested %d.%d, using %d.%d.%d", major, minor, bestMatch.Major, bestMatch.Minor, bestMatch.Patch) + return bestMatch, nil + } + + return nil, fmt.Errorf("no blender version found for %d.%d", major, minor) +} + +// GetBlenderArchivePath returns the path to the cached blender archive for a specific version +// Downloads from blender.org and decompresses to .tar if not already cached +// The manager caches as uncompressed .tar to save decompression time on runners +func (s *Manager) GetBlenderArchivePath(version *BlenderVersion) (string, error) { + // Base directory for blender archives + blenderDir := filepath.Join(s.storage.BasePath(), "blender-versions") + if err := os.MkdirAll(blenderDir, 0755); err != nil { + return "", fmt.Errorf("failed to create blender directory: %w", err) + } + + // Cache as uncompressed .tar for faster runner downloads + // Convert filename like "blender-4.2.3-linux-x64.tar.xz" to "blender-4.2.3-linux-x64.tar" + tarFilename := version.Filename + tarFilename = strings.TrimSuffix(tarFilename, ".xz") + tarFilename = strings.TrimSuffix(tarFilename, ".bz2") + archivePath := filepath.Join(blenderDir, tarFilename) + + // Check if already cached as .tar + if _, err := os.Stat(archivePath); err == nil { + log.Printf("Using cached Blender %s at %s", version.Full, archivePath) + // Clean up any extracted folders that might exist + s.cleanupExtractedBlenderFolders(blenderDir, version) + return archivePath, nil + } + + // Need to download and decompress + log.Printf("Downloading Blender %s from %s", version.Full, version.URL) + + client := &http.Client{ + Timeout: 0, // No timeout for large downloads + } + resp, err := client.Get(version.URL) + if err != nil { + return "", fmt.Errorf("failed to download blender: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("failed to download blender: status %d", resp.StatusCode) + } + + // Download to temp file first + compressedPath := filepath.Join(blenderDir, "download-"+version.Filename) + compressedFile, err := os.Create(compressedPath) + if err != nil { + return "", fmt.Errorf("failed to create temp file: %w", err) + } + + if _, err := io.Copy(compressedFile, resp.Body); err != nil { + compressedFile.Close() + os.Remove(compressedPath) + return "", fmt.Errorf("failed to download blender: %w", err) + } + compressedFile.Close() + + log.Printf("Downloaded Blender %s, decompressing to .tar...", version.Full) + + // Decompress to .tar + if err := decompressToTar(compressedPath, archivePath); err != nil { + os.Remove(compressedPath) + os.Remove(archivePath) + return "", fmt.Errorf("failed to decompress blender archive: %w", err) + } + + // Remove compressed file + os.Remove(compressedPath) + + // Clean up any extracted folders for this version (if they exist) + s.cleanupExtractedBlenderFolders(blenderDir, version) + + log.Printf("Blender %s cached at %s", version.Full, archivePath) + return archivePath, nil +} + +// decompressToTar decompresses a .tar.xz or .tar.bz2 file to a plain .tar file +func decompressToTar(compressedPath, tarPath string) error { + if strings.HasSuffix(compressedPath, ".tar.xz") { + // Use xz command for decompression + cmd := exec.Command("xz", "-d", "-k", "-c", compressedPath) + outFile, err := os.Create(tarPath) + if err != nil { + return err + } + defer outFile.Close() + + cmd.Stdout = outFile + if err := cmd.Run(); err != nil { + return fmt.Errorf("xz decompression failed: %w", err) + } + return nil + } else if strings.HasSuffix(compressedPath, ".tar.bz2") { + // Use bzip2 for decompression + inFile, err := os.Open(compressedPath) + if err != nil { + return err + } + defer inFile.Close() + + bzReader := bzip2.NewReader(inFile) + outFile, err := os.Create(tarPath) + if err != nil { + return err + } + defer outFile.Close() + + if _, err := io.Copy(outFile, bzReader); err != nil { + return fmt.Errorf("bzip2 decompression failed: %w", err) + } + return nil + } + + return fmt.Errorf("unsupported compression format: %s", compressedPath) +} + +// cleanupExtractedBlenderFolders removes any extracted Blender folders for the given version +// This ensures we only keep the .tar file and not extracted folders +func (s *Manager) cleanupExtractedBlenderFolders(blenderDir string, version *BlenderVersion) { + // Look for folders matching the version (e.g., "4.2.3", "2.83.20") + versionDirs := []string{ + filepath.Join(blenderDir, version.Full), // e.g., "4.2.3" + filepath.Join(blenderDir, fmt.Sprintf("%d.%d", version.Major, version.Minor)), // e.g., "4.2" + } + + for _, dir := range versionDirs { + if info, err := os.Stat(dir); err == nil && info.IsDir() { + log.Printf("Removing extracted Blender folder: %s", dir) + if err := os.RemoveAll(dir); err != nil { + log.Printf("Warning: failed to remove extracted folder %s: %v", dir, err) + } else { + log.Printf("Removed extracted Blender folder: %s", dir) + } + } + } +} + +// ParseBlenderVersionFromFile parses the Blender version that a .blend file was saved with +// This reads the file header to determine the version +func ParseBlenderVersionFromFile(blendPath string) (major, minor int, err error) { + file, err := os.Open(blendPath) + if err != nil { + return 0, 0, fmt.Errorf("failed to open blend file: %w", err) + } + defer file.Close() + + return ParseBlenderVersionFromReader(file) +} + +// ParseBlenderVersionFromReader parses the Blender version from a reader +// Useful for reading from uploaded files without saving to disk first +func ParseBlenderVersionFromReader(r io.ReadSeeker) (major, minor int, err error) { + // Read the first 12 bytes of the blend file header + // Format: BLENDER-v or BLENDER_v + // The header is: "BLENDER" (7 bytes) + pointer size (1 byte: '-' for 64-bit, '_' for 32-bit) + // + endianness (1 byte: 'v' for little-endian, 'V' for big-endian) + // + version (3 bytes: e.g., "402" for 4.02) + header := make([]byte, 12) + n, err := r.Read(header) + if err != nil || n < 12 { + return 0, 0, fmt.Errorf("failed to read blend file header: %w", err) + } + + // Check for BLENDER magic + if string(header[:7]) != "BLENDER" { + // Might be compressed - try to decompress + r.Seek(0, 0) + return parseCompressedBlendVersion(r) + } + + // Parse version from bytes 9-11 (3 digits) + versionStr := string(header[9:12]) + var vMajor, vMinor int + + // Version format changed in Blender 3.0 + // Pre-3.0: "279" = 2.79, "280" = 2.80 + // 3.0+: "300" = 3.0, "402" = 4.02, "410" = 4.10 + if len(versionStr) == 3 { + // First digit is major version + fmt.Sscanf(string(versionStr[0]), "%d", &vMajor) + // Next two digits are minor version + fmt.Sscanf(versionStr[1:3], "%d", &vMinor) + } + + return vMajor, vMinor, nil +} + +// parseCompressedBlendVersion handles gzip and zstd compressed blend files +func parseCompressedBlendVersion(r io.ReadSeeker) (major, minor int, err error) { + // Check for compression magic bytes + magic := make([]byte, 4) + if _, err := r.Read(magic); err != nil { + return 0, 0, err + } + r.Seek(0, 0) + + if magic[0] == 0x1f && magic[1] == 0x8b { + // gzip compressed + gzReader, err := gzip.NewReader(r) + if err != nil { + return 0, 0, fmt.Errorf("failed to create gzip reader: %w", err) + } + defer gzReader.Close() + + header := make([]byte, 12) + n, err := gzReader.Read(header) + if err != nil || n < 12 { + return 0, 0, fmt.Errorf("failed to read compressed blend header: %w", err) + } + + if string(header[:7]) != "BLENDER" { + return 0, 0, fmt.Errorf("invalid blend file format") + } + + versionStr := string(header[9:12]) + var vMajor, vMinor int + if len(versionStr) == 3 { + fmt.Sscanf(string(versionStr[0]), "%d", &vMajor) + fmt.Sscanf(versionStr[1:3], "%d", &vMinor) + } + + return vMajor, vMinor, nil + } + + // Check for zstd magic (Blender 3.0+): 0x28 0xB5 0x2F 0xFD + if magic[0] == 0x28 && magic[1] == 0xb5 && magic[2] == 0x2f && magic[3] == 0xfd { + return parseZstdBlendVersion(r) + } + + return 0, 0, fmt.Errorf("unknown blend file format") +} + +// parseZstdBlendVersion handles zstd-compressed blend files (Blender 3.0+) +// Uses zstd command line tool since Go doesn't have native zstd support +func parseZstdBlendVersion(r io.ReadSeeker) (major, minor int, err error) { + r.Seek(0, 0) + + // We need to decompress just enough to read the header + // Use zstd command to decompress from stdin + cmd := exec.Command("zstd", "-d", "-c") + cmd.Stdin = r + + stdout, err := cmd.StdoutPipe() + if err != nil { + return 0, 0, fmt.Errorf("failed to create zstd stdout pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + return 0, 0, fmt.Errorf("failed to start zstd decompression: %w", err) + } + + // Read just the header (12 bytes) + header := make([]byte, 12) + n, readErr := io.ReadFull(stdout, header) + + // Kill the process early - we only need the header + cmd.Process.Kill() + cmd.Wait() + + if readErr != nil || n < 12 { + return 0, 0, fmt.Errorf("failed to read zstd compressed blend header: %v", readErr) + } + + if string(header[:7]) != "BLENDER" { + return 0, 0, fmt.Errorf("invalid blend file format in zstd archive") + } + + versionStr := string(header[9:12]) + var vMajor, vMinor int + if len(versionStr) == 3 { + fmt.Sscanf(string(versionStr[0]), "%d", &vMajor) + fmt.Sscanf(versionStr[1:3], "%d", &vMinor) + } + + return vMajor, vMinor, nil +} + +// handleGetBlenderVersions returns available Blender versions +func (s *Manager) handleGetBlenderVersions(w http.ResponseWriter, r *http.Request) { + versions, err := s.FetchBlenderVersions() + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to fetch blender versions: %v", err)) + return + } + + // Group by major.minor for easier frontend display + type VersionGroup struct { + MajorMinor string `json:"major_minor"` + Latest BlenderVersion `json:"latest"` + All []BlenderVersion `json:"all"` + } + + groups := make(map[string]*VersionGroup) + for _, v := range versions { + key := fmt.Sprintf("%d.%d", v.Major, v.Minor) + if groups[key] == nil { + groups[key] = &VersionGroup{ + MajorMinor: key, + Latest: v, // First one is latest due to sorting + All: []BlenderVersion{v}, + } + } else { + groups[key].All = append(groups[key].All, v) + } + } + + // Convert to slice and sort by version + var groupedResult []VersionGroup + for _, g := range groups { + groupedResult = append(groupedResult, *g) + } + sort.Slice(groupedResult, func(i, j int) bool { + // Parse major.minor for comparison + var iMaj, iMin, jMaj, jMin int + fmt.Sscanf(groupedResult[i].MajorMinor, "%d.%d", &iMaj, &iMin) + fmt.Sscanf(groupedResult[j].MajorMinor, "%d.%d", &jMaj, &jMin) + if iMaj != jMaj { + return iMaj > jMaj + } + return iMin > jMin + }) + + // Return both flat list and grouped for flexibility + response := map[string]interface{}{ + "versions": versions, // Flat list of all versions (newest first) + "grouped": groupedResult, // Grouped by major.minor + } + + s.respondJSON(w, http.StatusOK, response) +} + +// handleDownloadBlender serves a cached Blender archive to runners +func (s *Manager) handleDownloadBlender(w http.ResponseWriter, r *http.Request) { + version := r.URL.Query().Get("version") + if version == "" { + s.respondError(w, http.StatusBadRequest, "version parameter required") + return + } + + // Parse version string (e.g., "4.2.3" or "4.2") + var major, minor, patch int + parts := strings.Split(version, ".") + if len(parts) < 2 { + s.respondError(w, http.StatusBadRequest, "invalid version format, expected major.minor or major.minor.patch") + return + } + + fmt.Sscanf(parts[0], "%d", &major) + fmt.Sscanf(parts[1], "%d", &minor) + if len(parts) >= 3 { + fmt.Sscanf(parts[2], "%d", &patch) + } + + // Find the version + var blenderVersion *BlenderVersion + if len(parts) >= 3 { + // Exact patch version requested - find it + versions, err := s.FetchBlenderVersions() + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to fetch versions: %v", err)) + return + } + + for _, v := range versions { + if v.Major == major && v.Minor == minor && v.Patch == patch { + blenderVersion = &v + break + } + } + + if blenderVersion == nil { + s.respondError(w, http.StatusNotFound, fmt.Sprintf("blender version %s not found", version)) + return + } + } else { + // Major.minor only - use helper to get latest patch version + var err error + blenderVersion, err = s.GetLatestBlenderForMajorMinor(major, minor) + if err != nil { + s.respondError(w, http.StatusNotFound, fmt.Sprintf("blender version %s not found: %v", version, err)) + return + } + } + + // Get or download the archive + archivePath, err := s.GetBlenderArchivePath(blenderVersion) + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to get blender archive: %v", err)) + return + } + + // Serve the file + file, err := os.Open(archivePath) + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to open archive: %v", err)) + return + } + defer file.Close() + + stat, err := file.Stat() + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to stat archive: %v", err)) + return + } + + // Filename is now .tar (decompressed) + tarFilename := blenderVersion.Filename + tarFilename = strings.TrimSuffix(tarFilename, ".xz") + tarFilename = strings.TrimSuffix(tarFilename, ".bz2") + + w.Header().Set("Content-Type", "application/x-tar") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", tarFilename)) + w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size())) + w.Header().Set("X-Blender-Version", blenderVersion.Full) + + io.Copy(w, file) +} + +// Unused functions from extraction - keeping for reference but not needed on manager +var _ = extractBlenderArchive +var _ = extractTarXz +var _ = extractTar + +// extractBlenderArchive extracts a blender archive (already decompressed to .tar by GetBlenderArchivePath) +func extractBlenderArchive(archivePath string, version *BlenderVersion, destDir string) error { + file, err := os.Open(archivePath) + if err != nil { + return err + } + defer file.Close() + + // The archive is already decompressed to .tar by GetBlenderArchivePath + // Just extract it directly + if strings.HasSuffix(archivePath, ".tar") { + tarReader := tar.NewReader(file) + return extractTar(tarReader, version, destDir) + } + + // Fallback for any other format (shouldn't happen with current flow) + if strings.HasSuffix(archivePath, ".tar.xz") { + return extractTarXz(archivePath, version, destDir) + } else if strings.HasSuffix(archivePath, ".tar.bz2") { + bzReader := bzip2.NewReader(file) + tarReader := tar.NewReader(bzReader) + return extractTar(tarReader, version, destDir) + } + + return fmt.Errorf("unsupported archive format: %s", archivePath) +} + +// extractTarXz extracts a tar.xz archive using the xz command +func extractTarXz(archivePath string, version *BlenderVersion, destDir string) error { + versionDir := filepath.Join(destDir, version.Full) + if err := os.MkdirAll(versionDir, 0755); err != nil { + return err + } + + cmd := exec.Command("tar", "-xJf", archivePath, "-C", versionDir, "--strip-components=1") + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("tar extraction failed: %v, output: %s", err, string(output)) + } + + return nil +} + +// extractTar extracts files from a tar reader +func extractTar(tarReader *tar.Reader, version *BlenderVersion, destDir string) error { + versionDir := filepath.Join(destDir, version.Full) + if err := os.MkdirAll(versionDir, 0755); err != nil { + return err + } + + stripPrefix := "" + + for { + header, err := tarReader.Next() + if err == io.EOF { + break + } + if err != nil { + return err + } + + if stripPrefix == "" { + parts := strings.SplitN(header.Name, "/", 2) + if len(parts) > 0 { + stripPrefix = parts[0] + "/" + } + } + + name := strings.TrimPrefix(header.Name, stripPrefix) + if name == "" { + continue + } + + targetPath := filepath.Join(versionDir, name) + + switch header.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil { + return err + } + case tar.TypeReg: + if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil { + return err + } + outFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode)) + if err != nil { + return err + } + if _, err := io.Copy(outFile, tarReader); err != nil { + outFile.Close() + return err + } + outFile.Close() + case tar.TypeSymlink: + if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil { + return err + } + if err := os.Symlink(header.Linkname, targetPath); err != nil { + return err + } + } + } + + return nil +} diff --git a/internal/api/jobs.go b/internal/manager/jobs.go similarity index 81% rename from internal/api/jobs.go rename to internal/manager/jobs.go index 5b08eda..9bd8e8e 100644 --- a/internal/api/jobs.go +++ b/internal/manager/jobs.go @@ -11,12 +11,14 @@ import ( "fmt" "io" "log" + "mime/multipart" "net/http" "os" "path/filepath" "strconv" "strings" "sync" + "sync/atomic" "time" authpkg "jiggablend/internal/auth" @@ -49,7 +51,7 @@ func isAdminUser(r *http.Request) bool { } // handleCreateJob creates a new job -func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleCreateJob(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -103,32 +105,29 @@ func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) { } } - // Default allow_parallel_runners to true for render jobs if not provided - var allowParallelRunners *bool - if req.JobType == types.JobTypeRender { - allowParallelRunners = new(bool) - *allowParallelRunners = true - if req.AllowParallelRunners != nil { - *allowParallelRunners = *req.AllowParallelRunners - } - } - - // Set job timeout to 24 hours (86400 seconds) - jobTimeout := 86400 - - // Store render settings, unhide_objects, and enable_execution flags in blend_metadata if provided + // Store render settings, unhide_objects, enable_execution, blender_version, preserve_hdr, and preserve_alpha flags in blend_metadata if provided + // Always include output_format in metadata so tasks can access it var blendMetadataJSON *string - if req.RenderSettings != nil || req.UnhideObjects != nil || req.EnableExecution != nil { + if req.RenderSettings != nil || req.UnhideObjects != nil || req.EnableExecution != nil || req.BlenderVersion != nil || req.OutputFormat != nil || req.PreserveHDR != nil || req.PreserveAlpha != nil { metadata := types.BlendMetadata{ FrameStart: *req.FrameStart, FrameEnd: *req.FrameEnd, RenderSettings: types.RenderSettings{}, UnhideObjects: req.UnhideObjects, EnableExecution: req.EnableExecution, + PreserveHDR: req.PreserveHDR, + PreserveAlpha: req.PreserveAlpha, } if req.RenderSettings != nil { metadata.RenderSettings = *req.RenderSettings } + // Always set output_format in metadata from job's output_format field + if req.OutputFormat != nil { + metadata.RenderSettings.OutputFormat = *req.OutputFormat + } + if req.BlenderVersion != nil { + metadata.BlenderVersion = *req.BlenderVersion + } metadataBytes, err := json.Marshal(metadata) if err == nil { metadataStr := string(metadataBytes) @@ -140,9 +139,9 @@ func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) { var jobID int64 err = s.db.With(func(conn *sql.DB) error { result, err := conn.Exec( - `INSERT INTO jobs (user_id, job_type, name, status, progress, frame_start, frame_end, output_format, allow_parallel_runners, timeout_seconds, blend_metadata) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - userID, req.JobType, req.Name, types.JobStatusPending, 0.0, *req.FrameStart, *req.FrameEnd, *req.OutputFormat, allowParallelRunners, jobTimeout, blendMetadataJSON, + `INSERT INTO jobs (user_id, job_type, name, status, progress, frame_start, frame_end, output_format, blend_metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + userID, req.JobType, req.Name, types.JobStatusPending, 0.0, *req.FrameStart, *req.FrameEnd, *req.OutputFormat, blendMetadataJSON, ) if err != nil { return err @@ -272,25 +271,21 @@ func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) { // Only create render tasks for render jobs if req.JobType == types.JobTypeRender { // Determine task timeout based on output format - taskTimeout := 300 // Default: 5 minutes for frame rendering - if *req.OutputFormat == "EXR_264_MP4" || *req.OutputFormat == "EXR_AV1_MP4" { - // For MP4, we'll create frame tasks with 5 min timeout - // Video generation tasks will be created later with 24h timeout - taskTimeout = 300 + taskTimeout := RenderTimeout // 1 hour for render jobs + if *req.OutputFormat == "EXR_264_MP4" || *req.OutputFormat == "EXR_AV1_MP4" || *req.OutputFormat == "EXR_VP9_WEBM" { + taskTimeout = VideoEncodeTimeout // 24 hours for encoding } // Create tasks for the job - // If allow_parallel_runners is false, create a single task for all frames - // Otherwise, create one task per frame for parallel processing + // Create one task per frame (all tasks are single-frame) var createdTaskIDs []int64 - if allowParallelRunners != nil && !*allowParallelRunners { - // Single task for entire frame range + for frame := *req.FrameStart; frame <= *req.FrameEnd; frame++ { var taskID int64 err = s.db.With(func(conn *sql.DB) error { result, err := conn.Exec( - `INSERT INTO tasks (job_id, frame_start, frame_end, task_type, status, timeout_seconds, max_retries) - VALUES (?, ?, ?, ?, ?, ?, ?)`, - jobID, *req.FrameStart, *req.FrameEnd, types.TaskTypeRender, types.TaskStatusPending, taskTimeout, 3, + `INSERT INTO tasks (job_id, frame, task_type, status, timeout_seconds, max_retries) + VALUES (?, ?, ?, ?, ?, ?)`, + jobID, frame, types.TaskTypeRender, types.TaskStatusPending, taskTimeout, 3, ) if err != nil { return err @@ -299,34 +294,38 @@ func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) { return err }) if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create task: %v", err)) + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create tasks: %v", err)) return } createdTaskIDs = append(createdTaskIDs, taskID) - log.Printf("Created 1 render task for job %d (frames %d-%d, single runner)", jobID, *req.FrameStart, *req.FrameEnd) - } else { - // One task per frame for parallel processing - for frame := *req.FrameStart; frame <= *req.FrameEnd; frame++ { - var taskID int64 - err = s.db.With(func(conn *sql.DB) error { - result, err := conn.Exec( - `INSERT INTO tasks (job_id, frame_start, frame_end, task_type, status, timeout_seconds, max_retries) - VALUES (?, ?, ?, ?, ?, ?, ?)`, - jobID, frame, frame, types.TaskTypeRender, types.TaskStatusPending, taskTimeout, 3, - ) - if err != nil { - return err - } - taskID, err = result.LastInsertId() - return err - }) + } + log.Printf("Created %d render tasks for job %d (frames %d-%d)", *req.FrameEnd-*req.FrameStart+1, jobID, *req.FrameStart, *req.FrameEnd) + + // Create encode task immediately if output format requires it + // The task will have a condition that prevents it from being assigned until all render tasks are completed + if *req.OutputFormat == "EXR_264_MP4" || *req.OutputFormat == "EXR_AV1_MP4" || *req.OutputFormat == "EXR_VP9_WEBM" { + encodeTaskTimeout := VideoEncodeTimeout // 24 hours for encoding + conditionJSON := `{"type": "all_render_tasks_completed"}` + var encodeTaskID int64 + err = s.db.With(func(conn *sql.DB) error { + result, err := conn.Exec( + `INSERT INTO tasks (job_id, frame, task_type, status, timeout_seconds, max_retries, condition) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + jobID, 0, types.TaskTypeEncode, types.TaskStatusPending, encodeTaskTimeout, 1, conditionJSON, + ) if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create tasks: %v", err)) - return + return err } - createdTaskIDs = append(createdTaskIDs, taskID) + encodeTaskID, err = result.LastInsertId() + return err + }) + if err != nil { + log.Printf("Failed to create encode task for job %d: %v", jobID, err) + // Don't fail the job creation if encode task creation fails + } else { + createdTaskIDs = append(createdTaskIDs, encodeTaskID) + log.Printf("Created encode task %d for job %d (with condition: all render tasks must be completed)", encodeTaskID, jobID) } - log.Printf("Created %d render tasks for job %d (frames %d-%d, parallel)", *req.FrameEnd-*req.FrameStart+1, jobID, *req.FrameStart, *req.FrameEnd) } // Update job status (should be pending since tasks are pending) s.updateJobStatusFromTasks(jobID) @@ -343,20 +342,18 @@ func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) { // Build response job object job := types.Job{ - ID: jobID, - UserID: userID, - JobType: req.JobType, - Name: req.Name, - Status: types.JobStatusPending, - Progress: 0.0, - TimeoutSeconds: jobTimeout, - CreatedAt: time.Now(), + ID: jobID, + UserID: userID, + JobType: req.JobType, + Name: req.Name, + Status: types.JobStatusPending, + Progress: 0.0, + CreatedAt: time.Now(), } if req.JobType == types.JobTypeRender { job.FrameStart = req.FrameStart job.FrameEnd = req.FrameEnd job.OutputFormat = req.OutputFormat - job.AllowParallelRunners = allowParallelRunners } // Broadcast job_created to all clients via jobs channel @@ -376,14 +373,11 @@ func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) { "timestamp": time.Now().Unix(), }) - // Immediately try to distribute tasks to connected runners - s.triggerTaskDistribution() - s.respondJSON(w, http.StatusCreated, job) } // handleListJobs lists jobs for the current user with pagination and filtering -func (s *Server) handleListJobs(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleListJobs(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -433,7 +427,7 @@ func (s *Server) handleListJobs(w http.ResponseWriter, r *http.Request) { // Build query with filters query := `SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format, - allow_parallel_runners, timeout_seconds, blend_metadata, created_at, started_at, completed_at, error_message + blend_metadata, created_at, started_at, completed_at, error_message FROM jobs WHERE user_id = ?` args := []interface{}{userID} @@ -494,11 +488,9 @@ func (s *Server) handleListJobs(w http.ResponseWriter, r *http.Request) { var errorMessage sql.NullString var frameStart, frameEnd sql.NullInt64 var outputFormat sql.NullString - var allowParallelRunners sql.NullBool - err := rows.Scan( &job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress, - &frameStart, &frameEnd, &outputFormat, &allowParallelRunners, &job.TimeoutSeconds, + &frameStart, &frameEnd, &outputFormat, &blendMetadataJSON, &job.CreatedAt, &startedAt, &completedAt, &errorMessage, ) if err != nil { @@ -518,9 +510,6 @@ func (s *Server) handleListJobs(w http.ResponseWriter, r *http.Request) { if outputFormat.Valid { job.OutputFormat = &outputFormat.String } - if allowParallelRunners.Valid { - job.AllowParallelRunners = &allowParallelRunners.Bool - } if startedAt.Valid { job.StartedAt = &startedAt.Time } @@ -559,7 +548,7 @@ func (s *Server) handleListJobs(w http.ResponseWriter, r *http.Request) { } // handleListJobsSummary lists lightweight job summaries for the current user -func (s *Server) handleListJobsSummary(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleListJobsSummary(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -702,7 +691,7 @@ func (s *Server) handleListJobsSummary(w http.ResponseWriter, r *http.Request) { } // handleBatchGetJobs fetches multiple jobs by IDs -func (s *Server) handleBatchGetJobs(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleBatchGetJobs(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -737,7 +726,7 @@ func (s *Server) handleBatchGetJobs(w http.ResponseWriter, r *http.Request) { } query := fmt.Sprintf(`SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format, - allow_parallel_runners, timeout_seconds, blend_metadata, created_at, started_at, completed_at, error_message + blend_metadata, created_at, started_at, completed_at, error_message FROM jobs WHERE user_id = ? AND id IN (%s) ORDER BY created_at DESC`, strings.Join(placeholders, ",")) var rows *sql.Rows @@ -761,11 +750,9 @@ func (s *Server) handleBatchGetJobs(w http.ResponseWriter, r *http.Request) { var errorMessage sql.NullString var frameStart, frameEnd sql.NullInt64 var outputFormat sql.NullString - var allowParallelRunners sql.NullBool - err := rows.Scan( &job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress, - &frameStart, &frameEnd, &outputFormat, &allowParallelRunners, &job.TimeoutSeconds, + &frameStart, &frameEnd, &outputFormat, &blendMetadataJSON, &job.CreatedAt, &startedAt, &completedAt, &errorMessage, ) if err != nil { @@ -785,9 +772,6 @@ func (s *Server) handleBatchGetJobs(w http.ResponseWriter, r *http.Request) { if outputFormat.Valid { job.OutputFormat = &outputFormat.String } - if allowParallelRunners.Valid { - job.AllowParallelRunners = &allowParallelRunners.Bool - } if startedAt.Valid { job.StartedAt = &startedAt.Time } @@ -811,7 +795,7 @@ func (s *Server) handleBatchGetJobs(w http.ResponseWriter, r *http.Request) { } // handleGetJob gets a specific job -func (s *Server) handleGetJob(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleGetJob(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -831,8 +815,6 @@ func (s *Server) handleGetJob(w http.ResponseWriter, r *http.Request) { var errorMessage sql.NullString var frameStart, frameEnd sql.NullInt64 var outputFormat sql.NullString - var allowParallelRunners sql.NullBool - // Allow admins to view any job, regular users can only view their own isAdmin := isAdminUser(r) var err2 error @@ -840,23 +822,23 @@ func (s *Server) handleGetJob(w http.ResponseWriter, r *http.Request) { if isAdmin { return conn.QueryRow( `SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format, - allow_parallel_runners, timeout_seconds, blend_metadata, created_at, started_at, completed_at, error_message + blend_metadata, created_at, started_at, completed_at, error_message FROM jobs WHERE id = ?`, jobID, ).Scan( &job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress, - &frameStart, &frameEnd, &outputFormat, &allowParallelRunners, &job.TimeoutSeconds, + &frameStart, &frameEnd, &outputFormat, &blendMetadataJSON, &job.CreatedAt, &startedAt, &completedAt, &errorMessage, ) } else { return conn.QueryRow( `SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format, - allow_parallel_runners, timeout_seconds, blend_metadata, created_at, started_at, completed_at, error_message + blend_metadata, created_at, started_at, completed_at, error_message FROM jobs WHERE id = ? AND user_id = ?`, jobID, userID, ).Scan( &job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress, - &frameStart, &frameEnd, &outputFormat, &allowParallelRunners, &job.TimeoutSeconds, + &frameStart, &frameEnd, &outputFormat, &blendMetadataJSON, &job.CreatedAt, &startedAt, &completedAt, &errorMessage, ) } @@ -883,9 +865,6 @@ func (s *Server) handleGetJob(w http.ResponseWriter, r *http.Request) { if outputFormat.Valid { job.OutputFormat = &outputFormat.String } - if allowParallelRunners.Valid { - job.AllowParallelRunners = &allowParallelRunners.Bool - } if startedAt.Valid { job.StartedAt = &startedAt.Time } @@ -915,7 +894,7 @@ func (s *Server) handleGetJob(w http.ResponseWriter, r *http.Request) { } // handleCancelJob cancels a job -func (s *Server) handleCancelJob(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleCancelJob(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -985,7 +964,7 @@ func (s *Server) handleCancelJob(w http.ResponseWriter, r *http.Request) { } // handleDeleteJob permanently deletes a job and all its associated data -func (s *Server) handleDeleteJob(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleDeleteJob(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -1079,7 +1058,7 @@ func (s *Server) handleDeleteJob(w http.ResponseWriter, r *http.Request) { } // cleanupOldRenderJobs periodically deletes render jobs older than 1 month -func (s *Server) cleanupOldRenderJobs() { +func (s *Manager) cleanupOldRenderJobs() { // Run cleanup every hour ticker := time.NewTicker(1 * time.Hour) defer ticker.Stop() @@ -1093,7 +1072,7 @@ func (s *Server) cleanupOldRenderJobs() { } // cleanupOldRenderJobsOnce finds and deletes render jobs older than 1 month that are completed, failed, or cancelled -func (s *Server) cleanupOldRenderJobsOnce() { +func (s *Manager) cleanupOldRenderJobsOnce() { defer func() { if r := recover(); r != nil { log.Printf("Panic in cleanupOldRenderJobs: %v", r) @@ -1191,7 +1170,7 @@ func (s *Server) cleanupOldRenderJobsOnce() { } // handleUploadJobFile handles file upload for a job -func (s *Server) handleUploadJobFile(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleUploadJobFile(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -1470,40 +1449,41 @@ func (s *Server) handleUploadJobFile(w http.ResponseWriter, r *http.Request) { // handleUploadFileForJobCreation handles file upload before job creation // Creates context archive and extracts metadata, returns metadata and upload session ID -func (s *Server) handleUploadFileForJobCreation(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleUploadFileForJobCreation(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) return } - // Parse multipart form with large limit for big files - err = r.ParseMultipartForm(20 << 30) // 20 GB + // Use MultipartReader to stream the file instead of loading it all into memory + // This allows us to report progress during upload + reader, err := r.MultipartReader() if err != nil { - log.Printf("Error parsing multipart form: %v", err) - s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Failed to parse form: %v", err)) + log.Printf("Error creating multipart reader: %v", err) + s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Failed to parse multipart form: %v", err)) return } - file, header, err := r.FormFile("file") - if err != nil { - log.Printf("Error getting file from form: %v", err) - s.respondError(w, http.StatusBadRequest, fmt.Sprintf("No file provided: %v", err)) - return - } - defer file.Close() + // Find the file part and collect form values + // IMPORTANT: With MultipartReader, we must read the file part's data immediately + // before calling NextPart() again, otherwise the data becomes unavailable + var header *multipart.FileHeader + var filePath string + formValues := make(map[string]string) + var tmpDir string + var sessionID string + var mainBlendFile string - log.Printf("Uploading file '%s' (size: %d bytes) for user %d (pre-job creation)", header.Filename, header.Size, userID) - - // Create temporary directory for processing upload (user-specific) - tmpDir, err := s.storage.TempDir(fmt.Sprintf("jiggablend-upload-user-%d-*", userID)) + // Create temporary directory first (before reading parts) + tmpDir, err = s.storage.TempDir(fmt.Sprintf("jiggablend-upload-user-%d-*", userID)) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create temporary directory: %v", err)) return } // Generate session ID (use temp directory path as session ID) - sessionID := tmpDir + sessionID = tmpDir // Create upload session s.uploadSessionsMu.Lock() @@ -1517,58 +1497,118 @@ func (s *Server) handleUploadFileForJobCreation(w http.ResponseWriter, r *http.R } s.uploadSessionsMu.Unlock() - // Broadcast initial upload status - s.broadcastUploadProgress(sessionID, 0.0, "uploading", "Uploading file...") + // Client tracks upload progress via XHR - no need to broadcast here + // We only broadcast processing status changes (extracting, creating context, etc.) - // Note: We'll clean this up after job creation or after timeout - // For now, we rely on the session cleanup mechanism, but also add defer for safety - defer func() { - // Only clean up if there's an error - otherwise let session cleanup handle it - // This is a safety net in case of early returns - }() - - var mainBlendFile string - var extractedFiles []string - - // Check if this is a ZIP file - if strings.HasSuffix(strings.ToLower(header.Filename), ".zip") { - log.Printf("Processing ZIP file '%s'", header.Filename) - // Save ZIP to temporary directory - zipPath := filepath.Join(tmpDir, header.Filename) - zipFile, err := os.Create(zipPath) + fileFound := false + for { + part, err := reader.NextPart() + if err == io.EOF { + break + } if err != nil { + log.Printf("Error reading multipart: %v", err) os.RemoveAll(tmpDir) - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create ZIP file: %v", err)) + s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Failed to read multipart: %v", err)) return } - copied, err := io.Copy(zipFile, file) - zipFile.Close() - if err != nil { - os.RemoveAll(tmpDir) - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to save ZIP file: %v", err)) - return + + formName := part.FormName() + if formName == "file" { + // Read file part immediately - can't store it for later with MultipartReader + header = &multipart.FileHeader{ + Filename: part.FileName(), + } + // Try to get Content-Length from header if available + if cl := part.Header.Get("Content-Length"); cl != "" { + if size, err := strconv.ParseInt(cl, 10, 64); err == nil { + header.Size = size + } + } + + // Determine file path + if strings.HasSuffix(strings.ToLower(header.Filename), ".zip") { + filePath = filepath.Join(tmpDir, header.Filename) + } else { + filePath = filepath.Join(tmpDir, header.Filename) + if strings.HasSuffix(strings.ToLower(header.Filename), ".blend") { + mainBlendFile = filePath + } + } + + // Create file and copy data immediately + outFile, err := os.Create(filePath) + if err != nil { + part.Close() + os.RemoveAll(tmpDir) + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create file: %v", err)) + return + } + + // Copy file data - must do this before calling NextPart() again + copied, err := io.Copy(outFile, part) + outFile.Close() + part.Close() + + if err != nil { + os.RemoveAll(tmpDir) + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to save file: %v", err)) + return + } + + // Update header size with actual bytes copied (in case Content-Length was wrong) + if header.Size == 0 { + header.Size = copied + } + + fileFound = true + log.Printf("Uploading file '%s' (size: %d bytes, copied: %d bytes) for user %d (pre-job creation)", header.Filename, header.Size, copied, userID) + } else if formName != "" { + // Read form value + valueBytes, err := io.ReadAll(part) + if err == nil { + formValues[formName] = string(valueBytes) + } + part.Close() + } else { + part.Close() } - log.Printf("Successfully copied %d bytes to ZIP file", copied) + } - // Broadcast upload complete, processing starts - s.broadcastUploadProgress(sessionID, 100.0, "processing", "Upload complete, processing file...") + if !fileFound { + os.RemoveAll(tmpDir) + s.respondError(w, http.StatusBadRequest, "No file provided") + return + } - // Extract ZIP file to temporary directory - s.broadcastUploadProgress(sessionID, 25.0, "extracting_zip", "Extracting ZIP file...") - extractedFiles, err = s.storage.ExtractZip(zipPath, tmpDir) + // Process everything synchronously and return metadata in HTTP response + // Client will show upload progress during upload, then processing progress while waiting + filename := header.Filename + fileSize := header.Size + mainBlendParam := formValues["main_blend_file"] + + var processedMainBlendFile string + var processedExtractedFiles []string + var processedMetadata *types.BlendMetadata + + // Process ZIP extraction if needed + if strings.HasSuffix(strings.ToLower(filename), ".zip") { + zipPath := filepath.Join(tmpDir, filename) + log.Printf("Extracting ZIP file: %s", zipPath) + processedExtractedFiles, err = s.storage.ExtractZip(zipPath, tmpDir) if err != nil { + log.Printf("ERROR: Failed to extract ZIP file: %v", err) os.RemoveAll(tmpDir) s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to extract ZIP file: %v", err)) return } - log.Printf("Successfully extracted %d files from ZIP", len(extractedFiles)) - s.broadcastUploadProgress(sessionID, 50.0, "extracting_zip", "ZIP extraction complete") + log.Printf("Successfully extracted %d files from ZIP", len(processedExtractedFiles)) // Find main blend file - mainBlendParam := r.FormValue("main_blend_file") if mainBlendParam != "" { - mainBlendFile = filepath.Join(tmpDir, mainBlendParam) - if _, err := os.Stat(mainBlendFile); err != nil { + processedMainBlendFile = filepath.Join(tmpDir, mainBlendParam) + if _, err := os.Stat(processedMainBlendFile); err != nil { + log.Printf("ERROR: Specified main blend file not found: %s", mainBlendParam) os.RemoveAll(tmpDir) s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Specified main blend file not found: %s", mainBlendParam)) return @@ -1589,128 +1629,100 @@ func (s *Server) handleUploadFileForJobCreation(w http.ResponseWriter, r *http.R return nil }) if err == nil && len(blendFiles) == 1 { - mainBlendFile = blendFiles[0] + processedMainBlendFile = blendFiles[0] } else if len(blendFiles) > 1 { - // Multiple blend files - return list for user to choose + // Multiple blend files - return response with list for user to select blendFileNames := []string{} for _, f := range blendFiles { rel, _ := filepath.Rel(tmpDir, f) blendFileNames = append(blendFileNames, rel) } - os.RemoveAll(tmpDir) - s.respondJSON(w, http.StatusOK, map[string]interface{}{ + // Return response indicating multiple blend files found + response := map[string]interface{}{ + "session_id": sessionID, + "file_name": filename, + "file_size": fileSize, + "status": "select_blend", "zip_extracted": true, "blend_files": blendFileNames, - "message": "Multiple blend files found. Please specify the main blend file.", - }) + } + s.respondJSON(w, http.StatusOK, response) return } } } else { - // Regular file upload (not ZIP) - save to temporary directory - filePath := filepath.Join(tmpDir, header.Filename) - outFile, err := os.Create(filePath) - if err != nil { - os.RemoveAll(tmpDir) - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create file: %v", err)) - return - } - - fileReader, _, err := r.FormFile("file") - if err != nil { - outFile.Close() - os.RemoveAll(tmpDir) - s.respondError(w, http.StatusBadRequest, fmt.Sprintf("No file provided: %v", err)) - return - } - - if _, err := io.Copy(outFile, fileReader); err != nil { - fileReader.Close() - outFile.Close() - os.RemoveAll(tmpDir) - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to save file: %v", err)) - return - } - fileReader.Close() - outFile.Close() - - // Broadcast upload complete for non-ZIP files - s.broadcastUploadProgress(sessionID, 100.0, "processing", "Upload complete, processing file...") - - if strings.HasSuffix(strings.ToLower(header.Filename), ".blend") { - mainBlendFile = filePath - } + processedMainBlendFile = mainBlendFile } - // Create context archive from temporary directory + // Create context archive var excludeFiles []string - if strings.HasSuffix(strings.ToLower(header.Filename), ".zip") { - excludeFiles = append(excludeFiles, header.Filename) + if strings.HasSuffix(strings.ToLower(filename), ".zip") { + excludeFiles = append(excludeFiles, filename) } - // Create context in temp directory (we'll move it to job directory later) - s.broadcastUploadProgress(sessionID, 75.0, "creating_context", "Creating context archive...") + log.Printf("Creating context archive for session %s", sessionID) contextPath := filepath.Join(tmpDir, "context.tar") contextPath, err = s.createContextFromDir(tmpDir, contextPath, excludeFiles...) if err != nil { - os.RemoveAll(tmpDir) log.Printf("ERROR: Failed to create context archive: %v", err) - s.broadcastUploadProgress(sessionID, 0.0, "error", fmt.Sprintf("Failed to create context archive: %v", err)) + os.RemoveAll(tmpDir) s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create context archive: %v", err)) return } // Extract metadata from context archive - s.broadcastUploadProgress(sessionID, 85.0, "extracting_metadata", "Extracting metadata from blend file...") - metadata, err := s.extractMetadataFromTempContext(contextPath) + log.Printf("Extracting metadata from context archive for session %s", sessionID) + processedMetadata, err = s.extractMetadataFromTempContext(contextPath) if err != nil { log.Printf("Warning: Failed to extract metadata: %v", err) // Continue anyway - user can fill in manually - metadata = nil + processedMetadata = nil } + // Build response with all results response := map[string]interface{}{ - "session_id": sessionID, // Full temp directory path - "file_name": header.Filename, - "file_size": header.Size, + "session_id": sessionID, + "file_name": filename, + "file_size": fileSize, "context_archive": filepath.Base(contextPath), + "status": "completed", } - if strings.HasSuffix(strings.ToLower(header.Filename), ".zip") { + if strings.HasSuffix(strings.ToLower(filename), ".zip") { response["zip_extracted"] = true - response["extracted_files_count"] = len(extractedFiles) - if mainBlendFile != "" { - relPath, _ := filepath.Rel(tmpDir, mainBlendFile) + response["extracted_files_count"] = len(processedExtractedFiles) + if processedMainBlendFile != "" { + relPath, _ := filepath.Rel(tmpDir, processedMainBlendFile) response["main_blend_file"] = relPath } - } else if mainBlendFile != "" { - relPath, _ := filepath.Rel(tmpDir, mainBlendFile) + } else if processedMainBlendFile != "" { + relPath, _ := filepath.Rel(tmpDir, processedMainBlendFile) response["main_blend_file"] = relPath } - if metadata != nil { - response["metadata"] = metadata + if processedMetadata != nil { + response["metadata"] = processedMetadata response["metadata_extracted"] = true } else { response["metadata_extracted"] = false } - // Broadcast processing complete - s.broadcastUploadProgress(sessionID, 100.0, "completed", "Processing complete") - - // Clean up upload session after a delay (client may still be subscribed) - go func() { - time.Sleep(5 * time.Minute) - s.uploadSessionsMu.Lock() - delete(s.uploadSessions, sessionID) - s.uploadSessionsMu.Unlock() - }() + // Clean up upload session immediately (no longer needed for WebSocket) + s.uploadSessionsMu.Lock() + delete(s.uploadSessions, sessionID) + s.uploadSessionsMu.Unlock() + // Return response with metadata s.respondJSON(w, http.StatusOK, response) } // extractMetadataFromTempContext extracts metadata from a context archive in a temporary location -func (s *Server) extractMetadataFromTempContext(contextPath string) (*types.BlendMetadata, error) { +func (s *Manager) extractMetadataFromTempContext(contextPath string) (*types.BlendMetadata, error) { + return s.extractMetadataFromTempContextWithProgress(contextPath, nil) +} + +// extractMetadataFromTempContextWithProgress extracts metadata with progress callbacks +func (s *Manager) extractMetadataFromTempContextWithProgress(contextPath string, progressCallback func(float64, string)) (*types.BlendMetadata, error) { // Create temporary directory for extraction under storage base path tmpDir, err := s.storage.TempDir("jiggablend-metadata-temp-*") if err != nil { @@ -1761,19 +1773,45 @@ func (s *Server) extractMetadataFromTempContext(contextPath string) (*types.Blen return nil, fmt.Errorf("no .blend file found in context - the uploaded context archive must contain at least one .blend file to render") } - // Use the same extraction script and process as extractMetadataFromContext - // (Copy the logic from extractMetadataFromContext but use tmpDir and blendFile) - // Log stderr for debugging (not shown to user) - stderrCallback := func(line string) { - log.Printf("Blender stderr during metadata extraction: %s", line) + // Detect Blender version from blend file header BEFORE running Blender + // This allows us to use the correct Blender version for metadata extraction + detectedVersion := "" + major, minor, versionErr := ParseBlenderVersionFromFile(blendFile) + if versionErr == nil { + detectedVersion = fmt.Sprintf("%d.%d", major, minor) + log.Printf("Detected Blender version %s from blend file header", detectedVersion) + } else { + log.Printf("Warning: Could not detect Blender version from blend file: %v", versionErr) } - return s.runBlenderMetadataExtraction(blendFile, tmpDir, stderrCallback) + // Use the same extraction script and process as extractMetadataFromContext + // (Copy the logic from extractMetadataFromContext but use tmpDir and blendFile) + metadata, err := s.runBlenderMetadataExtraction(blendFile, tmpDir, detectedVersion, nil, progressCallback) + if err != nil { + return nil, err + } + + // Set the detected/resolved Blender version in metadata + if metadata != nil && detectedVersion != "" { + // Get the latest patch version for this major.minor + version, verr := s.GetLatestBlenderForMajorMinor(major, minor) + if verr == nil { + metadata.BlenderVersion = version.Full + log.Printf("Resolved Blender version to %s", version.Full) + } else { + metadata.BlenderVersion = detectedVersion + log.Printf("Using detected version %s (could not resolve latest: %v)", detectedVersion, verr) + } + } + + return metadata, nil } // runBlenderMetadataExtraction runs Blender to extract metadata from a blend file +// blenderVersion is optional - if provided, will use versioned blender from cache // stderrCallback is optional and will be called for each stderr line (note: with RunCommand, this is called after completion) -func (s *Server) runBlenderMetadataExtraction(blendFile, workDir string, stderrCallback func(string)) (*types.BlendMetadata, error) { +// progressCallback is optional and will be called with progress updates (0.0-1.0) +func (s *Manager) runBlenderMetadataExtraction(blendFile, workDir, blenderVersion string, stderrCallback func(string), progressCallback func(float64, string)) (*types.BlendMetadata, error) { // Use embedded Python script scriptPath := filepath.Join(workDir, "extract_metadata.py") if err := os.WriteFile(scriptPath, []byte(scripts.ExtractMetadata), 0644); err != nil { @@ -1786,9 +1824,56 @@ func (s *Server) runBlenderMetadataExtraction(blendFile, workDir string, stderrC return nil, fmt.Errorf("failed to get relative path for blend file: %w", err) } + // Determine which blender binary to use + blenderBinary := "blender" // Default to system blender + var version *BlenderVersion // Track version for cleanup + if blenderVersion != "" { + // Try to get versioned blender from cache + var major, minor int + fmt.Sscanf(blenderVersion, "%d.%d", &major, &minor) + version, err = s.GetLatestBlenderForMajorMinor(major, minor) + if err == nil { + archivePath, err := s.GetBlenderArchivePath(version) + if err == nil { + // Extract to temp location for manager-side metadata extraction + blenderDir := filepath.Join(s.storage.BasePath(), "blender-versions") + binaryPath := filepath.Join(blenderDir, version.Full, "blender") + // Make path absolute to avoid working directory issues + if absBinaryPath, absErr := filepath.Abs(binaryPath); absErr == nil { + binaryPath = absBinaryPath + } + if _, err := os.Stat(binaryPath); os.IsNotExist(err) { + // Need to extract + if progressCallback != nil { + progressCallback(0.5, "Extracting Blender binary...") + } + if err := extractBlenderArchive(archivePath, version, blenderDir); err == nil { + blenderBinary = binaryPath + log.Printf("Using Blender %s at %s for metadata extraction", version.Full, binaryPath) + if progressCallback != nil { + progressCallback(0.7, "Blender extracted, extracting metadata...") + } + } else { + log.Printf("Warning: Failed to extract Blender %s: %v, using system blender", version.Full, err) + } + } else { + blenderBinary = binaryPath + log.Printf("Using cached Blender %s at %s for metadata extraction", version.Full, binaryPath) + if progressCallback != nil { + progressCallback(0.7, "Extracting metadata from blend file...") + } + } + } else { + log.Printf("Warning: Failed to get Blender archive for %s: %v, using system blender", version.Full, err) + } + } else { + log.Printf("Warning: Failed to find Blender version %s: %v, using system blender", blenderVersion, err) + } + } + // Execute Blender using executils result, err := executils.RunCommand( - "blender", + blenderBinary, []string{"-b", blendFileRel, "--python", "extract_metadata.py"}, workDir, nil, // inherit environment @@ -1843,7 +1928,7 @@ func (s *Server) runBlenderMetadataExtraction(blendFile, workDir string, stderrC } // createContextFromDir creates a context archive from a source directory to a specific destination path -func (s *Server) createContextFromDir(sourceDir, destPath string, excludeFiles ...string) (string, error) { +func (s *Manager) createContextFromDir(sourceDir, destPath string, excludeFiles ...string) (string, error) { // Build set of files to exclude excludeSet := make(map[string]bool) for _, excludeFile := range excludeFiles { @@ -2022,7 +2107,7 @@ func (s *Server) createContextFromDir(sourceDir, destPath string, excludeFiles . } // handleListJobFiles lists files for a job with pagination -func (s *Server) handleListJobFiles(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleListJobFiles(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -2158,7 +2243,7 @@ func (s *Server) handleListJobFiles(w http.ResponseWriter, r *http.Request) { } // handleGetJobFilesCount returns the count of files for a job -func (s *Server) handleGetJobFilesCount(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleGetJobFilesCount(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -2225,7 +2310,7 @@ func (s *Server) handleGetJobFilesCount(w http.ResponseWriter, r *http.Request) // handleListContextArchive lists files inside the context archive // Optimized to only read tar headers, skipping file data for fast directory listing -func (s *Server) handleListContextArchive(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleListContextArchive(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -2412,7 +2497,7 @@ func parseTarHeader(buf []byte, h *tar.Header) error { } // handleDownloadJobFile downloads a job file -func (s *Server) handleDownloadJobFile(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleDownloadJobFile(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -2521,8 +2606,159 @@ func (s *Server) handleDownloadJobFile(w http.ResponseWriter, r *http.Request) { io.Copy(w, file) } +// handlePreviewEXR converts an EXR file to PNG for browser preview +// Uses ImageMagick to convert with HDR tone mapping and alpha preservation +func (s *Manager) handlePreviewEXR(w http.ResponseWriter, r *http.Request) { + userID, err := getUserID(r) + if err != nil { + s.respondError(w, http.StatusUnauthorized, err.Error()) + return + } + + jobID, err := parseID(r, "id") + if err != nil { + s.respondError(w, http.StatusBadRequest, err.Error()) + return + } + + fileID, err := parseID(r, "fileId") + if err != nil { + s.respondError(w, http.StatusBadRequest, err.Error()) + return + } + + // Verify job belongs to user (unless admin) + isAdmin := isAdminUser(r) + if !isAdmin { + var jobUserID int64 + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + }) + if err == sql.ErrNoRows { + s.respondError(w, http.StatusNotFound, "Job not found") + return + } + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify job: %v", err)) + return + } + if jobUserID != userID { + s.respondError(w, http.StatusForbidden, "Access denied") + return + } + } else { + // Admin: verify job exists + var exists bool + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM jobs WHERE id = ?)", jobID).Scan(&exists) + }) + if err != nil || !exists { + s.respondError(w, http.StatusNotFound, "Job not found") + return + } + } + + // Get file info + var filePath, fileName string + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + `SELECT file_path, file_name FROM job_files WHERE id = ? AND job_id = ?`, + fileID, jobID, + ).Scan(&filePath, &fileName) + }) + if err == sql.ErrNoRows { + s.respondError(w, http.StatusNotFound, "File not found") + return + } + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query file: %v", err)) + return + } + + // Verify it's an EXR file + if !strings.HasSuffix(strings.ToLower(fileName), ".exr") { + s.respondError(w, http.StatusBadRequest, "File is not an EXR file") + return + } + + // Check if source file exists + if !s.storage.FileExists(filePath) { + s.respondError(w, http.StatusNotFound, "File not found on disk") + return + } + + // Create temp file for PNG output + tmpFile, err := os.CreateTemp("", "exr-preview-*.png") + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create temp file: %v", err)) + return + } + tmpPath := tmpFile.Name() + tmpFile.Close() + defer os.Remove(tmpPath) + + // Convert EXR to PNG using ImageMagick + // -colorspace sRGB: Convert from linear RGB to sRGB (matches SDR encoding pipeline) + // -depth 16: Use 16-bit depth for better quality + // -alpha on: Preserve alpha channel + // Note: Removed -auto-level to avoid automatic tone mapping that changes colors + result, err := executils.RunCommand( + "magick", + []string{ + filePath, + "-colorspace", "sRGB", + "-depth", "16", + "-alpha", "on", + tmpPath, + }, + "", // dir + nil, // env + 0, // taskID + nil, // tracker + ) + + if err != nil { + // Try with 'convert' command (older ImageMagick) + result, err = executils.RunCommand( + "convert", + []string{ + filePath, + "-colorspace", "sRGB", + "-depth", "16", + "-alpha", "on", + tmpPath, + }, + "", // dir + nil, // env + 0, // taskID + nil, // tracker + ) + if err != nil { + log.Printf("EXR conversion failed: %v, output: %s %s", err, result.Stdout, result.Stderr) + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to convert EXR: %v", err)) + return + } + } + + // Read the converted PNG + pngData, err := os.ReadFile(tmpPath) + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to read converted file: %v", err)) + return + } + + // Set headers + pngFileName := strings.TrimSuffix(fileName, filepath.Ext(fileName)) + ".png" + w.Header().Set("Content-Disposition", fmt.Sprintf("inline; filename=%s", pngFileName)) + w.Header().Set("Content-Type", "image/png") + w.Header().Set("Content-Length", strconv.Itoa(len(pngData))) + + // Write response + w.Write(pngData) +} + // handleStreamVideo streams MP4 video file with range support -func (s *Server) handleStreamVideo(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleStreamVideo(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -2627,7 +2863,7 @@ func (s *Server) handleStreamVideo(w http.ResponseWriter, r *http.Request) { } // handleListJobTasks lists all tasks for a job with pagination and filtering -func (s *Server) handleListJobTasks(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleListJobTasks(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -2691,12 +2927,12 @@ func (s *Server) handleListJobTasks(w http.ResponseWriter, r *http.Request) { frameEndFilter := r.URL.Query().Get("frame_end") sortBy := r.URL.Query().Get("sort") if sortBy == "" { - sortBy = "frame_start:asc" + sortBy = "frame:asc" } // Parse sort parameter sortParts := strings.Split(sortBy, ":") - sortField := "frame_start" + sortField := "frame" sortDir := "ASC" if len(sortParts) == 2 { sortField = sortParts[0] @@ -2705,16 +2941,16 @@ func (s *Server) handleListJobTasks(w http.ResponseWriter, r *http.Request) { sortDir = "ASC" } validFields := map[string]bool{ - "frame_start": true, "frame_end": true, "status": true, + "frame": true, "status": true, "created_at": true, "started_at": true, "completed_at": true, } if !validFields[sortField] { - sortField = "frame_start" + sortField = "frame" } } // Build query with filters - query := `SELECT id, job_id, runner_id, frame_start, frame_end, status, task_type, + query := `SELECT id, job_id, runner_id, frame, status, task_type, current_step, retry_count, max_retries, output_path, created_at, started_at, completed_at, error_message, timeout_seconds FROM tasks WHERE job_id = ?` @@ -2732,14 +2968,14 @@ func (s *Server) handleListJobTasks(w http.ResponseWriter, r *http.Request) { if frameStartFilter != "" { if fs, err := strconv.Atoi(frameStartFilter); err == nil { - query += " AND frame_start >= ?" + query += " AND frame >= ?" args = append(args, fs) } } if frameEndFilter != "" { if fe, err := strconv.Atoi(frameEndFilter); err == nil { - query += " AND frame_end <= ?" + query += " AND frame <= ?" args = append(args, fe) } } @@ -2770,13 +3006,13 @@ func (s *Server) handleListJobTasks(w http.ResponseWriter, r *http.Request) { } if frameStartFilter != "" { if fs, err := strconv.Atoi(frameStartFilter); err == nil { - countQuery += " AND frame_start >= ?" + countQuery += " AND frame >= ?" countArgs = append(countArgs, fs) } } if frameEndFilter != "" { if fe, err := strconv.Atoi(frameEndFilter); err == nil { - countQuery += " AND frame_end <= ?" + countQuery += " AND frame <= ?" countArgs = append(countArgs, fe) } } @@ -2803,7 +3039,7 @@ func (s *Server) handleListJobTasks(w http.ResponseWriter, r *http.Request) { var outputPath sql.NullString err := rows.Scan( - &task.ID, &task.JobID, &runnerID, &task.FrameStart, &task.FrameEnd, + &task.ID, &task.JobID, &runnerID, &task.Frame, &task.Status, &task.TaskType, ¤tStep, &task.RetryCount, &task.MaxRetries, &outputPath, &task.CreatedAt, &startedAt, &completedAt, &errorMessage, &timeoutSeconds, @@ -2859,7 +3095,7 @@ func (s *Server) handleListJobTasks(w http.ResponseWriter, r *http.Request) { } // handleListJobTasksSummary lists lightweight task summaries for a job -func (s *Server) handleListJobTasksSummary(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleListJobTasksSummary(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -2920,11 +3156,11 @@ func (s *Server) handleListJobTasksSummary(w http.ResponseWriter, r *http.Reques statusFilter := r.URL.Query().Get("status") sortBy := r.URL.Query().Get("sort") if sortBy == "" { - sortBy = "frame_start:asc" + sortBy = "frame:asc" } sortParts := strings.Split(sortBy, ":") - sortField := "frame_start" + sortField := "frame" sortDir := "ASC" if len(sortParts) == 2 { sortField = sortParts[0] @@ -2933,15 +3169,15 @@ func (s *Server) handleListJobTasksSummary(w http.ResponseWriter, r *http.Reques sortDir = "ASC" } validFields := map[string]bool{ - "frame_start": true, "frame_end": true, "status": true, + "frame": true, "status": true, } if !validFields[sortField] { - sortField = "frame_start" + sortField = "frame" } } // Build query - only select summary fields - query := `SELECT id, frame_start, frame_end, status, task_type, runner_id + query := `SELECT id, frame, status, task_type, runner_id FROM tasks WHERE job_id = ?` args := []interface{}{jobID} @@ -3001,12 +3237,11 @@ func (s *Server) handleListJobTasksSummary(w http.ResponseWriter, r *http.Reques defer rows.Close() type TaskSummary struct { - ID int64 `json:"id"` - FrameStart int `json:"frame_start"` - FrameEnd int `json:"frame_end"` - Status string `json:"status"` - TaskType string `json:"task_type"` - RunnerID *int64 `json:"runner_id,omitempty"` + ID int64 `json:"id"` + Frame int `json:"frame"` + Status string `json:"status"` + TaskType string `json:"task_type"` + RunnerID *int64 `json:"runner_id,omitempty"` } summaries := []TaskSummary{} @@ -3015,7 +3250,7 @@ func (s *Server) handleListJobTasksSummary(w http.ResponseWriter, r *http.Reques var runnerID sql.NullInt64 err := rows.Scan( - &summary.ID, &summary.FrameStart, &summary.FrameEnd, + &summary.ID, &summary.Frame, &summary.Status, &summary.TaskType, &runnerID, ) if err != nil { @@ -3040,7 +3275,7 @@ func (s *Server) handleListJobTasksSummary(w http.ResponseWriter, r *http.Reques } // handleBatchGetTasks fetches multiple tasks by IDs for a job -func (s *Server) handleBatchGetTasks(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleBatchGetTasks(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -3110,10 +3345,10 @@ func (s *Server) handleBatchGetTasks(w http.ResponseWriter, r *http.Request) { args[i+1] = taskID } - query := fmt.Sprintf(`SELECT id, job_id, runner_id, frame_start, frame_end, status, task_type, + query := fmt.Sprintf(`SELECT id, job_id, runner_id, frame, status, task_type, current_step, retry_count, max_retries, output_path, created_at, started_at, completed_at, error_message, timeout_seconds - FROM tasks WHERE job_id = ? AND id IN (%s) ORDER BY frame_start ASC`, strings.Join(placeholders, ",")) + FROM tasks WHERE job_id = ? AND id IN (%s) ORDER BY frame ASC`, strings.Join(placeholders, ",")) var rows *sql.Rows err = s.db.With(func(conn *sql.DB) error { @@ -3138,7 +3373,7 @@ func (s *Server) handleBatchGetTasks(w http.ResponseWriter, r *http.Request) { var outputPath sql.NullString err := rows.Scan( - &task.ID, &task.JobID, &runnerID, &task.FrameStart, &task.FrameEnd, + &task.ID, &task.JobID, &runnerID, &task.Frame, &task.Status, &task.TaskType, ¤tStep, &task.RetryCount, &task.MaxRetries, &outputPath, &task.CreatedAt, &startedAt, &completedAt, &errorMessage, &timeoutSeconds, @@ -3178,7 +3413,7 @@ func (s *Server) handleBatchGetTasks(w http.ResponseWriter, r *http.Request) { } // handleGetTaskLogs retrieves logs for a specific task -func (s *Server) handleGetTaskLogs(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleGetTaskLogs(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -3327,7 +3562,7 @@ func (s *Server) handleGetTaskLogs(w http.ResponseWriter, r *http.Request) { } // handleGetTaskSteps retrieves step timeline for a specific task -func (s *Server) handleGetTaskSteps(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleGetTaskSteps(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -3445,7 +3680,7 @@ func (s *Server) handleGetTaskSteps(w http.ResponseWriter, r *http.Request) { } // handleRetryTask retries a failed task -func (s *Server) handleRetryTask(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleRetryTask(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -3523,6 +3758,15 @@ func (s *Server) handleRetryTask(w http.ResponseWriter, r *http.Request) { WHERE id = ?`, types.TaskStatusPending, taskID, ) + if err != nil { + return err + } + // Clear steps and logs for fresh retry + _, err = conn.Exec(`DELETE FROM task_steps WHERE task_id = ?`, taskID) + if err != nil { + return err + } + _, err = conn.Exec(`DELETE FROM task_logs WHERE task_id = ?`, taskID) return err }) if err != nil { @@ -3530,12 +3774,14 @@ func (s *Server) handleRetryTask(w http.ResponseWriter, r *http.Request) { return } - // Broadcast task update - s.broadcastTaskUpdate(jobID, taskID, "task_update", map[string]interface{}{ + // Broadcast task reset to clients (includes steps_cleared and logs_cleared flags) + s.broadcastTaskUpdate(jobID, taskID, "task_reset", map[string]interface{}{ "status": types.TaskStatusPending, "runner_id": nil, "current_step": nil, "error_message": nil, + "steps_cleared": true, + "logs_cleared": true, }) s.respondJSON(w, http.StatusOK, map[string]string{"message": "Task queued for retry"}) @@ -3543,7 +3789,7 @@ func (s *Server) handleRetryTask(w http.ResponseWriter, r *http.Request) { // handleStreamTaskLogsWebSocket streams task logs via WebSocket // Note: This is called after auth middleware, so userID is already verified -func (s *Server) handleStreamTaskLogsWebSocket(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleStreamTaskLogsWebSocket(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { http.Error(w, "Unauthorized", http.StatusUnauthorized) @@ -3758,7 +4004,7 @@ func (s *Server) handleStreamTaskLogsWebSocket(w http.ResponseWriter, r *http.Re } // handleClientWebSocket handles the unified client WebSocket connection with subscription protocol -func (s *Server) handleClientWebSocket(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleClientWebSocket(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { http.Error(w, "Unauthorized", http.StatusUnauthorized) @@ -3776,46 +4022,31 @@ func (s *Server) handleClientWebSocket(w http.ResponseWriter, r *http.Request) { } defer conn.Close() + // Generate unique connection ID for this tab/connection + connNum := atomic.AddUint64(&s.connIDCounter, 1) + connID := fmt.Sprintf("%d:%d", userID, connNum) + // Create client connection clientConn := &ClientConnection{ Conn: conn, UserID: userID, + ConnID: connID, IsAdmin: isAdmin, Subscriptions: make(map[string]bool), WriteMu: &sync.Mutex{}, } - // Register connection - // Fix race condition: Close old connection BEFORE registering new one - var oldConn *ClientConnection + // Register connection (no need to close old - multiple connections per user are allowed) s.clientConnsMu.Lock() - if existingConn, exists := s.clientConns[userID]; exists && existingConn != nil { - oldConn = existingConn - } + s.clientConns[connID] = clientConn s.clientConnsMu.Unlock() - - // Close old connection BEFORE registering new one to prevent race conditions - if oldConn != nil { - log.Printf("handleClientWebSocket: Closing existing connection for user %d", userID) - oldConn.Conn.Close() - } - - // Now register the new connection - s.clientConnsMu.Lock() - s.clientConns[userID] = clientConn - s.clientConnsMu.Unlock() - log.Printf("handleClientWebSocket: Registered client connection for user %d", userID) + log.Printf("handleClientWebSocket: Registered client connection %s for user %d", connID, userID) defer func() { s.clientConnsMu.Lock() - // Only remove if this is still the current connection (not replaced by a newer one) - if existingConn, exists := s.clientConns[userID]; exists && existingConn == clientConn { - delete(s.clientConns, userID) - log.Printf("handleClientWebSocket: Removed client connection for user %d", userID) - } else { - log.Printf("handleClientWebSocket: Skipping removal for user %d (connection was replaced)", userID) - } + delete(s.clientConns, connID) s.clientConnsMu.Unlock() + log.Printf("handleClientWebSocket: Removed client connection %s for user %d", connID, userID) }() // Send initial connection message @@ -3831,14 +4062,14 @@ func (s *Server) handleClientWebSocket(w http.ResponseWriter, r *http.Request) { } // Set up ping/pong - conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Increased timeout + conn.SetReadDeadline(time.Now().Add(WSReadDeadline)) conn.SetPongHandler(func(string) error { - conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Reset deadline on pong + conn.SetReadDeadline(time.Now().Add(WSReadDeadline)) // Reset deadline on pong return nil }) - // Start ping ticker (send ping every 30 seconds) - ticker := time.NewTicker(30 * time.Second) + // Start ping ticker + ticker := time.NewTicker(WSPingInterval) defer ticker.Stop() // Message handling channel - increased buffer size to prevent blocking @@ -3849,7 +4080,7 @@ func (s *Server) handleClientWebSocket(w http.ResponseWriter, r *http.Request) { go func() { defer close(readDone) for { - conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Increased timeout + conn.SetReadDeadline(time.Now().Add(WSReadDeadline)) // Increased timeout messageType, message, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { @@ -3863,14 +4094,14 @@ func (s *Server) handleClientWebSocket(w http.ResponseWriter, r *http.Request) { // Handle control frames (pong, ping, close) if messageType == websocket.PongMessage { // Pong received - connection is alive, reset deadline - conn.SetReadDeadline(time.Now().Add(90 * time.Second)) + conn.SetReadDeadline(time.Now().Add(WSReadDeadline)) continue } if messageType == websocket.PingMessage { // Respond to ping with pong - conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + conn.SetWriteDeadline(time.Now().Add(WSWriteDeadline)) conn.WriteMessage(websocket.PongMessage, message) - conn.SetReadDeadline(time.Now().Add(90 * time.Second)) + conn.SetReadDeadline(time.Now().Add(WSReadDeadline)) continue } if messageType != websocket.TextMessage { @@ -3885,7 +4116,7 @@ func (s *Server) handleClientWebSocket(w http.ResponseWriter, r *http.Request) { continue } messageChan <- msg - conn.SetReadDeadline(time.Now().Add(90 * time.Second)) + conn.SetReadDeadline(time.Now().Add(WSReadDeadline)) } }() @@ -3902,10 +4133,10 @@ func (s *Server) handleClientWebSocket(w http.ResponseWriter, r *http.Request) { s.handleClientMessage(clientConn, msg) case <-ticker.C: // Reset read deadline before sending ping to ensure we can receive pong - conn.SetReadDeadline(time.Now().Add(90 * time.Second)) + conn.SetReadDeadline(time.Now().Add(WSReadDeadline)) clientConn.WriteMu.Lock() // Use WriteControl for ping frames (control frames) - if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { + if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(WSWriteDeadline)); err != nil { log.Printf("handleClientWebSocket: Ping failed for user %d: %v", userID, err) clientConn.WriteMu.Unlock() return @@ -3916,7 +4147,7 @@ func (s *Server) handleClientWebSocket(w http.ResponseWriter, r *http.Request) { } // handleClientMessage processes messages from client WebSocket -func (s *Server) handleClientMessage(clientConn *ClientConnection, msg map[string]interface{}) { +func (s *Manager) handleClientMessage(clientConn *ClientConnection, msg map[string]interface{}) { msgType, ok := msg["type"].(string) if !ok { return @@ -4008,7 +4239,7 @@ func (s *Server) handleClientMessage(clientConn *ClientConnection, msg map[strin } // canSubscribe checks if a client can subscribe to a channel -func (s *Server) canSubscribe(clientConn *ClientConnection, channel string) bool { +func (s *Manager) canSubscribe(clientConn *ClientConnection, channel string) bool { // Always allow jobs channel (always broadcasted, but subscription doesn't hurt) if channel == "jobs" { return true @@ -4081,7 +4312,7 @@ func (s *Server) canSubscribe(clientConn *ClientConnection, channel string) bool } // sendInitialState sends the current state when a client subscribes to a channel -func (s *Server) sendInitialState(clientConn *ClientConnection, channel string) { +func (s *Manager) sendInitialState(clientConn *ClientConnection, channel string) { // Use a shorter write deadline for initial state to avoid blocking too long // If the connection is slow/dead, we want to fail fast writeTimeout := 5 * time.Second @@ -4108,9 +4339,7 @@ func (s *Server) sendInitialState(clientConn *ClientConnection, channel string) var errorMessage sql.NullString var frameStart, frameEnd sql.NullInt64 var outputFormat sql.NullString - var allowParallelRunners sql.NullBool - - query := "SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format, allow_parallel_runners, blend_metadata, created_at, started_at, completed_at, error_message FROM jobs WHERE id = ?" + query := "SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format, blend_metadata, created_at, started_at, completed_at, error_message FROM jobs WHERE id = ?" if !clientConn.IsAdmin { query += " AND user_id = ?" } @@ -4120,13 +4349,13 @@ func (s *Server) sendInitialState(clientConn *ClientConnection, channel string) if clientConn.IsAdmin { return conn.QueryRow(query, jobID).Scan( &job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress, - &frameStart, &frameEnd, &outputFormat, &allowParallelRunners, + &frameStart, &frameEnd, &outputFormat, &blendMetadataJSON, &job.CreatedAt, &startedAt, &completedAt, &errorMessage, ) } else { return conn.QueryRow(query, jobID, clientConn.UserID).Scan( &job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress, - &frameStart, &frameEnd, &outputFormat, &allowParallelRunners, + &frameStart, &frameEnd, &outputFormat, &blendMetadataJSON, &job.CreatedAt, &startedAt, &completedAt, &errorMessage, ) } @@ -4148,10 +4377,6 @@ func (s *Server) sendInitialState(clientConn *ClientConnection, channel string) of := outputFormat.String job.OutputFormat = &of } - if allowParallelRunners.Valid { - apr := allowParallelRunners.Bool - job.AllowParallelRunners = &apr - } if startedAt.Valid { job.StartedAt = &startedAt.Time } @@ -4181,10 +4406,10 @@ func (s *Server) sendInitialState(clientConn *ClientConnection, channel string) // Get and send tasks (no limit - send all) err = s.db.With(func(conn *sql.DB) error { rows, err2 := conn.Query( - `SELECT id, job_id, runner_id, frame_start, frame_end, status, task_type, + `SELECT id, job_id, runner_id, frame, status, task_type, current_step, retry_count, max_retries, output_path, created_at, started_at, completed_at, error_message, timeout_seconds - FROM tasks WHERE job_id = ? ORDER BY frame_start ASC`, + FROM tasks WHERE job_id = ? ORDER BY frame ASC`, jobID, ) if err2 != nil { @@ -4201,7 +4426,7 @@ func (s *Server) sendInitialState(clientConn *ClientConnection, channel string) var outputPath sql.NullString err := rows.Scan( - &task.ID, &task.JobID, &runnerID, &task.FrameStart, &task.FrameEnd, + &task.ID, &task.JobID, &runnerID, &task.Frame, &task.Status, &task.TaskType, ¤tStep, &task.RetryCount, &task.MaxRetries, &outputPath, &task.CreatedAt, &startedAt, &completedAt, &errorMessage, &timeoutSeconds, @@ -4397,7 +4622,7 @@ func (s *Server) sendInitialState(clientConn *ClientConnection, channel string) } // handleJobsWebSocket handles WebSocket connection for job list updates -func (s *Server) handleJobsWebSocket(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleJobsWebSocket(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { http.Error(w, "Unauthorized", http.StatusUnauthorized) @@ -4445,7 +4670,7 @@ func (s *Server) handleJobsWebSocket(w http.ResponseWriter, r *http.Request) { }) // Start ping ticker - ticker := time.NewTicker(30 * time.Second) + ticker := time.NewTicker(WSPingInterval) defer ticker.Stop() // Read messages in background to keep connection alive and handle pongs @@ -4479,7 +4704,7 @@ func (s *Server) handleJobsWebSocket(w http.ResponseWriter, r *http.Request) { // Reset read deadline before sending ping to ensure we can receive pong conn.SetReadDeadline(time.Now().Add(60 * time.Second)) // Use WriteControl for ping frames (control frames) - if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { + if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(WSWriteDeadline)); err != nil { return } } @@ -4487,7 +4712,7 @@ func (s *Server) handleJobsWebSocket(w http.ResponseWriter, r *http.Request) { } // handleJobWebSocket handles WebSocket connection for single job updates -func (s *Server) handleJobWebSocket(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleJobWebSocket(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { http.Error(w, "Unauthorized", http.StatusUnauthorized) @@ -4582,7 +4807,7 @@ func (s *Server) handleJobWebSocket(w http.ResponseWriter, r *http.Request) { }) // Start ping ticker - ticker := time.NewTicker(30 * time.Second) + ticker := time.NewTicker(WSPingInterval) defer ticker.Stop() // Read messages in background to keep connection alive and handle pongs @@ -4616,7 +4841,7 @@ func (s *Server) handleJobWebSocket(w http.ResponseWriter, r *http.Request) { // Reset read deadline before sending ping to ensure we can receive pong conn.SetReadDeadline(time.Now().Add(60 * time.Second)) // Use WriteControl for ping frames (control frames) - if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { + if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(WSWriteDeadline)); err != nil { return } } @@ -4624,7 +4849,7 @@ func (s *Server) handleJobWebSocket(w http.ResponseWriter, r *http.Request) { } // broadcastJobUpdate broadcasts job update to connected clients -func (s *Server) broadcastJobUpdate(jobID int64, updateType string, data interface{}) { +func (s *Manager) broadcastJobUpdate(jobID int64, updateType string, data interface{}) { // Get user_id from job var userID int64 err := s.db.With(func(conn *sql.DB) error { @@ -4668,15 +4893,18 @@ func (s *Server) broadcastJobUpdate(jobID int64, updateType string, data interfa } } - // Broadcast to client WebSocket if subscribed to job:{id} - channel := fmt.Sprintf("job:%d", jobID) - s.broadcastToClient(userID, channel, msg) + // Only broadcast if client is connected + if s.isClientConnected(userID) { + // Broadcast to client WebSocket if subscribed to job:{id} + channel := fmt.Sprintf("job:%d", jobID) + s.broadcastToClient(userID, channel, msg) + } // Also broadcast to old WebSocket connections (for backwards compatibility during migration) s.jobListConnsMu.RLock() if conn, exists := s.jobListConns[userID]; exists && conn != nil { s.jobListConnsMu.RUnlock() - conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + conn.SetWriteDeadline(time.Now().Add(WSWriteDeadline)) conn.WriteJSON(msg) } else { s.jobListConnsMu.RUnlock() @@ -4695,21 +4923,21 @@ func (s *Server) broadcastJobUpdate(jobID int64, updateType string, data interfa if hasMu && writeMu != nil { writeMu.Lock() - conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + conn.SetWriteDeadline(time.Now().Add(WSWriteDeadline)) err := conn.WriteJSON(msg) writeMu.Unlock() if err != nil { log.Printf("Failed to broadcast %s to job %d WebSocket: %v", updateType, jobID, err) } } else { - conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + conn.SetWriteDeadline(time.Now().Add(WSWriteDeadline)) conn.WriteJSON(msg) } } } // broadcastTaskUpdate broadcasts task update to connected clients -func (s *Server) broadcastTaskUpdate(jobID int64, taskID int64, updateType string, data interface{}) { +func (s *Manager) broadcastTaskUpdate(jobID int64, taskID int64, updateType string, data interface{}) { // Get user_id from job var userID int64 err := s.db.With(func(conn *sql.DB) error { @@ -4736,12 +4964,20 @@ func (s *Server) broadcastTaskUpdate(jobID int64, taskID int64, updateType strin } } - // Broadcast to client WebSocket if subscribed to job:{id} - channel := fmt.Sprintf("job:%d", jobID) - if s.verboseWSLogging { - log.Printf("broadcastTaskUpdate: Broadcasting %s for task %d (job %d, user %d) on channel %s, data=%+v", updateType, taskID, jobID, userID, channel, data) + // Only broadcast if client is connected + if !s.isClientConnected(userID) { + if s.verboseWSLogging { + log.Printf("broadcastTaskUpdate: Client %d not connected, skipping broadcast for task %d (job %d)", userID, taskID, jobID) + } + // Still broadcast to old WebSocket connections for backwards compatibility + } else { + // Broadcast to client WebSocket if subscribed to job:{id} + channel := fmt.Sprintf("job:%d", jobID) + if s.verboseWSLogging { + log.Printf("broadcastTaskUpdate: Broadcasting %s for task %d (job %d, user %d) on channel %s, data=%+v", updateType, taskID, jobID, userID, channel, data) + } + s.broadcastToClient(userID, channel, msg) } - s.broadcastToClient(userID, channel, msg) // Also broadcast to old WebSocket connection (for backwards compatibility during migration) key := fmt.Sprintf("%d:%d", userID, jobID) @@ -4756,71 +4992,101 @@ func (s *Server) broadcastTaskUpdate(jobID int64, taskID int64, updateType strin if hasMu && writeMu != nil { writeMu.Lock() - conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + conn.SetWriteDeadline(time.Now().Add(WSWriteDeadline)) conn.WriteJSON(msg) writeMu.Unlock() } else { - conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + conn.SetWriteDeadline(time.Now().Add(WSWriteDeadline)) conn.WriteJSON(msg) } } } -// broadcastToClient sends a message to a specific client connection -func (s *Server) broadcastToClient(userID int64, channel string, msg map[string]interface{}) { +// isClientConnected checks if a user has at least one active connection +func (s *Manager) isClientConnected(userID int64) bool { s.clientConnsMu.RLock() - clientConn, exists := s.clientConns[userID] - s.clientConnsMu.RUnlock() - - if !exists || clientConn == nil { - log.Printf("broadcastToClient: Client %d not connected (channel: %s)", userID, channel) - return + defer s.clientConnsMu.RUnlock() + for _, clientConn := range s.clientConns { + if clientConn != nil && clientConn.UserID == userID { + return true + } } + return false +} - // Check if client is subscribed to this channel (jobs channel is always sent) - if channel != "jobs" { - clientConn.SubsMu.RLock() - subscribed := clientConn.Subscriptions[channel] - allSubs := make([]string, 0, len(clientConn.Subscriptions)) - for ch := range clientConn.Subscriptions { - allSubs = append(allSubs, ch) - } - clientConn.SubsMu.RUnlock() - if !subscribed { - if s.verboseWSLogging { - log.Printf("broadcastToClient: Client %d not subscribed to channel %s (subscribed to: %v)", userID, channel, allSubs) - } - return +// getClientConnections returns all connections for a specific user +func (s *Manager) getClientConnections(userID int64) []*ClientConnection { + s.clientConnsMu.RLock() + defer s.clientConnsMu.RUnlock() + var conns []*ClientConnection + for _, clientConn := range s.clientConns { + if clientConn != nil && clientConn.UserID == userID { + conns = append(conns, clientConn) } + } + return conns +} + +// broadcastToClient sends a message to all connections for a specific user +func (s *Manager) broadcastToClient(userID int64, channel string, msg map[string]interface{}) { + conns := s.getClientConnections(userID) + + if len(conns) == 0 { + // Client not connected - this is normal, don't log it (only log at verbose level) if s.verboseWSLogging { - log.Printf("broadcastToClient: Client %d is subscribed to channel %s", userID, channel) + log.Printf("broadcastToClient: Client %d not connected (channel: %s)", userID, channel) } + return } // Add channel to message msg["channel"] = channel - // Log what we're sending - if s.verboseWSLogging { - log.Printf("broadcastToClient: Sending to client %d on channel %s: type=%v, job_id=%v, task_id=%v", - userID, channel, msg["type"], msg["job_id"], msg["task_id"]) + sentCount := 0 + var deadConns []string + for _, clientConn := range conns { + // Check if client is subscribed to this channel (jobs channel is always sent) + if channel != "jobs" { + clientConn.SubsMu.RLock() + subscribed := clientConn.Subscriptions[channel] + clientConn.SubsMu.RUnlock() + if !subscribed { + continue + } + } + + clientConn.WriteMu.Lock() + clientConn.Conn.SetWriteDeadline(time.Now().Add(WSWriteDeadline)) + if err := clientConn.Conn.WriteJSON(msg); err != nil { + // Mark connection for removal - don't spam logs, just remove dead connections + deadConns = append(deadConns, clientConn.ConnID) + } else { + sentCount++ + } + clientConn.WriteMu.Unlock() } - clientConn.WriteMu.Lock() - defer clientConn.WriteMu.Unlock() - - clientConn.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) - if err := clientConn.Conn.WriteJSON(msg); err != nil { - log.Printf("Failed to send message to client %d on channel %s: %v", userID, channel, err) - } else { - if s.verboseWSLogging { - log.Printf("broadcastToClient: Successfully sent message to client %d on channel %s", userID, channel) + // Remove dead connections + if len(deadConns) > 0 { + s.clientConnsMu.Lock() + for _, connID := range deadConns { + if conn, exists := s.clientConns[connID]; exists { + log.Printf("Removing dead connection %s for user %d (write failed)", connID, conn.UserID) + conn.Conn.Close() + delete(s.clientConns, connID) + } } + s.clientConnsMu.Unlock() + } + + if s.verboseWSLogging { + log.Printf("broadcastToClient: Sent to %d/%d connections for user %d on channel %s: type=%v", + sentCount, len(conns), userID, channel, msg["type"]) } } // broadcastToAllClients sends a message to all connected clients (for jobs channel) -func (s *Server) broadcastToAllClients(channel string, msg map[string]interface{}) { +func (s *Manager) broadcastToAllClients(channel string, msg map[string]interface{}) { msg["channel"] = channel s.clientConnsMu.RLock() @@ -4830,18 +5096,34 @@ func (s *Server) broadcastToAllClients(channel string, msg map[string]interface{ } s.clientConnsMu.RUnlock() + var deadConns []string for _, clientConn := range clients { clientConn.WriteMu.Lock() - clientConn.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + clientConn.Conn.SetWriteDeadline(time.Now().Add(WSWriteDeadline)) if err := clientConn.Conn.WriteJSON(msg); err != nil { - log.Printf("Failed to broadcast to client %d: %v", clientConn.UserID, err) + deadConns = append(deadConns, clientConn.ConnID) } clientConn.WriteMu.Unlock() } + + // Remove dead connections + if len(deadConns) > 0 { + s.clientConnsMu.Lock() + for _, connID := range deadConns { + if conn, exists := s.clientConns[connID]; exists { + log.Printf("Removing dead connection %s for user %d (write failed)", connID, conn.UserID) + conn.Conn.Close() + delete(s.clientConns, connID) + } + } + s.clientConnsMu.Unlock() + } } // broadcastUploadProgress broadcasts upload/processing progress to subscribed clients -func (s *Server) broadcastUploadProgress(sessionID string, progress float64, status, message string) { +// This function updates the session synchronously (quick operation) but broadcasts +// asynchronously to avoid blocking the upload handler on slow WebSocket writes. +func (s *Manager) broadcastUploadProgress(sessionID string, progress float64, status, message string) { s.uploadSessionsMu.RLock() session, exists := s.uploadSessions[sessionID] s.uploadSessionsMu.RUnlock() @@ -4850,11 +5132,52 @@ func (s *Server) broadcastUploadProgress(sessionID string, progress float64, sta return } - // Update session + // Update session synchronously (quick operation - just updating struct fields) s.uploadSessionsMu.Lock() session.Progress = progress session.Status = status session.Message = message + userID := session.UserID // Capture userID before releasing lock + s.uploadSessionsMu.Unlock() + + // Broadcast asynchronously to avoid blocking upload handler on slow WebSocket writes + // This prevents the entire HTTP server from freezing during large file uploads + go func() { + // Determine message type + msgType := "upload_progress" + if status != "uploading" { + msgType = "processing_status" + } + + msg := map[string]interface{}{ + "type": msgType, + "session_id": sessionID, + "data": map[string]interface{}{ + "progress": progress, + "status": status, + "message": message, + }, + "timestamp": time.Now().Unix(), + } + + // Only broadcast if client is connected + if s.isClientConnected(userID) { + channel := fmt.Sprintf("upload:%s", sessionID) + s.broadcastToClient(userID, channel, msg) + } + }() +} + +// broadcastUploadProgressSync sends upload progress synchronously (for completion messages) +// This ensures the message is sent immediately and not lost +func (s *Manager) broadcastUploadProgressSync(userID int64, sessionID string, progress float64, status, message string) { + // Update session synchronously + s.uploadSessionsMu.Lock() + if session, exists := s.uploadSessions[sessionID]; exists { + session.Progress = progress + session.Status = status + session.Message = message + } s.uploadSessionsMu.Unlock() // Determine message type @@ -4874,8 +5197,11 @@ func (s *Server) broadcastUploadProgress(sessionID string, progress float64, sta "timestamp": time.Now().Unix(), } - channel := fmt.Sprintf("upload:%s", sessionID) - s.broadcastToClient(session.UserID, channel, msg) + // Send synchronously to ensure delivery + if s.isClientConnected(userID) { + channel := fmt.Sprintf("upload:%s", sessionID) + s.broadcastToClient(userID, channel, msg) + } } // truncateString truncates a string to a maximum length, appending "..." if truncated diff --git a/internal/api/server.go b/internal/manager/manager.go similarity index 77% rename from internal/api/server.go rename to internal/manager/manager.go index 06639cf..f9658ea 100644 --- a/internal/api/server.go +++ b/internal/manager/manager.go @@ -9,6 +9,7 @@ import ( "log" "net/http" "os" + "os/exec" "path/filepath" "runtime" "strconv" @@ -37,12 +38,10 @@ const ( WSWriteDeadline = 10 * time.Second // Task timeouts - DefaultTaskTimeout = 300 // 5 minutes for frame rendering - VideoGenerationTimeout = 86400 // 24 hours for video generation - DefaultJobTimeout = 86400 // 24 hours + RenderTimeout = 60 * 60 // 1 hour for frame rendering + VideoEncodeTimeout = 60 * 60 * 24 // 24 hours for encoding // Limits - MaxFrameRange = 10000 MaxUploadSize = 50 << 30 // 50 GB RunnerHeartbeatTimeout = 90 * time.Second TaskDistributionInterval = 10 * time.Second @@ -52,8 +51,8 @@ const ( SessionCookieMaxAge = 86400 // 24 hours ) -// Server represents the API server -type Server struct { +// Manager represents the manager server +type Manager struct { db *database.DB cfg *config.Config auth *authpkg.Auth @@ -62,14 +61,9 @@ type Server struct { router *chi.Mux // WebSocket connections - wsUpgrader websocket.Upgrader - runnerConns map[int64]*websocket.Conn - runnerConnsMu sync.RWMutex - // Mutexes for each runner connection to serialize writes - runnerConnsWriteMu map[int64]*sync.Mutex - runnerConnsWriteMuMu sync.RWMutex + wsUpgrader websocket.Upgrader - // DEPRECATED: Old WebSocket connection maps (kept for backwards compatibility) + // DEPRECATED: Old frontend WebSocket connection maps (kept for backwards compatibility) // These will be removed in a future release. Use clientConns instead. frontendConns map[string]*websocket.Conn // key: "jobId:taskId" frontendConnsMu sync.RWMutex @@ -82,18 +76,25 @@ type Server struct { jobConnsWriteMu map[string]*sync.Mutex jobConnsWriteMuMu sync.RWMutex + // Per-job runner WebSocket connections (polling-based flow) + // Key is "job-{jobId}-task-{taskId}" + runnerJobConns map[string]*websocket.Conn + runnerJobConnsMu sync.RWMutex + runnerJobConnsWriteMu map[string]*sync.Mutex + runnerJobConnsWriteMuMu sync.RWMutex + // Throttling for progress updates (per job) progressUpdateTimes map[int64]time.Time // key: jobID progressUpdateTimesMu sync.RWMutex // Throttling for task status updates (per task) taskUpdateTimes map[int64]time.Time // key: taskID taskUpdateTimesMu sync.RWMutex - // Task distribution serialization - taskDistMu sync.Mutex // Mutex to prevent concurrent distribution // Client WebSocket connections (new unified WebSocket) - clientConns map[int64]*ClientConnection // key: userID + // Key is "userID:connID" to support multiple tabs per user + clientConns map[string]*ClientConnection clientConnsMu sync.RWMutex + connIDCounter uint64 // Atomic counter for generating unique connection IDs // Upload session tracking uploadSessions map[string]*UploadSession // sessionId -> session info @@ -110,6 +111,7 @@ type Server struct { type ClientConnection struct { Conn *websocket.Conn UserID int64 + ConnID string // Unique connection ID (userID:connID) IsAdmin bool Subscriptions map[string]bool // channel -> subscribed SubsMu sync.RWMutex // Protects Subscriptions map @@ -126,14 +128,14 @@ type UploadSession struct { CreatedAt time.Time } -// NewServer creates a new API server -func NewServer(db *database.DB, cfg *config.Config, auth *authpkg.Auth, storage *storage.Storage) (*Server, error) { +// NewManager creates a new manager server +func NewManager(db *database.DB, cfg *config.Config, auth *authpkg.Auth, storage *storage.Storage) (*Manager, error) { secrets, err := authpkg.NewSecrets(db, cfg) if err != nil { return nil, fmt.Errorf("failed to initialize secrets: %w", err) } - s := &Server{ + s := &Manager{ db: db, cfg: cfg, auth: auth, @@ -146,9 +148,7 @@ func NewServer(db *database.DB, cfg *config.Config, auth *authpkg.Auth, storage ReadBufferSize: 1024, WriteBufferSize: 1024, }, - runnerConns: make(map[int64]*websocket.Conn), - runnerConnsWriteMu: make(map[int64]*sync.Mutex), - // DEPRECATED: Initialize old WebSocket maps for backward compatibility + // DEPRECATED: Initialize old frontend WebSocket maps for backward compatibility frontendConns: make(map[string]*websocket.Conn), frontendConnsWriteMu: make(map[string]*sync.Mutex), jobListConns: make(map[int64]*websocket.Conn), @@ -156,8 +156,17 @@ func NewServer(db *database.DB, cfg *config.Config, auth *authpkg.Auth, storage jobConnsWriteMu: make(map[string]*sync.Mutex), progressUpdateTimes: make(map[int64]time.Time), taskUpdateTimes: make(map[int64]time.Time), - clientConns: make(map[int64]*ClientConnection), + clientConns: make(map[string]*ClientConnection), uploadSessions: make(map[string]*UploadSession), + // Per-job runner WebSocket connections + runnerJobConns: make(map[string]*websocket.Conn), + runnerJobConnsWriteMu: make(map[string]*sync.Mutex), + runnerJobConnsWriteMuMu: sync.RWMutex{}, // Initialize the new field + } + + // Check for required external tools + if err := s.checkRequiredTools(); err != nil { + return nil, err } s.setupMiddleware() @@ -171,6 +180,23 @@ func NewServer(db *database.DB, cfg *config.Config, auth *authpkg.Auth, storage return s, nil } +// checkRequiredTools verifies that required external tools are available +func (s *Manager) checkRequiredTools() error { + // Check for zstd (required for zstd-compressed blend files) + if err := exec.Command("zstd", "--version").Run(); err != nil { + return fmt.Errorf("zstd not found - required for compressed blend file support. Install with: apt install zstd") + } + log.Printf("Found zstd for compressed blend file support") + + // Check for xz (required for decompressing blender archives) + if err := exec.Command("xz", "--version").Run(); err != nil { + return fmt.Errorf("xz not found - required for decompressing blender archives. Install with: apt install xz-utils") + } + log.Printf("Found xz for blender archive decompression") + + return nil +} + // checkWebSocketOrigin validates WebSocket connection origins // In production mode, only allows same-origin connections or configured allowed origins func checkWebSocketOrigin(r *http.Request) bool { @@ -323,7 +349,7 @@ func rateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler { } // setupMiddleware configures middleware -func (s *Server) setupMiddleware() { +func (s *Manager) setupMiddleware() { s.router.Use(middleware.Logger) s.router.Use(middleware.Recoverer) // Note: Timeout middleware is NOT applied globally to avoid conflicts with WebSocket connections @@ -416,7 +442,7 @@ func (w *gzipResponseWriter) WriteHeader(statusCode int) { } // setupRoutes configures routes -func (s *Server) setupRoutes() { +func (s *Manager) setupRoutes() { // Health check endpoint (unauthenticated) s.router.Get("/api/health", s.handleHealthCheck) @@ -457,13 +483,13 @@ func (s *Server) setupRoutes() { r.Get("/{id}/files/count", s.handleGetJobFilesCount) r.Get("/{id}/context", s.handleListContextArchive) r.Get("/{id}/files/{fileId}/download", s.handleDownloadJobFile) + r.Get("/{id}/files/{fileId}/preview-exr", s.handlePreviewEXR) r.Get("/{id}/video", s.handleStreamVideo) r.Get("/{id}/metadata", s.handleGetJobMetadata) r.Get("/{id}/tasks", s.handleListJobTasks) r.Get("/{id}/tasks/summary", s.handleListJobTasksSummary) r.Post("/{id}/tasks/batch", s.handleBatchGetTasks) r.Get("/{id}/tasks/{taskId}/logs", s.handleGetTaskLogs) - // Old WebSocket route removed - use client WebSocket with subscriptions instead r.Get("/{id}/tasks/{taskId}/steps", s.handleGetTaskSteps) r.Post("/{id}/tasks/{taskId}/retry", s.handleRetryTask) // WebSocket route for unified client WebSocket @@ -510,38 +536,40 @@ func (s *Server) setupRoutes() { // Registration doesn't require auth (uses token) r.With(middleware.Timeout(60*time.Second)).Post("/register", s.handleRegisterRunner) - // WebSocket endpoint (auth handled in handler) - no timeout middleware - r.Get("/ws", s.handleRunnerWebSocket) + // Polling-based endpoints (auth handled in handlers) + r.Get("/workers/{id}/next-job", s.handleNextJob) - // File operations still use HTTP (WebSocket not suitable for large files) + // Per-job endpoints with job_token auth (no middleware, auth in handler) + r.Get("/jobs/{jobId}/ws", s.handleRunnerJobWebSocket) + r.Get("/jobs/{jobId}/context.tar", s.handleDownloadJobContextWithToken) + r.Post("/jobs/{jobId}/upload", s.handleUploadFileWithToken) + + // Runner API endpoints (uses API key auth) r.Group(func(r chi.Router) { r.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(s.runnerAuthMiddleware(next.ServeHTTP)) }) - r.Get("/ping", s.handleRunnerPing) - r.Post("/tasks/{id}/progress", s.handleUpdateTaskProgress) - r.Post("/tasks/{id}/steps", s.handleUpdateTaskStep) - r.Get("/jobs/{jobId}/context.tar", s.handleDownloadJobContext) - r.Get("/files/{jobId}/{fileName}", s.handleDownloadFileForRunner) - r.Post("/files/{jobId}/upload", s.handleUploadFileFromRunner) - r.Get("/jobs/{jobId}/status", s.handleGetJobStatusForRunner) + r.Get("/blender/download", s.handleDownloadBlender) r.Get("/jobs/{jobId}/files", s.handleGetJobFilesForRunner) r.Get("/jobs/{jobId}/metadata", s.handleGetJobMetadataForRunner) - r.Post("/jobs/{jobId}/metadata", s.handleSubmitMetadata) + r.Get("/files/{jobId}/{fileName}", s.handleDownloadFileForRunner) }) }) + // Blender versions API (public, for job submission page) + s.router.Get("/api/blender/versions", s.handleGetBlenderVersions) + // Serve static files (embedded React app with SPA fallback) s.router.Handle("/*", web.SPAHandler()) } // ServeHTTP implements http.Handler -func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (s *Manager) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.router.ServeHTTP(w, r) } // JSON response helpers -func (s *Server) respondJSON(w http.ResponseWriter, status int, data interface{}) { +func (s *Manager) respondJSON(w http.ResponseWriter, status int, data interface{}) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) if err := json.NewEncoder(w).Encode(data); err != nil { @@ -549,7 +577,7 @@ func (s *Server) respondJSON(w http.ResponseWriter, status int, data interface{} } } -func (s *Server) respondError(w http.ResponseWriter, status int, message string) { +func (s *Manager) respondError(w http.ResponseWriter, status int, message string) { s.respondJSON(w, status, map[string]string{"error": message}) } @@ -573,7 +601,7 @@ func createSessionCookie(sessionID string) *http.Cookie { } // handleHealthCheck returns server health status -func (s *Server) handleHealthCheck(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleHealthCheck(w http.ResponseWriter, r *http.Request) { // Check database connectivity dbHealthy := true if err := s.db.Ping(); err != nil { @@ -581,10 +609,14 @@ func (s *Server) handleHealthCheck(w http.ResponseWriter, r *http.Request) { log.Printf("Health check: database ping failed: %v", err) } - // Count connected runners - s.runnerConnsMu.RLock() - runnerCount := len(s.runnerConns) - s.runnerConnsMu.RUnlock() + // Count online runners (based on recent heartbeat) + var runnerCount int + s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + `SELECT COUNT(*) FROM runners WHERE status = ?`, + types.RunnerStatusOnline, + ).Scan(&runnerCount) + }) // Count connected clients s.clientConnsMu.RLock() @@ -624,7 +656,7 @@ func (s *Server) handleHealthCheck(w http.ResponseWriter, r *http.Request) { } // Auth handlers -func (s *Server) handleGoogleLogin(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleGoogleLogin(w http.ResponseWriter, r *http.Request) { url, err := s.auth.GoogleLoginURL() if err != nil { s.respondError(w, http.StatusInternalServerError, err.Error()) @@ -633,7 +665,7 @@ func (s *Server) handleGoogleLogin(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, url, http.StatusFound) } -func (s *Server) handleGoogleCallback(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleGoogleCallback(w http.ResponseWriter, r *http.Request) { code := r.URL.Query().Get("code") if code == "" { s.respondError(w, http.StatusBadRequest, "Missing code parameter") @@ -657,7 +689,7 @@ func (s *Server) handleGoogleCallback(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/", http.StatusFound) } -func (s *Server) handleDiscordLogin(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleDiscordLogin(w http.ResponseWriter, r *http.Request) { url, err := s.auth.DiscordLoginURL() if err != nil { s.respondError(w, http.StatusInternalServerError, err.Error()) @@ -666,7 +698,7 @@ func (s *Server) handleDiscordLogin(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, url, http.StatusFound) } -func (s *Server) handleDiscordCallback(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleDiscordCallback(w http.ResponseWriter, r *http.Request) { code := r.URL.Query().Get("code") if code == "" { s.respondError(w, http.StatusBadRequest, "Missing code parameter") @@ -690,7 +722,7 @@ func (s *Server) handleDiscordCallback(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/", http.StatusFound) } -func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleLogout(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("session_id") if err == nil { s.auth.DeleteSession(cookie.Value) @@ -712,7 +744,7 @@ func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) { s.respondJSON(w, http.StatusOK, map[string]string{"message": "Logged out"}) } -func (s *Server) handleGetMe(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleGetMe(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("session_id") if err != nil { log.Printf("Authentication failed: missing session cookie in /auth/me") @@ -735,7 +767,7 @@ func (s *Server) handleGetMe(w http.ResponseWriter, r *http.Request) { }) } -func (s *Server) handleGetAuthProviders(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleGetAuthProviders(w http.ResponseWriter, r *http.Request) { s.respondJSON(w, http.StatusOK, map[string]bool{ "google": s.auth.IsGoogleOAuthConfigured(), "discord": s.auth.IsDiscordOAuthConfigured(), @@ -743,13 +775,13 @@ func (s *Server) handleGetAuthProviders(w http.ResponseWriter, r *http.Request) }) } -func (s *Server) handleLocalLoginAvailable(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleLocalLoginAvailable(w http.ResponseWriter, r *http.Request) { s.respondJSON(w, http.StatusOK, map[string]bool{ "available": s.auth.IsLocalLoginEnabled(), }) } -func (s *Server) handleLocalRegister(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleLocalRegister(w http.ResponseWriter, r *http.Request) { var req struct { Email string `json:"email"` Name string `json:"name"` @@ -791,7 +823,7 @@ func (s *Server) handleLocalRegister(w http.ResponseWriter, r *http.Request) { }) } -func (s *Server) handleLocalLogin(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleLocalLogin(w http.ResponseWriter, r *http.Request) { var req struct { Username string `json:"username"` Password string `json:"password"` @@ -828,7 +860,7 @@ func (s *Server) handleLocalLogin(w http.ResponseWriter, r *http.Request) { }) } -func (s *Server) handleChangePassword(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleChangePassword(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -902,7 +934,7 @@ func parseID(r *http.Request, param string) (int64, error) { } // StartBackgroundTasks starts background goroutines for error recovery -func (s *Server) StartBackgroundTasks() { +func (s *Manager) StartBackgroundTasks() { go s.recoverStuckTasks() go s.cleanupOldRenderJobs() go s.cleanupOldTempDirectories() @@ -910,100 +942,63 @@ func (s *Server) StartBackgroundTasks() { go s.cleanupOldUploadSessions() } -// recoverRunnersOnStartup checks for runners marked as online but not actually connected -// This runs once on startup to handle manager restarts where we lose track of connections -func (s *Server) recoverRunnersOnStartup() { - // Wait a short time for runners to reconnect after manager restart - // This gives runners a chance to reconnect before we mark them as dead - time.Sleep(5 * time.Second) +// recoverRunnersOnStartup marks runners as offline on startup +// In the polling model, runners will update their status when they poll for jobs +func (s *Manager) recoverRunnersOnStartup() { + log.Printf("Recovering runners on startup: marking all as offline...") - log.Printf("Recovering runners on startup: checking for disconnected runners...") - - var onlineRunnerIDs []int64 + // Mark all runners as offline - they'll be marked online when they poll + var runnersAffected int64 err := s.db.With(func(conn *sql.DB) error { - rows, err := conn.Query( - `SELECT id FROM runners WHERE status = ?`, - types.RunnerStatusOnline, + result, err := conn.Exec( + `UPDATE runners SET status = ? WHERE status = ?`, + types.RunnerStatusOffline, types.RunnerStatusOnline, ) if err != nil { return err } - defer rows.Close() - - for rows.Next() { - var runnerID int64 - if err := rows.Scan(&runnerID); err == nil { - onlineRunnerIDs = append(onlineRunnerIDs, runnerID) - } - } + runnersAffected, _ = result.RowsAffected() return nil }) - if err != nil { - log.Printf("Failed to query online runners on startup: %v", err) + log.Printf("Failed to mark runners as offline on startup: %v", err) return } - if len(onlineRunnerIDs) == 0 { - log.Printf("No runners marked as online on startup") - return + if runnersAffected > 0 { + log.Printf("Marked %d runners as offline on startup", runnersAffected) } - log.Printf("Found %d runners marked as online, checking actual connections...", len(onlineRunnerIDs)) - - // Check which runners are actually connected - s.runnerConnsMu.RLock() - deadRunnerIDs := make([]int64, 0) - for _, runnerID := range onlineRunnerIDs { - if _, connected := s.runnerConns[runnerID]; !connected { - deadRunnerIDs = append(deadRunnerIDs, runnerID) + // Reset any running tasks that were assigned to runners + // They will be picked up by runners when they poll + var tasksAffected int64 + err = s.db.With(func(conn *sql.DB) error { + result, err := conn.Exec( + `UPDATE tasks SET runner_id = NULL, status = ?, started_at = NULL + WHERE status = ?`, + types.TaskStatusPending, types.TaskStatusRunning, + ) + if err != nil { + return err } - } - s.runnerConnsMu.RUnlock() - - if len(deadRunnerIDs) == 0 { - log.Printf("All runners marked as online are actually connected") + tasksAffected, _ = result.RowsAffected() + return nil + }) + if err != nil { + log.Printf("Failed to reset running tasks on startup: %v", err) return } - log.Printf("Found %d runners marked as online but not connected, redistributing their tasks...", len(deadRunnerIDs)) - - // Redistribute tasks for disconnected runners - for _, runnerID := range deadRunnerIDs { - log.Printf("Recovering runner %d: redistributing tasks and marking as offline", runnerID) - s.redistributeRunnerTasks(runnerID) - - // Mark runner as offline - s.db.With(func(conn *sql.DB) error { - _, _ = conn.Exec( - `UPDATE runners SET status = ?, last_heartbeat = ? WHERE id = ?`, - types.RunnerStatusOffline, time.Now(), runnerID, - ) - return nil - }) + if tasksAffected > 0 { + log.Printf("Reset %d running tasks to pending on startup", tasksAffected) } - - log.Printf("Startup recovery complete: redistributed tasks from %d disconnected runners", len(deadRunnerIDs)) - - // Trigger task distribution to assign recovered tasks to available runners - s.triggerTaskDistribution() } // recoverStuckTasks periodically checks for dead runners and stuck tasks -func (s *Server) recoverStuckTasks() { - ticker := time.NewTicker(10 * time.Second) +func (s *Manager) recoverStuckTasks() { + ticker := time.NewTicker(TaskDistributionInterval) defer ticker.Stop() - // Also distribute tasks every 10 seconds (reduced frequency since we have event-driven distribution) - distributeTicker := time.NewTicker(10 * time.Second) - defer distributeTicker.Stop() - - go func() { - for range distributeTicker.C { - s.triggerTaskDistribution() - } - }() - for range ticker.C { func() { defer func() { @@ -1012,37 +1007,28 @@ func (s *Server) recoverStuckTasks() { } }() - // Find dead runners (no heartbeat for 90 seconds) - // But only mark as dead if they're not actually connected via WebSocket + // Find dead runners (no heartbeat for configured timeout) + // In polling model, heartbeat is updated when runner polls for jobs var deadRunnerIDs []int64 - var stillConnectedIDs []int64 + cutoffTime := time.Now().Add(-RunnerHeartbeatTimeout) err := s.db.With(func(conn *sql.DB) error { rows, err := conn.Query( `SELECT id FROM runners - WHERE last_heartbeat < datetime('now', '-90 seconds') + WHERE last_heartbeat < ? AND status = ?`, - types.RunnerStatusOnline, + cutoffTime, types.RunnerStatusOnline, ) if err != nil { return err } defer rows.Close() - s.runnerConnsMu.RLock() for rows.Next() { var runnerID int64 if err := rows.Scan(&runnerID); err == nil { - // Only mark as dead if not actually connected via WebSocket - // The WebSocket connection is the source of truth - if _, stillConnected := s.runnerConns[runnerID]; !stillConnected { - deadRunnerIDs = append(deadRunnerIDs, runnerID) - } else { - // Runner is still connected but heartbeat is stale - update it - stillConnectedIDs = append(stillConnectedIDs, runnerID) - } + deadRunnerIDs = append(deadRunnerIDs, runnerID) } } - s.runnerConnsMu.RUnlock() return nil }) if err != nil { @@ -1050,27 +1036,9 @@ func (s *Server) recoverStuckTasks() { return } - // Update heartbeat for runners that are still connected but have stale heartbeats - // This ensures the database stays in sync with actual connection state - for _, runnerID := range stillConnectedIDs { - s.db.With(func(conn *sql.DB) error { - _, _ = conn.Exec( - `UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`, - time.Now(), types.RunnerStatusOnline, runnerID, - ) - return nil - }) - } - - if len(deadRunnerIDs) == 0 { - // Check for task timeouts - s.recoverTaskTimeouts() - return - } - // Reset tasks assigned to dead runners for _, runnerID := range deadRunnerIDs { - s.redistributeRunnerTasks(runnerID) + s.resetRunnerTasks(runnerID) // Mark runner as offline s.db.With(func(conn *sql.DB) error { @@ -1084,31 +1052,29 @@ func (s *Server) recoverStuckTasks() { // Check for task timeouts s.recoverTaskTimeouts() - - // Distribute newly recovered tasks - s.triggerTaskDistribution() }() } } // recoverTaskTimeouts handles tasks that have exceeded their timeout -func (s *Server) recoverTaskTimeouts() { +// Timeouts are treated as runner failures (not task failures) and retry indefinitely +func (s *Manager) recoverTaskTimeouts() { // Find tasks running longer than their timeout var tasks []struct { taskID int64 + jobID int64 runnerID sql.NullInt64 - retryCount int - maxRetries int timeoutSeconds sql.NullInt64 startedAt time.Time } err := s.db.With(func(conn *sql.DB) error { rows, err := conn.Query( - `SELECT t.id, t.runner_id, t.retry_count, t.max_retries, t.timeout_seconds, t.started_at + `SELECT t.id, t.job_id, t.runner_id, t.timeout_seconds, t.started_at FROM tasks t WHERE t.status = ? AND t.started_at IS NOT NULL + AND (t.completed_at IS NULL OR t.completed_at < datetime('now', '-30 seconds')) AND (t.timeout_seconds IS NULL OR (julianday('now') - julianday(t.started_at)) * 86400 > t.timeout_seconds)`, types.TaskStatusRunning, @@ -1121,13 +1087,12 @@ func (s *Server) recoverTaskTimeouts() { for rows.Next() { var task struct { taskID int64 + jobID int64 runnerID sql.NullInt64 - retryCount int - maxRetries int timeoutSeconds sql.NullInt64 startedAt time.Time } - err := rows.Scan(&task.taskID, &task.runnerID, &task.retryCount, &task.maxRetries, &task.timeoutSeconds, &task.startedAt) + err := rows.Scan(&task.taskID, &task.jobID, &task.runnerID, &task.timeoutSeconds, &task.startedAt) if err != nil { log.Printf("Failed to scan task row in recoverTaskTimeouts: %v", err) continue @@ -1143,8 +1108,7 @@ func (s *Server) recoverTaskTimeouts() { for _, task := range tasks { taskID := task.taskID - retryCount := task.retryCount - maxRetries := task.maxRetries + jobID := task.jobID timeoutSeconds := task.timeoutSeconds startedAt := task.startedAt @@ -1159,51 +1123,60 @@ func (s *Server) recoverTaskTimeouts() { continue } - if retryCount >= maxRetries { - // Mark as failed - err = s.db.With(func(conn *sql.DB) error { - _, err := conn.Exec(`UPDATE tasks SET status = ? WHERE id = ?`, types.TaskStatusFailed, taskID) - if err != nil { - return err - } - _, err = conn.Exec(`UPDATE tasks SET error_message = ? WHERE id = ?`, "Task timeout exceeded, max retries reached", taskID) - if err != nil { - return err - } - _, err = conn.Exec(`UPDATE tasks SET runner_id = NULL WHERE id = ?`, taskID) - return err - }) + // Timeouts are runner failures - always reset to pending and increment runner_failure_count + // This does NOT count against retry_count (which is for actual task failures like Blender crashes) + err = s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec(`UPDATE tasks SET status = ? WHERE id = ?`, types.TaskStatusPending, taskID) if err != nil { - log.Printf("Failed to mark task %d as failed: %v", taskID, err) - } - } else { - // Reset to pending - err = s.db.With(func(conn *sql.DB) error { - _, err := conn.Exec(`UPDATE tasks SET status = ? WHERE id = ?`, types.TaskStatusPending, taskID) - if err != nil { - return err - } - _, err = conn.Exec(`UPDATE tasks SET runner_id = NULL WHERE id = ?`, taskID) - if err != nil { - return err - } - _, err = conn.Exec(`UPDATE tasks SET current_step = NULL WHERE id = ?`, taskID) - if err != nil { - return err - } - _, err = conn.Exec(`UPDATE tasks SET retry_count = retry_count + 1 WHERE id = ?`, taskID) return err - }) - if err == nil { - // Add log entry using the helper function - s.logTaskEvent(taskID, nil, types.LogLevelWarn, fmt.Sprintf("Task timeout exceeded, resetting (retry %d/%d)", retryCount+1, maxRetries), "") } + _, err = conn.Exec(`UPDATE tasks SET runner_id = NULL WHERE id = ?`, taskID) + if err != nil { + return err + } + _, err = conn.Exec(`UPDATE tasks SET current_step = NULL WHERE id = ?`, taskID) + if err != nil { + return err + } + _, err = conn.Exec(`UPDATE tasks SET started_at = NULL WHERE id = ?`, taskID) + if err != nil { + return err + } + _, err = conn.Exec(`UPDATE tasks SET runner_failure_count = runner_failure_count + 1 WHERE id = ?`, taskID) + if err != nil { + return err + } + // Clear steps and logs for fresh retry + _, err = conn.Exec(`DELETE FROM task_steps WHERE task_id = ?`, taskID) + if err != nil { + return err + } + _, err = conn.Exec(`DELETE FROM task_logs WHERE task_id = ?`, taskID) + return err + }) + if err == nil { + // Broadcast task reset to clients (includes steps_cleared and logs_cleared flags) + s.broadcastTaskUpdate(jobID, taskID, "task_reset", map[string]interface{}{ + "status": types.TaskStatusPending, + "runner_id": nil, + "current_step": nil, + "started_at": nil, + "steps_cleared": true, + "logs_cleared": true, + }) + + // Update job status + s.updateJobStatusFromTasks(jobID) + + log.Printf("Reset timed out task %d: %v", taskID, err) + } else { + log.Printf("Failed to reset timed out task %d: %v", taskID, err) } } } // cleanupOldTempDirectories periodically cleans up old temporary directories -func (s *Server) cleanupOldTempDirectories() { +func (s *Manager) cleanupOldTempDirectories() { // Run cleanup every hour ticker := time.NewTicker(1 * time.Hour) defer ticker.Stop() @@ -1217,7 +1190,7 @@ func (s *Server) cleanupOldTempDirectories() { } // cleanupOldTempDirectoriesOnce removes temp directories older than 1 hour -func (s *Server) cleanupOldTempDirectoriesOnce() { +func (s *Manager) cleanupOldTempDirectoriesOnce() { defer func() { if r := recover(); r != nil { log.Printf("Panic in cleanupOldTempDirectories: %v", r) @@ -1285,7 +1258,7 @@ func (s *Server) cleanupOldTempDirectoriesOnce() { } // cleanupOldUploadSessions periodically cleans up abandoned upload sessions -func (s *Server) cleanupOldUploadSessions() { +func (s *Manager) cleanupOldUploadSessions() { // Run cleanup every 10 minutes ticker := time.NewTicker(10 * time.Minute) defer ticker.Stop() @@ -1299,7 +1272,7 @@ func (s *Server) cleanupOldUploadSessions() { } // cleanupOldUploadSessionsOnce removes upload sessions older than 1 hour -func (s *Server) cleanupOldUploadSessionsOnce() { +func (s *Manager) cleanupOldUploadSessionsOnce() { defer func() { if r := recover(); r != nil { log.Printf("Panic in cleanupOldUploadSessions: %v", r) diff --git a/internal/api/metadata.go b/internal/manager/metadata.go similarity index 64% rename from internal/api/metadata.go rename to internal/manager/metadata.go index 9789a89..dbdf467 100644 --- a/internal/api/metadata.go +++ b/internal/manager/metadata.go @@ -19,121 +19,8 @@ import ( "jiggablend/pkg/types" ) -// handleSubmitMetadata handles metadata submission from runner -func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) { - jobID, err := parseID(r, "jobId") - if err != nil { - s.respondError(w, http.StatusBadRequest, err.Error()) - return - } - - // Get runner ID from context (set by runnerAuthMiddleware) - runnerID, ok := r.Context().Value(runnerIDContextKey).(int64) - if !ok { - s.respondError(w, http.StatusUnauthorized, "runner_id not found in context") - return - } - - var metadata types.BlendMetadata - if err := json.NewDecoder(r.Body).Decode(&metadata); err != nil { - s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid metadata JSON: %v", err)) - return - } - - // Verify job exists - var jobUserID int64 - err = s.db.With(func(conn *sql.DB) error { - return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) - }) - if err == sql.ErrNoRows { - s.respondError(w, http.StatusNotFound, "Job not found") - return - } - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify job: %v", err)) - return - } - - // Find the metadata extraction task for this job - // First try to find task assigned to this runner, then fall back to any metadata task for this job - var taskID int64 - err = s.db.With(func(conn *sql.DB) error { - err := conn.QueryRow( - `SELECT id FROM tasks WHERE job_id = ? AND task_type = ? AND runner_id = ?`, - jobID, types.TaskTypeMetadata, runnerID, - ).Scan(&taskID) - if err == sql.ErrNoRows { - // Fall back to any metadata task for this job (in case assignment changed) - err = conn.QueryRow( - `SELECT id FROM tasks WHERE job_id = ? AND task_type = ? ORDER BY created_at DESC LIMIT 1`, - jobID, types.TaskTypeMetadata, - ).Scan(&taskID) - if err == sql.ErrNoRows { - return fmt.Errorf("metadata extraction task not found") - } - if err != nil { - return err - } - // Update the task to be assigned to this runner if it wasn't already - conn.Exec( - `UPDATE tasks SET runner_id = ? WHERE id = ? AND runner_id IS NULL`, - runnerID, taskID, - ) - } - return err - }) - if err != nil { - if err.Error() == "metadata extraction task not found" { - s.respondError(w, http.StatusNotFound, "Metadata extraction task not found") - return - } - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to find task: %v", err)) - return - } - - // Convert metadata to JSON - metadataJSON, err := json.Marshal(metadata) - if err != nil { - s.respondError(w, http.StatusInternalServerError, "Failed to marshal metadata") - return - } - - // Update job with metadata - err = s.db.With(func(conn *sql.DB) error { - _, err := conn.Exec( - `UPDATE jobs SET blend_metadata = ? WHERE id = ?`, - string(metadataJSON), jobID, - ) - return err - }) - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update job metadata: %v", err)) - return - } - - // Mark task as completed - err = s.db.With(func(conn *sql.DB) error { - _, err := conn.Exec(`UPDATE tasks SET status = ? WHERE id = ?`, types.TaskStatusCompleted, taskID) - if err != nil { - return err - } - _, err = conn.Exec(`UPDATE tasks SET completed_at = CURRENT_TIMESTAMP WHERE id = ?`, taskID) - return err - }) - if err != nil { - log.Printf("Failed to mark metadata task as completed: %v", err) - } else { - // Update job status and progress after metadata task completes - s.updateJobStatusFromTasks(jobID) - } - - log.Printf("Metadata extracted for job %d: frame_start=%d, frame_end=%d", jobID, metadata.FrameStart, metadata.FrameEnd) - - s.respondJSON(w, http.StatusOK, map[string]string{"message": "Metadata submitted successfully"}) -} - // handleGetJobMetadata retrieves metadata for a job -func (s *Server) handleGetJobMetadata(w http.ResponseWriter, r *http.Request) { +func (s *Manager) handleGetJobMetadata(w http.ResponseWriter, r *http.Request) { userID, err := getUserID(r) if err != nil { s.respondError(w, http.StatusUnauthorized, err.Error()) @@ -151,9 +38,9 @@ func (s *Server) handleGetJobMetadata(w http.ResponseWriter, r *http.Request) { var blendMetadataJSON sql.NullString err = s.db.With(func(conn *sql.DB) error { return conn.QueryRow( - `SELECT user_id, blend_metadata FROM jobs WHERE id = ?`, - jobID, - ).Scan(&jobUserID, &blendMetadataJSON) + `SELECT user_id, blend_metadata FROM jobs WHERE id = ?`, + jobID, + ).Scan(&jobUserID, &blendMetadataJSON) }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") @@ -184,7 +71,7 @@ func (s *Server) handleGetJobMetadata(w http.ResponseWriter, r *http.Request) { // extractMetadataFromContext extracts metadata from the blend file in a context archive // Returns the extracted metadata or an error -func (s *Server) extractMetadataFromContext(jobID int64) (*types.BlendMetadata, error) { +func (s *Manager) extractMetadataFromContext(jobID int64) (*types.BlendMetadata, error) { contextPath := filepath.Join(s.storage.JobPath(jobID), "context.tar") // Check if context exists @@ -310,7 +197,7 @@ func (s *Server) extractMetadataFromContext(jobID int64) (*types.BlendMetadata, } // extractTar extracts a tar archive to a destination directory -func (s *Server) extractTar(tarPath, destDir string) error { +func (s *Manager) extractTar(tarPath, destDir string) error { log.Printf("Extracting tar archive: %s -> %s", tarPath, destDir) // Ensure destination directory exists @@ -355,7 +242,8 @@ func (s *Server) extractTar(tarPath, destDir string) error { } // Write file - if header.Typeflag == tar.TypeReg { + switch header.Typeflag { + case tar.TypeReg: outFile, err := os.Create(target) if err != nil { return fmt.Errorf("failed to create file: %w", err) @@ -367,7 +255,7 @@ func (s *Server) extractTar(tarPath, destDir string) error { } outFile.Close() fileCount++ - } else if header.Typeflag == tar.TypeDir { + case tar.TypeDir: dirCount++ } } diff --git a/internal/manager/runners.go b/internal/manager/runners.go new file mode 100644 index 0000000..bc0d85a --- /dev/null +++ b/internal/manager/runners.go @@ -0,0 +1,2501 @@ +package api + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/url" + "path/filepath" + "strconv" + "strings" + "sync" + "time" + + "jiggablend/internal/auth" + "jiggablend/pkg/types" + + "github.com/go-chi/chi/v5" + "github.com/gorilla/websocket" +) + +type contextKey string + +const runnerIDContextKey contextKey = "runner_id" + +// runnerAuthMiddleware verifies runner requests using API key +func (s *Manager) runnerAuthMiddleware(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Get API key from header + apiKey := r.Header.Get("Authorization") + if apiKey == "" { + // Try alternative header + apiKey = r.Header.Get("X-API-Key") + } + if apiKey == "" { + s.respondError(w, http.StatusUnauthorized, "API key required") + return + } + + // Remove "Bearer " prefix if present + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + + // Validate API key and get its ID + apiKeyID, _, err := s.secrets.ValidateRunnerAPIKey(apiKey) + if err != nil { + log.Printf("API key validation failed: %v", err) + s.respondError(w, http.StatusUnauthorized, "invalid API key") + return + } + + // Get runner ID from query string or find runner by API key + runnerIDStr := r.URL.Query().Get("runner_id") + var runnerID int64 + + if runnerIDStr != "" { + // Runner ID provided - verify it belongs to this API key + _, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID) + if err != nil { + s.respondError(w, http.StatusBadRequest, "invalid runner_id") + return + } + + // For fixed API keys, skip database verification + if apiKeyID != -1 { + // Verify runner exists and uses this API key + var dbAPIKeyID sql.NullInt64 + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT api_key_id FROM runners WHERE id = ?", runnerID).Scan(&dbAPIKeyID) + }) + if err == sql.ErrNoRows { + s.respondError(w, http.StatusNotFound, "runner not found") + return + } + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to query runner API key: %v", err)) + return + } + if !dbAPIKeyID.Valid || dbAPIKeyID.Int64 != apiKeyID { + s.respondError(w, http.StatusForbidden, "runner does not belong to this API key") + return + } + } + } else { + // No runner ID provided - find the runner for this API key + // For simplicity, assume each API key has one runner + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT id FROM runners WHERE api_key_id = ?", apiKeyID).Scan(&runnerID) + }) + if err == sql.ErrNoRows { + s.respondError(w, http.StatusNotFound, "no runner found for this API key") + return + } + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to query runner by API key: %v", err)) + return + } + } + + // Add runner ID to context + ctx := r.Context() + ctx = context.WithValue(ctx, runnerIDContextKey, runnerID) + next(w, r.WithContext(ctx)) + } +} + +// handleRegisterRunner registers a new runner using an API key +func (s *Manager) handleRegisterRunner(w http.ResponseWriter, r *http.Request) { + var req struct { + types.RegisterRunnerRequest + APIKey string `json:"api_key"` + Fingerprint string `json:"fingerprint,omitempty"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err)) + return + } + + // Lock to prevent concurrent registrations that could create duplicate runners + s.secrets.RegistrationMu.Lock() + defer s.secrets.RegistrationMu.Unlock() + + // Validate runner name + if req.Name == "" { + s.respondError(w, http.StatusBadRequest, "Runner name is required") + return + } + if len(req.Name) > 255 { + s.respondError(w, http.StatusBadRequest, "Runner name must be 255 characters or less") + return + } + + // Validate hostname + if req.Hostname != "" { + // Basic hostname validation (allow IP addresses and domain names) + if len(req.Hostname) > 253 { + s.respondError(w, http.StatusBadRequest, "Hostname must be 253 characters or less") + return + } + } + + // Validate capabilities JSON if provided + if req.Capabilities != "" { + var testCapabilities map[string]interface{} + if err := json.Unmarshal([]byte(req.Capabilities), &testCapabilities); err != nil { + s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid capabilities JSON: %v", err)) + return + } + } + + if req.APIKey == "" { + s.respondError(w, http.StatusBadRequest, "API key is required") + return + } + + // Validate API key + apiKeyID, apiKeyScope, err := s.secrets.ValidateRunnerAPIKey(req.APIKey) + if err != nil { + s.respondError(w, http.StatusUnauthorized, fmt.Sprintf("Invalid API key: %v", err)) + return + } + + // For fixed API keys (keyID = -1), skip fingerprint checking + // Set default priority if not provided + priority := 100 + if req.Priority != nil { + priority = *req.Priority + } + + // Register runner + var runnerID int64 + // For fixed API keys, don't store api_key_id in database + var dbAPIKeyID interface{} + if apiKeyID == -1 { + dbAPIKeyID = nil // NULL for fixed API keys + } else { + dbAPIKeyID = apiKeyID + } + + // Determine fingerprint value + fingerprint := req.Fingerprint + if apiKeyID == -1 || fingerprint == "" { + // For fixed API keys or when no fingerprint provided, generate a unique fingerprint + // to avoid conflicts while still maintaining some uniqueness + fingerprint = fmt.Sprintf("fixed-%s-%d", req.Name, time.Now().UnixNano()) + } + + // Check fingerprint uniqueness only for non-fixed API keys + if apiKeyID != -1 && req.Fingerprint != "" { + var existingRunnerID int64 + var existingAPIKeyID sql.NullInt64 + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + "SELECT id, api_key_id FROM runners WHERE fingerprint = ?", + req.Fingerprint, + ).Scan(&existingRunnerID, &existingAPIKeyID) + }) + + if err == nil { + // Runner already exists with this fingerprint + if existingAPIKeyID.Valid && existingAPIKeyID.Int64 == apiKeyID { + // Same API key - update and return existing runner + log.Printf("Runner with fingerprint %s already exists (ID: %d), updating info", req.Fingerprint, existingRunnerID) + + err = s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + `UPDATE runners SET name = ?, hostname = ?, capabilities = ?, status = ?, last_heartbeat = ? WHERE id = ?`, + req.Name, req.Hostname, req.Capabilities, types.RunnerStatusOnline, time.Now(), existingRunnerID, + ) + return err + }) + if err != nil { + log.Printf("Warning: Failed to update existing runner info: %v", err) + } + + s.respondJSON(w, http.StatusOK, map[string]interface{}{ + "id": existingRunnerID, + "name": req.Name, + "hostname": req.Hostname, + "status": types.RunnerStatusOnline, + "reused": true, // Indicates this was a re-registration + }) + return + } else { + // Different API key - reject registration + s.respondError(w, http.StatusConflict, "Runner with this fingerprint already registered with different API key") + return + } + } + // If err is not nil, it means no existing runner with this fingerprint - proceed with new registration + } + + // Insert runner + err = s.db.With(func(conn *sql.DB) error { + result, err := conn.Exec( + `INSERT INTO runners (name, hostname, ip_address, status, last_heartbeat, capabilities, + api_key_id, api_key_scope, priority, fingerprint) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + req.Name, req.Hostname, "", types.RunnerStatusOnline, time.Now(), req.Capabilities, + dbAPIKeyID, apiKeyScope, priority, fingerprint, + ) + if err != nil { + return err + } + runnerID, err = result.LastInsertId() + return err + }) + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to register runner: %v", err)) + return + } + + log.Printf("Registered new runner %s (ID: %d) with API key ID: %d", req.Name, runnerID, apiKeyID) + + // Return runner info + s.respondJSON(w, http.StatusCreated, map[string]interface{}{ + "id": runnerID, + "name": req.Name, + "hostname": req.Hostname, + "status": types.RunnerStatusOnline, + }) +} + +// NextJobResponse is the response for the next-job endpoint +type NextJobResponse struct { + JobToken string `json:"job_token"` + JobPath string `json:"job_path"` + Task NextJobTaskInfo `json:"task"` +} + +// NextJobTaskInfo contains task information for the next-job response +type NextJobTaskInfo struct { + TaskID int64 `json:"task_id"` + JobID int64 `json:"job_id"` + JobName string `json:"job_name"` + Frame int `json:"frame"` + TaskType string `json:"task_type"` + Metadata *types.BlendMetadata `json:"metadata,omitempty"` +} + +// handleNextJob handles the polling endpoint for runners to get their next job +// GET /api/runner/workers/:id/next-job +func (s *Manager) handleNextJob(w http.ResponseWriter, r *http.Request) { + // Get runner ID from URL path + runnerIDStr := chi.URLParam(r, "id") + if runnerIDStr == "" { + s.respondError(w, http.StatusBadRequest, "runner ID required") + return + } + var runnerID int64 + if _, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID); err != nil { + s.respondError(w, http.StatusBadRequest, "invalid runner ID") + return + } + + // Get API key from header + apiKey := r.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + if apiKey == "" { + s.respondError(w, http.StatusUnauthorized, "API key required") + return + } + + // Validate API key + apiKeyID, apiKeyScope, err := s.secrets.ValidateRunnerAPIKey(apiKey) + if err != nil { + s.respondError(w, http.StatusUnauthorized, fmt.Sprintf("Invalid API key: %v", err)) + return + } + + // Verify runner exists and belongs to this API key + var dbAPIKeyID sql.NullInt64 + var runnerCapabilitiesJSON sql.NullString + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT api_key_id, capabilities FROM runners WHERE id = ?", runnerID).Scan(&dbAPIKeyID, &runnerCapabilitiesJSON) + }) + if err == sql.ErrNoRows { + s.respondError(w, http.StatusNotFound, "runner not found") + return + } + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to query runner: %v", err)) + return + } + + // For non-fixed API keys, verify ownership + if apiKeyID != -1 { + if !dbAPIKeyID.Valid || dbAPIKeyID.Int64 != apiKeyID { + s.respondError(w, http.StatusForbidden, "runner does not belong to this API key") + return + } + } + + // Update runner heartbeat + s.db.With(func(conn *sql.DB) error { + _, _ = conn.Exec( + `UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`, + time.Now(), types.RunnerStatusOnline, runnerID, + ) + return nil + }) + + // Parse runner capabilities + var runnerCapabilities map[string]interface{} + if runnerCapabilitiesJSON.Valid && runnerCapabilitiesJSON.String != "" { + if err := json.Unmarshal([]byte(runnerCapabilitiesJSON.String), &runnerCapabilities); err != nil { + runnerCapabilities = make(map[string]interface{}) + } + } else { + runnerCapabilities = make(map[string]interface{}) + } + + // Check if runner already has an active task + var activeTaskCount int + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + `SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status IN (?, ?)`, + runnerID, types.TaskStatusPending, types.TaskStatusRunning, + ).Scan(&activeTaskCount) + }) + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to check active tasks: %v", err)) + return + } + if activeTaskCount > 0 { + // Runner is busy, return 204 + w.WriteHeader(http.StatusNoContent) + return + } + + // Find next pending task for this runner + // Query pending tasks ordered by created_at (oldest first) + type taskCandidate struct { + TaskID int64 + JobID int64 + Frame int + TaskType string + JobName string + JobUserID int64 + BlendMetadata sql.NullString + } + var candidates []taskCandidate + + err = s.db.With(func(conn *sql.DB) error { + rows, err := conn.Query( + `SELECT t.id, t.job_id, t.frame, t.task_type, + j.name as job_name, j.user_id, j.blend_metadata, + t.condition + FROM tasks t + JOIN jobs j ON t.job_id = j.id + WHERE t.status = ? AND j.status != ? + ORDER BY t.created_at ASC + LIMIT 50`, + types.TaskStatusPending, types.JobStatusCancelled, + ) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var task taskCandidate + var condition sql.NullString + err := rows.Scan(&task.TaskID, &task.JobID, &task.Frame, &task.TaskType, + &task.JobName, &task.JobUserID, &task.BlendMetadata, &condition) + if err != nil { + continue + } + + // Check if task condition is met before adding to candidates + conditionStr := "" + if condition.Valid { + conditionStr = condition.String + } + if !s.evaluateTaskCondition(task.TaskID, task.JobID, conditionStr) { + continue // Skip tasks whose conditions are not met + } + + candidates = append(candidates, task) + } + return rows.Err() + }) + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to query tasks: %v", err)) + return + } + + // Find a suitable task from candidates + var selectedTask *taskCandidate + for i := range candidates { + task := &candidates[i] + + // Check runner scope + if apiKeyScope == "user" && task.JobUserID != 0 { + // User-scoped runner - check if they can work on this job + var apiKeyCreatedBy int64 + if apiKeyID != -1 { + s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT created_by FROM runner_api_keys WHERE id = ?", apiKeyID).Scan(&apiKeyCreatedBy) + }) + if apiKeyCreatedBy != task.JobUserID { + continue // Skip this task + } + } + } + + // Check required capability (only for ffmpeg - blender is assumed installed) + if task.TaskType == string(types.TaskTypeEncode) { + hasFFmpeg := false + if reqVal, ok := runnerCapabilities["ffmpeg"]; ok { + if reqBool, ok := reqVal.(bool); ok { + hasFFmpeg = reqBool + } else if reqFloat, ok := reqVal.(float64); ok { + hasFFmpeg = reqFloat > 0 + } + } + if !hasFFmpeg { + continue // Runner doesn't have ffmpeg capability + } + } + + // Found a suitable task + selectedTask = task + break + } + + if selectedTask == nil { + // No task available + w.WriteHeader(http.StatusNoContent) + return + } + + // Atomically assign task to runner + now := time.Now() + var rowsAffected int64 + err = s.db.WithTx(func(tx *sql.Tx) error { + result, err := tx.Exec( + `UPDATE tasks SET runner_id = ?, status = ?, started_at = ? + WHERE id = ? AND runner_id IS NULL AND status = ?`, + runnerID, types.TaskStatusRunning, now, selectedTask.TaskID, types.TaskStatusPending, + ) + if err != nil { + return err + } + rowsAffected, err = result.RowsAffected() + if err != nil { + return err + } + + // Also update job's assigned_runner_id to track current worker + // For parallel jobs, this will be updated each time a new runner picks up a task + _, err = tx.Exec( + `UPDATE jobs SET assigned_runner_id = ? WHERE id = ?`, + runnerID, selectedTask.JobID, + ) + return err + }) + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to assign task: %v", err)) + return + } + if rowsAffected == 0 { + // Task was already assigned by another runner, return 204 to retry + w.WriteHeader(http.StatusNoContent) + return + } + + // Generate job token + jobToken, err := auth.GenerateJobToken(selectedTask.JobID, runnerID, selectedTask.TaskID) + if err != nil { + // Rollback task assignment and job runner assignment + s.db.With(func(conn *sql.DB) error { + _, _ = conn.Exec( + `UPDATE tasks SET runner_id = NULL, status = ?, started_at = NULL WHERE id = ?`, + types.TaskStatusPending, selectedTask.TaskID, + ) + _, _ = conn.Exec( + `UPDATE jobs SET assigned_runner_id = NULL WHERE id = ?`, + selectedTask.JobID, // Fixed: was selectedTask.TaskID + ) + return nil + }) + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to generate job token: %v", err)) + return + } + + // Parse metadata + var metadata *types.BlendMetadata + if selectedTask.BlendMetadata.Valid && selectedTask.BlendMetadata.String != "" { + metadata = &types.BlendMetadata{} + if err := json.Unmarshal([]byte(selectedTask.BlendMetadata.String), metadata); err != nil { + metadata = nil + } + } + + // Log task assignment + log.Printf("Assigned task %d (type: %s, job: %d) to runner %d via polling", selectedTask.TaskID, selectedTask.TaskType, selectedTask.JobID, runnerID) + s.logTaskEvent(selectedTask.TaskID, nil, types.LogLevelInfo, fmt.Sprintf("Task assigned to runner %d", runnerID), "") + + // Broadcast task update to frontend + s.broadcastTaskUpdate(selectedTask.JobID, selectedTask.TaskID, "task_update", map[string]interface{}{ + "status": types.TaskStatusRunning, + "runner_id": runnerID, + "started_at": now, + }) + + // Update job status + s.updateJobStatusFromTasks(selectedTask.JobID) + + // Build response + response := NextJobResponse{ + JobToken: jobToken, + JobPath: fmt.Sprintf("/api/runner/jobs/%d", selectedTask.JobID), + Task: NextJobTaskInfo{ + TaskID: selectedTask.TaskID, + JobID: selectedTask.JobID, + JobName: selectedTask.JobName, + Frame: selectedTask.Frame, + TaskType: selectedTask.TaskType, + Metadata: metadata, + }, + } + + s.respondJSON(w, http.StatusOK, response) +} + +// handleUpdateTaskProgress updates task progress +func (s *Manager) handleUpdateTaskProgress(w http.ResponseWriter, r *http.Request) { + _, err := parseID(r, "id") + if err != nil { + s.respondError(w, http.StatusBadRequest, err.Error()) + return + } + + var req struct { + Progress float64 `json:"progress"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err)) + return + } + + // This is mainly for logging/debugging, actual progress is calculated from completed tasks + s.respondJSON(w, http.StatusOK, map[string]string{"message": "Progress updated"}) +} + +// handleUpdateTaskStep handles step start/complete events from runners +func (s *Manager) handleUpdateTaskStep(w http.ResponseWriter, r *http.Request) { + // Get runner ID from context (set by runnerAuthMiddleware) + runnerID, ok := r.Context().Value(runnerIDContextKey).(int64) + if !ok { + s.respondError(w, http.StatusUnauthorized, "runner_id not found in context") + return + } + + taskID, err := parseID(r, "id") + if err != nil { + s.respondError(w, http.StatusBadRequest, err.Error()) + return + } + + var req struct { + StepName string `json:"step_name"` + Status string `json:"status"` // "pending", "running", "completed", "failed", "skipped" + DurationMs *int `json:"duration_ms,omitempty"` + ErrorMessage string `json:"error_message,omitempty"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err)) + return + } + + // Verify task belongs to runner + var taskRunnerID sql.NullInt64 + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskID).Scan(&taskRunnerID) + }) + if err == sql.ErrNoRows { + s.respondError(w, http.StatusNotFound, "Task not found") + return + } + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify task: %v", err)) + return + } + if !taskRunnerID.Valid || taskRunnerID.Int64 != runnerID { + s.respondError(w, http.StatusForbidden, "Task does not belong to this runner") + return + } + + now := time.Now() + var stepID int64 + + // Check if step already exists + var existingStepID sql.NullInt64 + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + `SELECT id FROM task_steps WHERE task_id = ? AND step_name = ?`, + taskID, req.StepName, + ).Scan(&existingStepID) + }) + + if err == sql.ErrNoRows || !existingStepID.Valid { + // Create new step + var startedAt *time.Time + var completedAt *time.Time + if req.Status == string(types.StepStatusRunning) || req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) { + startedAt = &now + } + if req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) { + completedAt = &now + } + + err = s.db.With(func(conn *sql.DB) error { + result, err := conn.Exec( + `INSERT INTO task_steps (task_id, step_name, status, started_at, completed_at, duration_ms, error_message) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + taskID, req.StepName, req.Status, startedAt, completedAt, req.DurationMs, req.ErrorMessage, + ) + if err != nil { + return err + } + stepID, err = result.LastInsertId() + return err + }) + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create step: %v", err)) + return + } + } else { + // Update existing step + stepID = existingStepID.Int64 + var startedAt *time.Time + var completedAt *time.Time + + // Get existing started_at if status is running/completed/failed + if req.Status == string(types.StepStatusRunning) || req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) { + var existingStartedAt sql.NullTime + s.db.With(func(conn *sql.DB) error { + return conn.QueryRow(`SELECT started_at FROM task_steps WHERE id = ?`, stepID).Scan(&existingStartedAt) + }) + if existingStartedAt.Valid { + startedAt = &existingStartedAt.Time + } else { + startedAt = &now + } + } + + if req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) { + completedAt = &now + } + + err = s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + `UPDATE task_steps SET status = ?, started_at = ?, completed_at = ?, duration_ms = ?, error_message = ? + WHERE id = ?`, + req.Status, startedAt, completedAt, req.DurationMs, req.ErrorMessage, stepID, + ) + return err + }) + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update step: %v", err)) + return + } + } + + // Get job ID for broadcasting + var jobID int64 + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&jobID) + }) + if err == nil { + // Broadcast step update to frontend + s.broadcastTaskUpdate(jobID, taskID, "step_update", map[string]interface{}{ + "step_id": stepID, + "step_name": req.StepName, + "status": req.Status, + "duration_ms": req.DurationMs, + "error_message": req.ErrorMessage, + }) + } + + s.respondJSON(w, http.StatusOK, map[string]interface{}{ + "step_id": stepID, + "message": "Step updated successfully", + }) +} + +// handleDownloadJobContext allows runners to download the job context tar +// DEPRECATED: Use handleDownloadJobContextWithToken for new polling-based flow +func (s *Manager) handleDownloadJobContext(w http.ResponseWriter, r *http.Request) { + jobID, err := parseID(r, "jobId") + if err != nil { + s.respondError(w, http.StatusBadRequest, err.Error()) + return + } + + // Construct the context file path + contextPath := filepath.Join(s.storage.JobPath(jobID), "context.tar") + + // Check if context file exists + if !s.storage.FileExists(contextPath) { + log.Printf("Context archive not found for job %d", jobID) + s.respondError(w, http.StatusNotFound, "Context archive not found. The file may not have been uploaded successfully.") + return + } + + // Open and serve file + file, err := s.storage.GetFile(contextPath) + if err != nil { + s.respondError(w, http.StatusNotFound, "Context file not found on disk") + return + } + defer file.Close() + + // Set appropriate headers for tar file + w.Header().Set("Content-Type", "application/x-tar") + w.Header().Set("Content-Disposition", "attachment; filename=context.tar") + + // Stream the file to the response + io.Copy(w, file) +} + +// handleDownloadJobContextWithToken allows runners to download job context using job_token +// GET /api/runner/jobs/:jobId/context.tar +func (s *Manager) handleDownloadJobContextWithToken(w http.ResponseWriter, r *http.Request) { + jobID, err := parseID(r, "jobId") + if err != nil { + s.respondError(w, http.StatusBadRequest, err.Error()) + return + } + + // Get job token from Authorization header + jobToken := r.Header.Get("Authorization") + jobToken = strings.TrimPrefix(jobToken, "Bearer ") + if jobToken == "" { + s.respondError(w, http.StatusUnauthorized, "job token required") + return + } + + // Validate job token + claims, err := auth.ValidateJobToken(jobToken) + if err != nil { + s.respondError(w, http.StatusUnauthorized, fmt.Sprintf("invalid job token: %v", err)) + return + } + + // Verify job ID matches + if claims.JobID != jobID { + s.respondError(w, http.StatusForbidden, "job ID mismatch") + return + } + + // Construct the context file path + contextPath := filepath.Join(s.storage.JobPath(jobID), "context.tar") + + // Check if context file exists + if !s.storage.FileExists(contextPath) { + log.Printf("Context archive not found for job %d", jobID) + s.respondError(w, http.StatusNotFound, "Context archive not found. The file may not have been uploaded successfully.") + return + } + + // Open and serve file + file, err := s.storage.GetFile(contextPath) + if err != nil { + s.respondError(w, http.StatusNotFound, "Context file not found on disk") + return + } + defer file.Close() + + // Set appropriate headers for tar file + w.Header().Set("Content-Type", "application/x-tar") + w.Header().Set("Content-Disposition", "attachment; filename=context.tar") + + // Stream the file to the response + io.Copy(w, file) +} + +// handleUploadFileFromRunner allows runners to upload output files +// DEPRECATED: Use handleUploadFileWithToken for new polling-based flow +func (s *Manager) handleUploadFileFromRunner(w http.ResponseWriter, r *http.Request) { + jobID, err := parseID(r, "jobId") + if err != nil { + s.respondError(w, http.StatusBadRequest, err.Error()) + return + } + + err = r.ParseMultipartForm(MaxUploadSize) // 50 GB (for large output files) + if err != nil { + s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Failed to parse multipart form: %v", err)) + return + } + + file, header, err := r.FormFile("file") + if err != nil { + s.respondError(w, http.StatusBadRequest, "No file provided") + return + } + defer file.Close() + + // Save file + filePath, err := s.storage.SaveOutput(jobID, header.Filename, file) + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to save file: %v", err)) + return + } + + // Record in database - check for existing file first to avoid duplicates + var fileID int64 + err = s.db.With(func(conn *sql.DB) error { + // Check if file with same name already exists + var existingID int64 + err := conn.QueryRow( + `SELECT id FROM job_files WHERE job_id = ? AND file_type = ? AND file_name = ?`, + jobID, types.JobFileTypeOutput, header.Filename, + ).Scan(&existingID) + + switch err { + case nil: + // File exists - update it instead of creating duplicate + log.Printf("File %s already exists for job %d (ID: %d), updating record", header.Filename, jobID, existingID) + _, err = conn.Exec( + `UPDATE job_files SET file_path = ?, file_size = ? WHERE id = ?`, + filePath, header.Size, existingID, + ) + if err != nil { + return err + } + fileID = existingID + return nil + case sql.ErrNoRows: + // File doesn't exist - insert new record + result, err := conn.Exec( + `INSERT INTO job_files (job_id, file_type, file_path, file_name, file_size) + VALUES (?, ?, ?, ?, ?)`, + jobID, types.JobFileTypeOutput, filePath, header.Filename, header.Size, + ) + if err != nil { + return err + } + fileID, err = result.LastInsertId() + return err + default: + return err + } + }) + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to record file: %v", err)) + return + } + + // Broadcast file addition + s.broadcastJobUpdate(jobID, "file_added", map[string]interface{}{ + "file_id": fileID, + "file_type": types.JobFileTypeOutput, + "file_name": header.Filename, + "file_size": header.Size, + }) + + s.respondJSON(w, http.StatusCreated, map[string]interface{}{ + "file_path": filePath, + "file_name": header.Filename, + }) +} + +// handleUploadFileWithToken allows runners to upload output files using job_token +// POST /api/runner/jobs/:jobId/upload +func (s *Manager) handleUploadFileWithToken(w http.ResponseWriter, r *http.Request) { + jobID, err := parseID(r, "jobId") + if err != nil { + s.respondError(w, http.StatusBadRequest, err.Error()) + return + } + + // Get job token from Authorization header + jobToken := r.Header.Get("Authorization") + jobToken = strings.TrimPrefix(jobToken, "Bearer ") + if jobToken == "" { + s.respondError(w, http.StatusUnauthorized, "job token required") + return + } + + // Validate job token + claims, err := auth.ValidateJobToken(jobToken) + if err != nil { + s.respondError(w, http.StatusUnauthorized, fmt.Sprintf("invalid job token: %v", err)) + return + } + + // Verify job ID matches + if claims.JobID != jobID { + s.respondError(w, http.StatusForbidden, "job ID mismatch") + return + } + + err = r.ParseMultipartForm(MaxUploadSize) // 50 GB (for large output files) + if err != nil { + s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Failed to parse multipart form: %v", err)) + return + } + + file, header, err := r.FormFile("file") + if err != nil { + s.respondError(w, http.StatusBadRequest, "No file provided") + return + } + defer file.Close() + + // Save file + filePath, err := s.storage.SaveOutput(jobID, header.Filename, file) + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to save file: %v", err)) + return + } + + // Record in database + var fileID int64 + err = s.db.With(func(conn *sql.DB) error { + result, err := conn.Exec( + `INSERT INTO job_files (job_id, file_type, file_path, file_name, file_size) + VALUES (?, ?, ?, ?, ?)`, + jobID, types.JobFileTypeOutput, filePath, header.Filename, header.Size, + ) + if err != nil { + return err + } + fileID, err = result.LastInsertId() + return err + }) + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to record file: %v", err)) + return + } + + // Broadcast file addition + s.broadcastJobUpdate(jobID, "file_added", map[string]interface{}{ + "file_id": fileID, + "file_type": types.JobFileTypeOutput, + "file_name": header.Filename, + "file_size": header.Size, + }) + + log.Printf("Runner uploaded file %s for job %d (task %d)", header.Filename, jobID, claims.TaskID) + + s.respondJSON(w, http.StatusCreated, map[string]interface{}{ + "file_id": fileID, + "file_path": filePath, + "file_name": header.Filename, + }) +} + +// handleGetJobStatusForRunner allows runners to check job status +func (s *Manager) handleGetJobStatusForRunner(w http.ResponseWriter, r *http.Request) { + jobID, err := parseID(r, "jobId") + if err != nil { + s.respondError(w, http.StatusBadRequest, err.Error()) + return + } + + var job types.Job + var startedAt, completedAt sql.NullTime + var errorMessage sql.NullString + + var jobType string + var frameStart, frameEnd sql.NullInt64 + var outputFormat sql.NullString + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + `SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format, + created_at, started_at, completed_at, error_message + FROM jobs WHERE id = ?`, + jobID, + ).Scan( + &job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress, + &frameStart, &frameEnd, &outputFormat, + &job.CreatedAt, &startedAt, &completedAt, &errorMessage, + ) + }) + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query job: %v", err)) + return + } + + job.JobType = types.JobType(jobType) + if frameStart.Valid { + fs := int(frameStart.Int64) + job.FrameStart = &fs + } + if frameEnd.Valid { + fe := int(frameEnd.Int64) + job.FrameEnd = &fe + } + if outputFormat.Valid { + job.OutputFormat = &outputFormat.String + } + + if err == sql.ErrNoRows { + s.respondError(w, http.StatusNotFound, "Job not found") + return + } + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query job: %v", err)) + return + } + + if startedAt.Valid { + job.StartedAt = &startedAt.Time + } + if completedAt.Valid { + job.CompletedAt = &completedAt.Time + } + if errorMessage.Valid { + job.ErrorMessage = errorMessage.String + } + + s.respondJSON(w, http.StatusOK, job) +} + +// handleGetJobFilesForRunner allows runners to get job files +func (s *Manager) handleGetJobFilesForRunner(w http.ResponseWriter, r *http.Request) { + jobID, err := parseID(r, "jobId") + if err != nil { + s.respondError(w, http.StatusBadRequest, err.Error()) + return + } + + runnerID := r.URL.Query().Get("runner_id") + log.Printf("GetJobFiles request for job %d from runner %s", jobID, runnerID) + + var rows *sql.Rows + var fileCount int + err = s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query( + `SELECT id, job_id, file_type, file_path, file_name, file_size, created_at + FROM job_files WHERE job_id = ? ORDER BY file_name`, + jobID, + ) + if err != nil { + return err + } + // Count files + var count int + err = conn.QueryRow(`SELECT COUNT(*) FROM job_files WHERE job_id = ?`, jobID).Scan(&count) + if err == nil { + fileCount = count + } + return nil + }) + if err != nil { + log.Printf("GetJobFiles query error for job %d: %v", jobID, err) + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query files: %v", err)) + return + } + defer rows.Close() + + files := []types.JobFile{} + for rows.Next() { + var file types.JobFile + err := rows.Scan( + &file.ID, &file.JobID, &file.FileType, &file.FilePath, + &file.FileName, &file.FileSize, &file.CreatedAt, + ) + if err != nil { + log.Printf("GetJobFiles scan error for job %d: %v", jobID, err) + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan file: %v", err)) + return + } + files = append(files, file) + log.Printf("GetJobFiles: returning file %s (type: %s, size: %d) for job %d", file.FileName, file.FileType, file.FileSize, jobID) + } + + log.Printf("GetJobFiles returning %d files for job %d (total in DB: %d)", len(files), jobID, fileCount) + s.respondJSON(w, http.StatusOK, files) +} + +// handleGetJobMetadataForRunner allows runners to get job metadata +func (s *Manager) handleGetJobMetadataForRunner(w http.ResponseWriter, r *http.Request) { + jobID, err := parseID(r, "jobId") + if err != nil { + s.respondError(w, http.StatusBadRequest, err.Error()) + return + } + + var blendMetadataJSON sql.NullString + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + `SELECT blend_metadata FROM jobs WHERE id = ?`, + jobID, + ).Scan(&blendMetadataJSON) + }) + + if err == sql.ErrNoRows { + s.respondError(w, http.StatusNotFound, "Job not found") + return + } + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query job: %v", err)) + return + } + + if !blendMetadataJSON.Valid || blendMetadataJSON.String == "" { + s.respondJSON(w, http.StatusOK, nil) + return + } + + var metadata types.BlendMetadata + if err := json.Unmarshal([]byte(blendMetadataJSON.String), &metadata); err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to parse metadata JSON: %v", err)) + return + } + + s.respondJSON(w, http.StatusOK, metadata) +} + +// handleDownloadFileForRunner allows runners to download a file by fileName +func (s *Manager) handleDownloadFileForRunner(w http.ResponseWriter, r *http.Request) { + jobID, err := parseID(r, "jobId") + if err != nil { + s.respondError(w, http.StatusBadRequest, err.Error()) + return + } + + // Get fileName from URL path (may need URL decoding) + fileName := chi.URLParam(r, "fileName") + if fileName == "" { + s.respondError(w, http.StatusBadRequest, "fileName is required") + return + } + + // URL decode the fileName in case it contains encoded characters + decodedFileName, err := url.QueryUnescape(fileName) + if err != nil { + // If decoding fails, use original fileName + decodedFileName = fileName + } + + // Get file info from database + var filePath string + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + `SELECT file_path FROM job_files WHERE job_id = ? AND file_name = ?`, + jobID, decodedFileName, + ).Scan(&filePath) + }) + if err == sql.ErrNoRows { + s.respondError(w, http.StatusNotFound, "File not found") + return + } + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query file: %v", err)) + return + } + + // Open file + file, err := s.storage.GetFile(filePath) + if err != nil { + s.respondError(w, http.StatusNotFound, "File not found on disk") + return + } + defer file.Close() + + // Determine content type based on file extension + contentType := "application/octet-stream" + fileNameLower := strings.ToLower(decodedFileName) + switch { + case strings.HasSuffix(fileNameLower, ".png"): + contentType = "image/png" + case strings.HasSuffix(fileNameLower, ".jpg") || strings.HasSuffix(fileNameLower, ".jpeg"): + contentType = "image/jpeg" + case strings.HasSuffix(fileNameLower, ".gif"): + contentType = "image/gif" + case strings.HasSuffix(fileNameLower, ".webp"): + contentType = "image/webp" + case strings.HasSuffix(fileNameLower, ".exr") || strings.HasSuffix(fileNameLower, ".EXR"): + contentType = "image/x-exr" + case strings.HasSuffix(fileNameLower, ".mp4"): + contentType = "video/mp4" + case strings.HasSuffix(fileNameLower, ".webm"): + contentType = "video/webm" + } + + // Set headers + w.Header().Set("Content-Type", contentType) + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", decodedFileName)) + + // Stream file + io.Copy(w, file) +} + +// WebSocket message types +type WSMessage struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + Timestamp int64 `json:"timestamp"` +} + +type WSTaskAssignment struct { + TaskID int64 `json:"task_id"` + JobID int64 `json:"job_id"` + JobName string `json:"job_name"` + OutputFormat string `json:"output_format"` + Frame int `json:"frame"` + TaskType string `json:"task_type"` + InputFiles []string `json:"input_files"` +} + +type WSLogEntry struct { + TaskID int64 `json:"task_id"` + LogLevel string `json:"log_level"` + Message string `json:"message"` + StepName string `json:"step_name,omitempty"` +} + +type WSTaskUpdate struct { + TaskID int64 `json:"task_id"` + Status string `json:"status"` + OutputPath string `json:"output_path,omitempty"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` +} + +// handleRunnerJobWebSocket handles per-job WebSocket connections from runners +// WS /api/runner/jobs/:job_id/ws +func (s *Manager) handleRunnerJobWebSocket(w http.ResponseWriter, r *http.Request) { + // Get job ID from URL path + jobIDStr := chi.URLParam(r, "jobId") + if jobIDStr == "" { + s.respondError(w, http.StatusBadRequest, "job ID required") + return + } + var jobID int64 + if _, err := fmt.Sscanf(jobIDStr, "%d", &jobID); err != nil { + s.respondError(w, http.StatusBadRequest, "invalid job ID") + return + } + + // Upgrade to WebSocket + conn, err := s.wsUpgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("Failed to upgrade job WebSocket: %v", err) + return + } + defer conn.Close() + + // First message must be auth + conn.SetReadDeadline(time.Now().Add(WSPingInterval)) + var authMsg struct { + Type string `json:"type"` + JobToken string `json:"job_token"` + } + if err := conn.ReadJSON(&authMsg); err != nil { + log.Printf("Job WebSocket auth read error: %v", err) + conn.WriteJSON(map[string]string{"type": "error", "message": "failed to read auth message"}) + return + } + if authMsg.Type != "auth" { + conn.WriteJSON(map[string]string{"type": "error", "message": "first message must be auth"}) + return + } + + // Validate job token + claims, err := auth.ValidateJobToken(authMsg.JobToken) + if err != nil { + log.Printf("Job WebSocket invalid token: %v", err) + conn.WriteJSON(map[string]string{"type": "error", "message": fmt.Sprintf("invalid job token: %v", err)}) + return + } + + // Verify job ID matches + if claims.JobID != jobID { + conn.WriteJSON(map[string]string{"type": "error", "message": "job ID mismatch"}) + return + } + + runnerID := claims.RunnerID + taskID := claims.TaskID + + // Verify task is still assigned to this runner + var taskRunnerID sql.NullInt64 + var taskStatus string + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT runner_id, status FROM tasks WHERE id = ?", taskID).Scan(&taskRunnerID, &taskStatus) + }) + if err != nil { + conn.WriteJSON(map[string]string{"type": "error", "message": "task not found"}) + return + } + if !taskRunnerID.Valid || taskRunnerID.Int64 != runnerID { + conn.WriteJSON(map[string]string{"type": "error", "message": "task not assigned to this runner"}) + return + } + + // Send auth_ok + if err := conn.WriteJSON(map[string]string{"type": "auth_ok"}); err != nil { + log.Printf("Failed to send auth_ok: %v", err) + return + } + + log.Printf("Job WebSocket authenticated: job=%d, runner=%d, task=%d", jobID, runnerID, taskID) + + // Track this connection for the task + connKey := fmt.Sprintf("job-%d-task-%d", jobID, taskID) + var writeMu sync.Mutex + + // Store connection for potential server->runner messages + s.runnerJobConnsMu.Lock() + s.runnerJobConns[connKey] = conn + s.runnerJobConnsWriteMu[connKey] = &writeMu + s.runnerJobConnsMu.Unlock() + + // Cleanup on disconnect + defer func() { + s.runnerJobConnsMu.Lock() + delete(s.runnerJobConns, connKey) + delete(s.runnerJobConnsWriteMu, connKey) + s.runnerJobConnsMu.Unlock() + + // Check if task is still running - if so, mark as failed + var currentStatus string + s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT status FROM tasks WHERE id = ?", taskID).Scan(¤tStatus) + }) + if currentStatus == string(types.TaskStatusRunning) { + log.Printf("Job WebSocket disconnected unexpectedly for task %d, marking as failed", taskID) + s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + `UPDATE tasks SET status = ?, error_message = ?, completed_at = ? WHERE id = ?`, + types.TaskStatusFailed, "WebSocket connection lost", time.Now(), taskID, + ) + return err + }) + s.broadcastTaskUpdate(jobID, taskID, "task_update", map[string]interface{}{ + "status": types.TaskStatusFailed, + "error_message": "WebSocket connection lost", + }) + s.updateJobStatusFromTasks(jobID) + } + + log.Printf("Job WebSocket closed: job=%d, runner=%d, task=%d", jobID, runnerID, taskID) + }() + + // Set up ping/pong keepalive + conn.SetPongHandler(func(string) error { + conn.SetReadDeadline(time.Now().Add(WSReadDeadline)) + return nil + }) + + // Send pings periodically + go func() { + ticker := time.NewTicker(WSPingInterval) + defer ticker.Stop() + for range ticker.C { + s.runnerJobConnsMu.RLock() + currentConn, exists := s.runnerJobConns[connKey] + mu, hasMu := s.runnerJobConnsWriteMu[connKey] + s.runnerJobConnsMu.RUnlock() + if !exists || currentConn != conn || !hasMu { + return + } + mu.Lock() + err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(WSWriteDeadline)) + mu.Unlock() + if err != nil { + return + } + } + }() + + // Handle incoming messages + for { + conn.SetReadDeadline(time.Now().Add(WSReadDeadline)) + + var msg WSMessage + err := conn.ReadJSON(&msg) + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) { + log.Printf("Job WebSocket error for task %d: %v", taskID, err) + } + break + } + + switch msg.Type { + case "log_entry": + var logEntry WSLogEntry + if err := json.Unmarshal(msg.Data, &logEntry); err == nil { + // Verify task ID matches + if logEntry.TaskID == taskID { + s.handleWebSocketLog(runnerID, logEntry) + } + } + + case "progress": + var progress struct { + TaskID int64 `json:"task_id"` + Progress float64 `json:"progress"` + } + if err := json.Unmarshal(msg.Data, &progress); err == nil { + if progress.TaskID == taskID { + // Broadcast progress update + s.broadcastTaskUpdate(jobID, taskID, "progress", map[string]interface{}{ + "progress": progress.Progress, + }) + } + } + + case "output_uploaded": + var output struct { + TaskID int64 `json:"task_id"` + FileName string `json:"file_name"` + } + if err := json.Unmarshal(msg.Data, &output); err == nil { + if output.TaskID == taskID { + log.Printf("Task %d uploaded output: %s", taskID, output.FileName) + // Broadcast file upload notification + s.broadcastJobUpdate(jobID, "file_uploaded", map[string]interface{}{ + "task_id": taskID, + "file_name": output.FileName, + }) + } + } + + case "task_complete": + var taskUpdate WSTaskUpdate + if err := json.Unmarshal(msg.Data, &taskUpdate); err == nil { + if taskUpdate.TaskID == taskID { + s.handleWebSocketTaskComplete(runnerID, taskUpdate) + // Task is done, close connection + return + } + } + case "runner_heartbeat": + // Lookup runner ID from job's assigned_runner_id + var assignedRunnerID sql.NullInt64 + err := s.db.With(func(db *sql.DB) error { + return db.QueryRow( + "SELECT assigned_runner_id FROM jobs WHERE id = ?", + jobID, + ).Scan(&assignedRunnerID) + }) + if err != nil { + log.Printf("Failed to lookup runner for job %d heartbeat: %v", jobID, err) + // Send error response + response := map[string]interface{}{ + "type": "error", + "message": "Failed to process heartbeat", + } + s.sendWebSocketMessage(conn, response) + continue + } + + if !assignedRunnerID.Valid { + log.Printf("Job %d has no assigned runner, skipping heartbeat update", jobID) + // Send acknowledgment but no database update + response := map[string]interface{}{ + "type": "heartbeat_ack", + "timestamp": time.Now().Unix(), + "message": "No assigned runner for this job", + } + s.sendWebSocketMessage(conn, response) + continue + } + + runnerID := assignedRunnerID.Int64 + + // Update runner heartbeat + err = s.db.With(func(db *sql.DB) error { + _, err := db.Exec( + "UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?", + time.Now(), types.RunnerStatusOnline, runnerID, + ) + return err + }) + if err != nil { + log.Printf("Failed to update runner %d heartbeat for job %d: %v", runnerID, jobID, err) + // Send error response + response := map[string]interface{}{ + "type": "error", + "message": "Failed to update heartbeat", + } + s.sendWebSocketMessage(conn, response) + continue + } + + // Send acknowledgment + response := map[string]interface{}{ + "type": "heartbeat_ack", + "timestamp": time.Now().Unix(), + } + s.sendWebSocketMessage(conn, response) + + continue + } + } +} + +// handleWebSocketLog handles log entries from WebSocket +func (s *Manager) handleWebSocketLog(runnerID int64, logEntry WSLogEntry) { + // Store log in database + err := s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + `INSERT INTO task_logs (task_id, runner_id, log_level, message, step_name, created_at) + VALUES (?, ?, ?, ?, ?, ?)`, + logEntry.TaskID, runnerID, logEntry.LogLevel, logEntry.Message, logEntry.StepName, time.Now(), + ) + return err + }) + if err != nil { + log.Printf("Failed to store log: %v", err) + return + } + + // Broadcast to frontend clients + s.broadcastLogToFrontend(logEntry.TaskID, logEntry) + + // If this log contains a frame number (Fra:), update progress for single-runner render jobs + if strings.Contains(logEntry.Message, "Fra:") { + // Get job ID from task + var jobID int64 + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT job_id FROM tasks WHERE id = ?", logEntry.TaskID).Scan(&jobID) + }) + if err == nil { + // Throttle progress updates (max once per 2 seconds per job) + s.progressUpdateTimesMu.RLock() + lastUpdate, exists := s.progressUpdateTimes[jobID] + s.progressUpdateTimesMu.RUnlock() + + shouldUpdate := !exists || time.Since(lastUpdate) >= ProgressUpdateThrottle + if shouldUpdate { + s.progressUpdateTimesMu.Lock() + s.progressUpdateTimes[jobID] = time.Now() + s.progressUpdateTimesMu.Unlock() + + // Update progress in background to avoid blocking log processing + go s.updateJobStatusFromTasks(jobID) + } + } + } +} + +// handleWebSocketTaskUpdate handles task status updates from WebSocket +func (s *Manager) handleWebSocketTaskUpdate(runnerID int64, taskUpdate WSTaskUpdate) { + // This can be used for progress updates + // For now, we'll just log it + log.Printf("Task %d update from runner %d: %s", taskUpdate.TaskID, runnerID, taskUpdate.Status) +} + +// handleWebSocketTaskComplete handles task completion from WebSocket +func (s *Manager) handleWebSocketTaskComplete(runnerID int64, taskUpdate WSTaskUpdate) { + // Verify task belongs to runner and get task info + var taskRunnerID sql.NullInt64 + var jobID int64 + var retryCount, maxRetries int + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + "SELECT runner_id, job_id, retry_count, max_retries FROM tasks WHERE id = ?", + taskUpdate.TaskID, + ).Scan(&taskRunnerID, &jobID, &retryCount, &maxRetries) + }) + if err != nil { + log.Printf("Failed to get task %d info: %v", taskUpdate.TaskID, err) + return + } + if !taskRunnerID.Valid || taskRunnerID.Int64 != runnerID { + log.Printf("Task %d does not belong to runner %d", taskUpdate.TaskID, runnerID) + return + } + + now := time.Now() + + // Handle successful completion + if taskUpdate.Success { + err = s.db.WithTx(func(tx *sql.Tx) error { + _, err := tx.Exec(`UPDATE tasks SET status = ? WHERE id = ?`, types.TaskStatusCompleted, taskUpdate.TaskID) + if err != nil { + return err + } + if taskUpdate.OutputPath != "" { + _, err = tx.Exec(`UPDATE tasks SET output_path = ? WHERE id = ?`, taskUpdate.OutputPath, taskUpdate.TaskID) + if err != nil { + return err + } + } + _, err = tx.Exec(`UPDATE tasks SET completed_at = ? WHERE id = ?`, now, taskUpdate.TaskID) + return err + }) + if err != nil { + log.Printf("Failed to update task %d: %v", taskUpdate.TaskID, err) + return + } + + // Broadcast task update + s.broadcastTaskUpdate(jobID, taskUpdate.TaskID, "task_update", map[string]interface{}{ + "status": types.TaskStatusCompleted, + "output_path": taskUpdate.OutputPath, + "completed_at": now, + }) + s.updateJobStatusFromTasks(jobID) + return + } + + // Handle task failure - this is an actual task failure (e.g., Blender crash) + // Check if we have retries remaining + if retryCount < maxRetries { + // Reset to pending for retry - increment retry_count + err = s.db.WithTx(func(tx *sql.Tx) error { + _, err := tx.Exec( + `UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL, + retry_count = retry_count + 1, started_at = NULL, completed_at = NULL + WHERE id = ?`, + types.TaskStatusPending, taskUpdate.TaskID, + ) + if err != nil { + return err + } + // Clear steps and logs for fresh retry + _, err = tx.Exec(`DELETE FROM task_steps WHERE task_id = ?`, taskUpdate.TaskID) + if err != nil { + return err + } + _, err = tx.Exec(`DELETE FROM task_logs WHERE task_id = ?`, taskUpdate.TaskID) + return err + }) + if err != nil { + log.Printf("Failed to reset task %d for retry: %v", taskUpdate.TaskID, err) + return + } + + // Broadcast task reset to clients (includes steps_cleared and logs_cleared flags) + s.broadcastTaskUpdate(jobID, taskUpdate.TaskID, "task_reset", map[string]interface{}{ + "status": types.TaskStatusPending, + "retry_count": retryCount + 1, + "error_message": taskUpdate.Error, + "steps_cleared": true, + "logs_cleared": true, + }) + + log.Printf("Task %d failed but has retries remaining (%d/%d), reset to pending", taskUpdate.TaskID, retryCount+1, maxRetries) + } else { + // No retries remaining - mark as failed + err = s.db.WithTx(func(tx *sql.Tx) error { + _, err := tx.Exec(`UPDATE tasks SET status = ? WHERE id = ?`, types.TaskStatusFailed, taskUpdate.TaskID) + if err != nil { + return err + } + _, err = tx.Exec(`UPDATE tasks SET completed_at = ? WHERE id = ?`, now, taskUpdate.TaskID) + if err != nil { + return err + } + if taskUpdate.Error != "" { + _, err = tx.Exec(`UPDATE tasks SET error_message = ? WHERE id = ?`, taskUpdate.Error, taskUpdate.TaskID) + if err != nil { + return err + } + } + return nil + }) + if err != nil { + log.Printf("Failed to mark task %d as failed: %v", taskUpdate.TaskID, err) + return + } + + // Log the final failure + s.logTaskEvent(taskUpdate.TaskID, &runnerID, types.LogLevelError, + fmt.Sprintf("Task failed permanently after %d retries: %s", maxRetries, taskUpdate.Error), "") + + // Broadcast task update + s.broadcastTaskUpdate(jobID, taskUpdate.TaskID, "task_update", map[string]interface{}{ + "status": types.TaskStatusFailed, + "completed_at": now, + "error_message": taskUpdate.Error, + }) + + log.Printf("Task %d failed permanently after %d retries", taskUpdate.TaskID, maxRetries) + } + + // Update job status and progress + s.updateJobStatusFromTasks(jobID) +} + +// parseBlenderFrame extracts the current frame number from Blender log messages +// Looks for patterns like "Fra:2470" in log messages +func parseBlenderFrame(logMessage string) (int, bool) { + // Look for "Fra:" followed by digits + // Pattern: "Fra:2470" or "Fra: 2470" or similar variations + fraIndex := strings.Index(logMessage, "Fra:") + if fraIndex == -1 { + return 0, false + } + + // Find the number after "Fra:" + start := fraIndex + 4 // Skip "Fra:" + // Skip whitespace + for start < len(logMessage) && (logMessage[start] == ' ' || logMessage[start] == '\t') { + start++ + } + + // Extract digits + end := start + for end < len(logMessage) && logMessage[end] >= '0' && logMessage[end] <= '9' { + end++ + } + + if end > start { + frame, err := strconv.Atoi(logMessage[start:end]) + if err == nil { + return frame, true + } + } + + return 0, false +} + +// getCurrentFrameFromLogs gets the highest frame number found in logs for a job's render tasks +func (s *Manager) getCurrentFrameFromLogs(jobID int64) (int, bool) { + // Get all render tasks for this job + var rows *sql.Rows + err := s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query( + `SELECT id FROM tasks WHERE job_id = ? AND task_type = ? AND status = ?`, + jobID, types.TaskTypeRender, types.TaskStatusRunning, + ) + return err + }) + if err != nil { + return 0, false + } + defer rows.Close() + + maxFrame := 0 + found := false + + for rows.Next() { + var taskID int64 + if err := rows.Scan(&taskID); err != nil { + log.Printf("Failed to scan task ID in getCurrentFrameFromLogs: %v", err) + continue + } + + // Get the most recent log entries for this task (last 100 to avoid scanning all logs) + var logRows *sql.Rows + err := s.db.With(func(conn *sql.DB) error { + var err error + logRows, err = conn.Query( + `SELECT message FROM task_logs + WHERE task_id = ? AND message LIKE '%Fra:%' + ORDER BY id DESC LIMIT 100`, + taskID, + ) + return err + }) + if err != nil { + continue + } + + for logRows.Next() { + var message string + if err := logRows.Scan(&message); err != nil { + continue + } + + if frame, ok := parseBlenderFrame(message); ok { + if frame > maxFrame { + maxFrame = frame + found = true + } + } + } + logRows.Close() + } + + return maxFrame, found +} + +// resetFailedTasksAndRedistribute resets all failed tasks for a job to pending and redistributes them +func (s *Manager) resetFailedTasksAndRedistribute(jobID int64) error { + // Reset all failed tasks to pending and clear their retry_count + err := s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + `UPDATE tasks SET status = ?, retry_count = 0, runner_id = NULL, started_at = NULL, completed_at = NULL, error_message = NULL + WHERE job_id = ? AND status = ?`, + types.TaskStatusPending, jobID, types.TaskStatusFailed, + ) + if err != nil { + return fmt.Errorf("failed to reset failed tasks: %v", err) + } + + // Increment job retry_count + _, err = conn.Exec( + `UPDATE jobs SET retry_count = retry_count + 1 WHERE id = ?`, + jobID, + ) + if err != nil { + return fmt.Errorf("failed to increment job retry_count: %v", err) + } + return nil + }) + if err != nil { + return err + } + + log.Printf("Reset failed tasks for job %d and incremented retry_count", jobID) + + return nil +} + +// cancelActiveTasksForJob cancels all active (pending or running) tasks for a job +func (s *Manager) cancelActiveTasksForJob(jobID int64) error { + // Tasks don't have a cancelled status - mark them as failed instead + err := s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + `UPDATE tasks SET status = ?, error_message = ? WHERE job_id = ? AND status IN (?, ?)`, + types.TaskStatusFailed, "Job cancelled", jobID, types.TaskStatusPending, types.TaskStatusRunning, + ) + if err != nil { + return fmt.Errorf("failed to cancel active tasks: %v", err) + } + return nil + }) + if err != nil { + return err + } + + log.Printf("Cancelled all active tasks for job %d", jobID) + return nil +} + +// evaluateTaskCondition checks if a task's condition is met +// Returns true if the task can be assigned, false otherwise +func (s *Manager) evaluateTaskCondition(taskID int64, jobID int64, conditionJSON string) bool { + if conditionJSON == "" { + // No condition means task can always be assigned + return true + } + + var condition map[string]interface{} + if err := json.Unmarshal([]byte(conditionJSON), &condition); err != nil { + log.Printf("Failed to parse condition for task %d: %v", taskID, err) + // If we can't parse the condition, err on the side of caution and don't assign + return false + } + + conditionType, ok := condition["type"].(string) + if !ok { + log.Printf("Invalid condition format for task %d: missing type", taskID) + return false + } + + switch conditionType { + case "all_render_tasks_completed": + // Check if all render tasks for this job are completed + var totalRenderTasks, completedRenderTasks int + err := s.db.With(func(conn *sql.DB) error { + conn.QueryRow( + `SELECT COUNT(*) FROM tasks + WHERE job_id = ? AND task_type = ? AND status IN (?, ?, ?)`, + jobID, types.TaskTypeRender, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, + ).Scan(&totalRenderTasks) + conn.QueryRow( + `SELECT COUNT(*) FROM tasks + WHERE job_id = ? AND task_type = ? AND status = ?`, + jobID, types.TaskTypeRender, types.TaskStatusCompleted, + ).Scan(&completedRenderTasks) + return nil + }) + if err != nil { + log.Printf("Failed to check render task completion for task %d: %v", taskID, err) + return false + } + return totalRenderTasks > 0 && completedRenderTasks == totalRenderTasks + + default: + log.Printf("Unknown condition type '%s' for task %d", conditionType, taskID) + return false + } +} + +// updateJobStatusFromTasks updates job status and progress based on task states +func (s *Manager) updateJobStatusFromTasks(jobID int64) { + now := time.Now() + + // All jobs now use parallel runners (one task per frame), so we always use task-based progress + + // Get current job status to detect changes + var currentStatus string + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow(`SELECT status FROM jobs WHERE id = ?`, jobID).Scan(¤tStatus) + }) + if err != nil { + log.Printf("Failed to get current job status for job %d: %v", jobID, err) + return + } + + // Count total tasks and completed tasks + var totalTasks, completedTasks int + err = s.db.With(func(conn *sql.DB) error { + err := conn.QueryRow( + `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?, ?, ?)`, + jobID, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed, + ).Scan(&totalTasks) + if err != nil { + return err + } + return conn.QueryRow( + `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`, + jobID, types.TaskStatusCompleted, + ).Scan(&completedTasks) + }) + if err != nil { + log.Printf("Failed to count completed tasks for job %d: %v", jobID, err) + return + } + + // Calculate progress + var progress float64 + if totalTasks == 0 { + // All tasks cancelled or no tasks, set progress to 0 + progress = 0.0 + } else { + // Standard task-based progress + progress = float64(completedTasks) / float64(totalTasks) * 100.0 + } + + var jobStatus string + + // Check if all non-cancelled tasks are completed + var pendingOrRunningTasks int + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + `SELECT COUNT(*) FROM tasks + WHERE job_id = ? AND status IN (?, ?)`, + jobID, types.TaskStatusPending, types.TaskStatusRunning, + ).Scan(&pendingOrRunningTasks) + }) + if err != nil { + log.Printf("Failed to count pending/running tasks for job %d: %v", jobID, err) + return + } + + if pendingOrRunningTasks == 0 && totalTasks > 0 { + // All tasks are either completed or failed/cancelled + // Check if any tasks failed + var failedTasks int + s.db.With(func(conn *sql.DB) error { + conn.QueryRow( + `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`, + jobID, types.TaskStatusFailed, + ).Scan(&failedTasks) + return nil + }) + + if failedTasks > 0 { + // Some tasks failed - check if job has retries left + var retryCount, maxRetries int + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + `SELECT retry_count, max_retries FROM jobs WHERE id = ?`, + jobID, + ).Scan(&retryCount, &maxRetries) + }) + if err != nil { + log.Printf("Failed to get retry info for job %d: %v", jobID, err) + // Fall back to marking job as failed + jobStatus = string(types.JobStatusFailed) + } else if retryCount < maxRetries { + // Job has retries left - reset failed tasks and redistribute + if err := s.resetFailedTasksAndRedistribute(jobID); err != nil { + log.Printf("Failed to reset failed tasks for job %d: %v", jobID, err) + // If reset fails, mark job as failed + jobStatus = string(types.JobStatusFailed) + } else { + // Tasks reset successfully - job remains in running/pending state + // Don't update job status, just update progress + jobStatus = currentStatus // Keep current status + // Recalculate progress after reset (failed tasks are now pending again) + var newTotalTasks, newCompletedTasks int + s.db.With(func(conn *sql.DB) error { + conn.QueryRow( + `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?, ?, ?)`, + jobID, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed, + ).Scan(&newTotalTasks) + conn.QueryRow( + `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`, + jobID, types.TaskStatusCompleted, + ).Scan(&newCompletedTasks) + return nil + }) + if newTotalTasks > 0 { + progress = float64(newCompletedTasks) / float64(newTotalTasks) * 100.0 + } + // Update progress only + err := s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + `UPDATE jobs SET progress = ? WHERE id = ?`, + progress, jobID, + ) + return err + }) + if err != nil { + log.Printf("Failed to update job %d progress: %v", jobID, err) + } else { + // Broadcast job update via WebSocket + s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{ + "status": jobStatus, + "progress": progress, + }) + } + return // Exit early since we've handled the retry + } + } else { + // No retries left - mark job as failed and cancel active tasks + jobStatus = string(types.JobStatusFailed) + if err := s.cancelActiveTasksForJob(jobID); err != nil { + log.Printf("Failed to cancel active tasks for job %d: %v", jobID, err) + } + } + } else { + // All tasks completed successfully + jobStatus = string(types.JobStatusCompleted) + progress = 100.0 // Ensure progress is 100% when all tasks complete + } + + // Update job status (if we didn't return early from retry logic) + if jobStatus != "" { + err := s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + `UPDATE jobs SET status = ?, progress = ?, completed_at = ? WHERE id = ?`, + jobStatus, progress, now, jobID, + ) + return err + }) + if err != nil { + log.Printf("Failed to update job %d status to %s: %v", jobID, jobStatus, err) + } else { + // Only log if status actually changed + if currentStatus != jobStatus { + log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed tasks: %d/%d)", jobID, currentStatus, jobStatus, progress, completedTasks, totalTasks) + } + // Broadcast job update via WebSocket + s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{ + "status": jobStatus, + "progress": progress, + "completed_at": now, + }) + } + } + + // Encode tasks are now created immediately when the job is created + // with a condition that prevents assignment until all render tasks are completed. + // No need to create them here anymore. + } else { + // Job has pending or running tasks - determine if it's running or still pending + var runningTasks int + s.db.With(func(conn *sql.DB) error { + conn.QueryRow( + `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`, + jobID, types.TaskStatusRunning, + ).Scan(&runningTasks) + return nil + }) + + if runningTasks > 0 { + // Has running tasks - job is running + jobStatus = string(types.JobStatusRunning) + var startedAt sql.NullTime + s.db.With(func(conn *sql.DB) error { + conn.QueryRow(`SELECT started_at FROM jobs WHERE id = ?`, jobID).Scan(&startedAt) + if !startedAt.Valid { + conn.Exec(`UPDATE jobs SET started_at = ? WHERE id = ?`, now, jobID) + } + return nil + }) + } else { + // All tasks are pending - job is pending + jobStatus = string(types.JobStatusPending) + } + + err := s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + `UPDATE jobs SET status = ?, progress = ? WHERE id = ?`, + jobStatus, progress, jobID, + ) + return err + }) + if err != nil { + log.Printf("Failed to update job %d status to %s: %v", jobID, jobStatus, err) + } else { + // Only log if status actually changed + if currentStatus != jobStatus { + log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed: %d/%d, pending: %d, running: %d)", jobID, currentStatus, jobStatus, progress, completedTasks, totalTasks, pendingOrRunningTasks-runningTasks, runningTasks) + } + // Broadcast job update during execution (not just on completion) + s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{ + "status": jobStatus, + "progress": progress, + }) + } + } +} + +// broadcastLogToFrontend broadcasts log to connected frontend clients +func (s *Manager) broadcastLogToFrontend(taskID int64, logEntry WSLogEntry) { + // Get job_id, user_id, and task status from task + var jobID, userID int64 + var taskStatus string + var taskRunnerID sql.NullInt64 + var taskStartedAt sql.NullTime + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + `SELECT t.job_id, j.user_id, t.status, t.runner_id, t.started_at + FROM tasks t + JOIN jobs j ON t.job_id = j.id + WHERE t.id = ?`, + taskID, + ).Scan(&jobID, &userID, &taskStatus, &taskRunnerID, &taskStartedAt) + }) + if err != nil { + return + } + + // Get full log entry from database for consistency + // Use a more reliable query that gets the most recent log with matching message + // This avoids race conditions with concurrent inserts + var taskLog types.TaskLog + var runnerID sql.NullInt64 + var stepName sql.NullString + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + `SELECT id, task_id, runner_id, log_level, message, step_name, created_at + FROM task_logs WHERE task_id = ? AND message = ? ORDER BY id DESC LIMIT 1`, + taskID, logEntry.Message, + ).Scan(&taskLog.ID, &taskLog.TaskID, &runnerID, &taskLog.LogLevel, &taskLog.Message, &stepName, &taskLog.CreatedAt) + }) + if err != nil { + return + } + if runnerID.Valid { + taskLog.RunnerID = &runnerID.Int64 + } + if stepName.Valid { + taskLog.StepName = stepName.String + } + + msg := map[string]interface{}{ + "type": "log", + "task_id": taskID, + "job_id": jobID, + "data": taskLog, + "timestamp": time.Now().Unix(), + } + + // Only broadcast if client is connected + if !s.isClientConnected(userID) { + if s.verboseWSLogging { + log.Printf("broadcastLogToFrontend: Client %d not connected, skipping log broadcast for task %d (job %d)", userID, taskID, jobID) + } + // Still broadcast to old WebSocket connections for backwards compatibility + } else { + // Broadcast to client WebSocket if subscribed to logs:{jobId}:{taskId} + channel := fmt.Sprintf("logs:%d:%d", jobID, taskID) + if s.verboseWSLogging { + runnerIDStr := "none" + if taskRunnerID.Valid { + runnerIDStr = fmt.Sprintf("%d", taskRunnerID.Int64) + } + log.Printf("broadcastLogToFrontend: Broadcasting log for task %d (job %d, user %d) on channel %s, log_id=%d, task_status=%s, runner_id=%s", taskID, jobID, userID, channel, taskLog.ID, taskStatus, runnerIDStr) + } + s.broadcastToClient(userID, channel, msg) + } + + // If task status is pending but logs are coming in, log a warning + // This indicates the initial assignment broadcast may have been missed or the database update failed + if taskStatus == string(types.TaskStatusPending) { + log.Printf("broadcastLogToFrontend: ERROR - Task %d has logs but status is 'pending'. This indicates the initial task assignment failed or the task_update broadcast was missed.", taskID) + } + + // Also broadcast to old WebSocket connection (for backwards compatibility during migration) + key := fmt.Sprintf("%d:%d", jobID, taskID) + s.frontendConnsMu.RLock() + conn, exists := s.frontendConns[key] + s.frontendConnsMu.RUnlock() + + if exists && conn != nil { + // Serialize writes to prevent concurrent write panics + s.frontendConnsWriteMuMu.RLock() + writeMu, hasMu := s.frontendConnsWriteMu[key] + s.frontendConnsWriteMuMu.RUnlock() + + if hasMu && writeMu != nil { + writeMu.Lock() + conn.WriteJSON(msg) + writeMu.Unlock() + } else { + // Fallback if mutex doesn't exist yet (shouldn't happen, but be safe) + conn.WriteJSON(msg) + } + } +} + +// resetRunnerTasks resets tasks assigned to a disconnected/dead runner +// In the polling model, tasks are picked up by runners when they poll +func (s *Manager) resetRunnerTasks(runnerID int64) { + log.Printf("Resetting tasks for disconnected runner %d", runnerID) + + // Find running tasks assigned to this runner (exclude completed/failed for safety) + var taskRows *sql.Rows + err := s.db.With(func(conn *sql.DB) error { + var err error + taskRows, err = conn.Query( + `SELECT id, job_id FROM tasks + WHERE runner_id = ? AND status = ? + AND (completed_at IS NULL OR completed_at < datetime('now', '-30 seconds'))`, + runnerID, types.TaskStatusRunning, + ) + return err + }) + if err != nil { + log.Printf("Failed to query tasks for runner %d: %v", runnerID, err) + return + } + defer taskRows.Close() + + var tasksToReset []struct { + ID int64 + JobID int64 + } + + for taskRows.Next() { + var t struct { + ID int64 + JobID int64 + } + if err := taskRows.Scan(&t.ID, &t.JobID); err != nil { + log.Printf("Failed to scan task for runner %d: %v", runnerID, err) + continue + } + tasksToReset = append(tasksToReset, t) + } + + if len(tasksToReset) == 0 { + log.Printf("No running tasks found for runner %d to redistribute", runnerID) + return + } + + log.Printf("Redistributing %d running tasks from disconnected runner %d", len(tasksToReset), runnerID) + + // Runner disconnections always get retried - increment runner_failure_count for tracking only + // This does NOT count against the task's retry_count (which is for actual task failures like Blender crashes) + resetCount := 0 + + for _, task := range tasksToReset { + // Always reset to pending - runner failures retry indefinitely + err = s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + `UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL, + runner_failure_count = runner_failure_count + 1, started_at = NULL WHERE id = ? AND runner_id = ?`, + types.TaskStatusPending, task.ID, runnerID, + ) + if err != nil { + return err + } + // Clear steps and logs for fresh retry + _, err = conn.Exec(`DELETE FROM task_steps WHERE task_id = ?`, task.ID) + if err != nil { + return err + } + _, err = conn.Exec(`DELETE FROM task_logs WHERE task_id = ?`, task.ID) + return err + }) + if err != nil { + log.Printf("Failed to reset task %d: %v", task.ID, err) + } else { + resetCount++ + + // Broadcast task reset to clients (includes steps_cleared and logs_cleared flags) + s.broadcastTaskUpdate(task.JobID, task.ID, "task_reset", map[string]interface{}{ + "status": types.TaskStatusPending, + "runner_id": nil, + "current_step": nil, + "started_at": nil, + "steps_cleared": true, + "logs_cleared": true, + }) + } + } + + log.Printf("Task reset complete for runner %d: %d tasks reset for retry", runnerID, resetCount) + + // Update job statuses for affected jobs + jobIDs := make(map[int64]bool) + for _, task := range tasksToReset { + jobIDs[task.JobID] = true + } + + for jobID := range jobIDs { + // Update job status based on remaining tasks + go s.updateJobStatusFromTasks(jobID) + } +} + +// logTaskEvent logs an event to a task's log (manager-side logging) +func (s *Manager) logTaskEvent(taskID int64, runnerID *int64, logLevel types.LogLevel, message, stepName string) { + var runnerIDValue interface{} + if runnerID != nil { + runnerIDValue = *runnerID + } + + err := s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + `INSERT INTO task_logs (task_id, runner_id, log_level, message, step_name, created_at) + VALUES (?, ?, ?, ?, ?, ?)`, + taskID, runnerIDValue, logLevel, message, stepName, time.Now(), + ) + return err + }) + if err != nil { + log.Printf("Failed to log task event for task %d: %v", taskID, err) + return + } + + // Broadcast to frontend if there are connected clients + s.broadcastLogToFrontend(taskID, WSLogEntry{ + TaskID: taskID, + LogLevel: string(logLevel), + Message: message, + StepName: stepName, + }) +} + +// cleanupOldOfflineRunners periodically deletes runners that have been offline for more than 1 month +func (s *Manager) cleanupOldOfflineRunners() { + // Run cleanup every 24 hours + ticker := time.NewTicker(24 * time.Hour) + defer ticker.Stop() + + // Run once immediately on startup + s.cleanupOldOfflineRunnersOnce() + + for range ticker.C { + s.cleanupOldOfflineRunnersOnce() + } +} + +// cleanupOldOfflineRunnersOnce finds and deletes runners that have been offline for more than 1 month +func (s *Manager) cleanupOldOfflineRunnersOnce() { + defer func() { + if r := recover(); r != nil { + log.Printf("Panic in cleanupOldOfflineRunners: %v", r) + } + }() + + // Find runners that: + // 1. Are offline + // 2. Haven't had a heartbeat in over 1 month + var rows *sql.Rows + err := s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query( + `SELECT id, name FROM runners + WHERE status = ? + AND last_heartbeat < datetime('now', '-1 month')`, + types.RunnerStatusOffline, + ) + return err + }) + if err != nil { + log.Printf("Failed to query old offline runners: %v", err) + return + } + defer rows.Close() + + type runnerInfo struct { + ID int64 + Name string + } + var runnersToDelete []runnerInfo + + for rows.Next() { + var info runnerInfo + if err := rows.Scan(&info.ID, &info.Name); err == nil { + runnersToDelete = append(runnersToDelete, info) + } + } + rows.Close() + + if len(runnersToDelete) == 0 { + return + } + + log.Printf("Cleaning up %d old offline runners (offline for more than 1 month)", len(runnersToDelete)) + + // Delete each runner + for _, runner := range runnersToDelete { + // First, check if there are any tasks still assigned to this runner + // If so, reset them to pending before deleting the runner + var assignedTaskCount int + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + `SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status IN (?, ?)`, + runner.ID, types.TaskStatusRunning, types.TaskStatusPending, + ).Scan(&assignedTaskCount) + }) + if err != nil { + log.Printf("Failed to check assigned tasks for runner %d: %v", runner.ID, err) + continue + } + + if assignedTaskCount > 0 { + // Reset any tasks assigned to this runner + log.Printf("Resetting %d tasks assigned to runner %d before deletion", assignedTaskCount, runner.ID) + err = s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + `UPDATE tasks SET runner_id = NULL, status = ? WHERE runner_id = ? AND status IN (?, ?)`, + types.TaskStatusPending, runner.ID, types.TaskStatusRunning, types.TaskStatusPending, + ) + return err + }) + if err != nil { + log.Printf("Failed to reset tasks for runner %d: %v", runner.ID, err) + continue + } + } + + // Delete the runner + err = s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec("DELETE FROM runners WHERE id = ?", runner.ID) + return err + }) + if err != nil { + log.Printf("Failed to delete runner %d (%s): %v", runner.ID, runner.Name, err) + continue + } + + log.Printf("Deleted old offline runner: %d (%s)", runner.ID, runner.Name) + } +} + +// sendWebSocketMessage safely sends a message over a WebSocket connection with write locking +func (s *Manager) sendWebSocketMessage(conn *websocket.Conn, msg interface{}) error { + // For simplicity in the polling model, we'll use a global write mutex + // since we typically have one connection per job/task + s.runnerJobConnsMu.RLock() + defer s.runnerJobConnsMu.RUnlock() + + // Set write deadline + conn.SetWriteDeadline(time.Now().Add(WSWriteDeadline)) + + // Write the message directly - the RWMutex read lock provides basic synchronization + // For production, consider using a per-connection mutex pool + if err := conn.WriteJSON(msg); err != nil { + log.Printf("Failed to send WebSocket message: %v", err) + return err + } + + return nil +} diff --git a/internal/runner/api/jobconn.go b/internal/runner/api/jobconn.go new file mode 100644 index 0000000..426e18b --- /dev/null +++ b/internal/runner/api/jobconn.go @@ -0,0 +1,333 @@ +package api + +import ( + "fmt" + "log" + "strings" + "sync" + "time" + + "jiggablend/pkg/types" + + "github.com/gorilla/websocket" +) + +// JobConnection wraps a WebSocket connection for job communication. +type JobConnection struct { + conn *websocket.Conn + writeMu sync.Mutex + stopPing chan struct{} + stopHeartbeat chan struct{} + isConnected bool + connMu sync.RWMutex +} + +// NewJobConnection creates a new job connection wrapper. +func NewJobConnection() *JobConnection { + return &JobConnection{} +} + +// Connect establishes a WebSocket connection for a job (no runnerID needed). +func (j *JobConnection) Connect(managerURL, jobPath, jobToken string) error { + wsPath := jobPath + "/ws" + wsURL := strings.Replace(managerURL, "http://", "ws://", 1) + wsURL = strings.Replace(wsURL, "https://", "wss://", 1) + wsURL += wsPath + + log.Printf("Connecting to job WebSocket: %s", wsPath) + + dialer := websocket.Dialer{ + HandshakeTimeout: 10 * time.Second, + } + conn, _, err := dialer.Dial(wsURL, nil) + if err != nil { + return fmt.Errorf("failed to connect job WebSocket: %w", err) + } + + j.conn = conn + + // Send auth message + authMsg := map[string]interface{}{ + "type": "auth", + "job_token": jobToken, + } + if err := conn.WriteJSON(authMsg); err != nil { + conn.Close() + return fmt.Errorf("failed to send auth: %w", err) + } + + // Wait for auth_ok + conn.SetReadDeadline(time.Now().Add(30 * time.Second)) + var authResp map[string]string + if err := conn.ReadJSON(&authResp); err != nil { + conn.Close() + return fmt.Errorf("failed to read auth response: %w", err) + } + if authResp["type"] == "error" { + conn.Close() + return fmt.Errorf("auth failed: %s", authResp["message"]) + } + if authResp["type"] != "auth_ok" { + conn.Close() + return fmt.Errorf("unexpected auth response: %s", authResp["type"]) + } + + // Clear read deadline after auth + conn.SetReadDeadline(time.Time{}) + + // Set up ping/pong handler for keepalive + conn.SetPongHandler(func(string) error { + conn.SetReadDeadline(time.Now().Add(90 * time.Second)) + return nil + }) + + // Start ping goroutine + j.stopPing = make(chan struct{}) + j.connMu.Lock() + j.isConnected = true + j.connMu.Unlock() + go j.pingLoop() + + // Start WebSocket heartbeat goroutine + j.stopHeartbeat = make(chan struct{}) + go j.heartbeatLoop() + + return nil +} + +// pingLoop sends periodic pings to keep the WebSocket connection alive. +func (j *JobConnection) pingLoop() { + defer func() { + if rec := recover(); rec != nil { + log.Printf("Ping loop panicked: %v", rec) + } + }() + + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-j.stopPing: + return + case <-ticker.C: + j.writeMu.Lock() + if j.conn != nil { + deadline := time.Now().Add(10 * time.Second) + if err := j.conn.WriteControl(websocket.PingMessage, []byte{}, deadline); err != nil { + log.Printf("Failed to send ping, closing connection: %v", err) + j.connMu.Lock() + j.isConnected = false + if j.conn != nil { + j.conn.Close() + j.conn = nil + } + j.connMu.Unlock() + } + } + j.writeMu.Unlock() + } + } +} + +// Heartbeat sends a heartbeat message over WebSocket to keep runner online. +func (j *JobConnection) Heartbeat() { + if j.conn == nil { + return + } + + j.writeMu.Lock() + defer j.writeMu.Unlock() + + msg := map[string]interface{}{ + "type": "runner_heartbeat", + "timestamp": time.Now().Unix(), + } + + if err := j.conn.WriteJSON(msg); err != nil { + log.Printf("Failed to send WebSocket heartbeat: %v", err) + // Handle connection failure + j.connMu.Lock() + j.isConnected = false + if j.conn != nil { + j.conn.Close() + j.conn = nil + } + j.connMu.Unlock() + } +} + +// heartbeatLoop sends periodic heartbeat messages over WebSocket. +func (j *JobConnection) heartbeatLoop() { + defer func() { + if rec := recover(); rec != nil { + log.Printf("WebSocket heartbeat loop panicked: %v", rec) + } + }() + + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-j.stopHeartbeat: + return + case <-ticker.C: + j.Heartbeat() + } + } +} + +// Close closes the WebSocket connection. +func (j *JobConnection) Close() { + j.connMu.Lock() + j.isConnected = false + j.connMu.Unlock() + + // Stop heartbeat goroutine + if j.stopHeartbeat != nil { + close(j.stopHeartbeat) + j.stopHeartbeat = nil + } + + // Stop ping goroutine + if j.stopPing != nil { + close(j.stopPing) + j.stopPing = nil + } + + if j.conn != nil { + j.conn.Close() + j.conn = nil + } +} + +// IsConnected returns true if the connection is established. +func (j *JobConnection) IsConnected() bool { + j.connMu.RLock() + defer j.connMu.RUnlock() + return j.isConnected && j.conn != nil +} + +// Log sends a log entry to the manager. +func (j *JobConnection) Log(taskID int64, level types.LogLevel, message string) { + if j.conn == nil { + return + } + + j.writeMu.Lock() + defer j.writeMu.Unlock() + + msg := map[string]interface{}{ + "type": "log_entry", + "data": map[string]interface{}{ + "task_id": taskID, + "log_level": string(level), + "message": message, + }, + "timestamp": time.Now().Unix(), + } + if err := j.conn.WriteJSON(msg); err != nil { + log.Printf("Failed to send job log, connection may be broken: %v", err) + // Close the connection on write error + j.connMu.Lock() + j.isConnected = false + if j.conn != nil { + j.conn.Close() + j.conn = nil + } + j.connMu.Unlock() + } +} + +// Progress sends a progress update to the manager. +func (j *JobConnection) Progress(taskID int64, progress float64) { + if j.conn == nil { + return + } + + j.writeMu.Lock() + defer j.writeMu.Unlock() + + msg := map[string]interface{}{ + "type": "progress", + "data": map[string]interface{}{ + "task_id": taskID, + "progress": progress, + }, + "timestamp": time.Now().Unix(), + } + if err := j.conn.WriteJSON(msg); err != nil { + log.Printf("Failed to send job progress, connection may be broken: %v", err) + // Close the connection on write error + j.connMu.Lock() + j.isConnected = false + if j.conn != nil { + j.conn.Close() + j.conn = nil + } + j.connMu.Unlock() + } +} + +// OutputUploaded notifies that an output file was uploaded. +func (j *JobConnection) OutputUploaded(taskID int64, fileName string) { + if j.conn == nil { + return + } + + j.writeMu.Lock() + defer j.writeMu.Unlock() + + msg := map[string]interface{}{ + "type": "output_uploaded", + "data": map[string]interface{}{ + "task_id": taskID, + "file_name": fileName, + }, + "timestamp": time.Now().Unix(), + } + if err := j.conn.WriteJSON(msg); err != nil { + log.Printf("Failed to send output uploaded, connection may be broken: %v", err) + // Close the connection on write error + j.connMu.Lock() + j.isConnected = false + if j.conn != nil { + j.conn.Close() + j.conn = nil + } + j.connMu.Unlock() + } +} + +// Complete sends task completion to the manager. +func (j *JobConnection) Complete(taskID int64, success bool, errorMsg error) { + if j.conn == nil { + log.Printf("Cannot send task complete: WebSocket connection is nil") + return + } + + j.writeMu.Lock() + defer j.writeMu.Unlock() + + msg := map[string]interface{}{ + "type": "task_complete", + "data": map[string]interface{}{ + "task_id": taskID, + "success": success, + "error": errorMsg, + }, + "timestamp": time.Now().Unix(), + } + if err := j.conn.WriteJSON(msg); err != nil { + log.Printf("Failed to send task complete, connection may be broken: %v", err) + // Close the connection on write error + j.connMu.Lock() + j.isConnected = false + if j.conn != nil { + j.conn.Close() + j.conn = nil + } + j.connMu.Unlock() + } +} diff --git a/internal/runner/api/manager.go b/internal/runner/api/manager.go new file mode 100644 index 0000000..9d1c437 --- /dev/null +++ b/internal/runner/api/manager.go @@ -0,0 +1,421 @@ +// Package api provides HTTP and WebSocket communication with the manager server. +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "time" + + "jiggablend/pkg/types" +) + +// ManagerClient handles all HTTP communication with the manager server. +type ManagerClient struct { + baseURL string + apiKey string + runnerID int64 + httpClient *http.Client // Standard timeout for quick requests + longClient *http.Client // No timeout for large file transfers +} + +// NewManagerClient creates a new manager client. +func NewManagerClient(baseURL string) *ManagerClient { + return &ManagerClient{ + baseURL: strings.TrimSuffix(baseURL, "/"), + httpClient: &http.Client{Timeout: 30 * time.Second}, + longClient: &http.Client{Timeout: 0}, // No timeout for large transfers + } +} + +// SetCredentials sets the API key and runner ID after registration. +func (m *ManagerClient) SetCredentials(runnerID int64, apiKey string) { + m.runnerID = runnerID + m.apiKey = apiKey +} + +// GetRunnerID returns the registered runner ID. +func (m *ManagerClient) GetRunnerID() int64 { + return m.runnerID +} + +// GetAPIKey returns the API key. +func (m *ManagerClient) GetAPIKey() string { + return m.apiKey +} + +// GetBaseURL returns the base URL. +func (m *ManagerClient) GetBaseURL() string { + return m.baseURL +} + +// Request performs an authenticated HTTP request with standard timeout. +func (m *ManagerClient) Request(method, path string, body []byte) (*http.Response, error) { + return m.doRequest(method, path, body, m.httpClient) +} + +// RequestLong performs an authenticated HTTP request with no timeout. +// Use for large file uploads/downloads. +func (m *ManagerClient) RequestLong(method, path string, body []byte) (*http.Response, error) { + return m.doRequest(method, path, body, m.longClient) +} + +func (m *ManagerClient) doRequest(method, path string, body []byte, client *http.Client) (*http.Response, error) { + if m.apiKey == "" { + return nil, fmt.Errorf("not authenticated") + } + + fullURL := m.baseURL + path + req, err := http.NewRequest(method, fullURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + req.Header.Set("Authorization", "Bearer "+m.apiKey) + if len(body) > 0 { + req.Header.Set("Content-Type", "application/json") + } + + return client.Do(req) +} + +// RequestWithToken performs an authenticated HTTP request using a specific token. +func (m *ManagerClient) RequestWithToken(method, path, token string, body []byte) (*http.Response, error) { + return m.doRequestWithToken(method, path, token, body, m.httpClient) +} + +// RequestLongWithToken performs a long-running request with a specific token. +func (m *ManagerClient) RequestLongWithToken(method, path, token string, body []byte) (*http.Response, error) { + return m.doRequestWithToken(method, path, token, body, m.longClient) +} + +func (m *ManagerClient) doRequestWithToken(method, path, token string, body []byte, client *http.Client) (*http.Response, error) { + fullURL := m.baseURL + path + req, err := http.NewRequest(method, fullURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + req.Header.Set("Authorization", "Bearer "+token) + if len(body) > 0 { + req.Header.Set("Content-Type", "application/json") + } + + return client.Do(req) +} + +// RegisterRequest is the request body for runner registration. +type RegisterRequest struct { + Name string `json:"name"` + Hostname string `json:"hostname"` + Capabilities string `json:"capabilities"` + APIKey string `json:"api_key"` + Fingerprint string `json:"fingerprint,omitempty"` +} + +// RegisterResponse is the response from runner registration. +type RegisterResponse struct { + ID int64 `json:"id"` +} + +// Register registers the runner with the manager. +func (m *ManagerClient) Register(name, hostname string, capabilities map[string]interface{}, registrationToken, fingerprint string) (int64, error) { + capsJSON, err := json.Marshal(capabilities) + if err != nil { + return 0, fmt.Errorf("failed to marshal capabilities: %w", err) + } + + reqBody := RegisterRequest{ + Name: name, + Hostname: hostname, + Capabilities: string(capsJSON), + APIKey: registrationToken, + } + + // Only send fingerprint for non-fixed API keys + if !strings.HasPrefix(registrationToken, "jk_r0_") { + reqBody.Fingerprint = fingerprint + } + + body, _ := json.Marshal(reqBody) + resp, err := m.httpClient.Post( + m.baseURL+"/api/runner/register", + "application/json", + bytes.NewReader(body), + ) + if err != nil { + return 0, fmt.Errorf("connection error: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + bodyBytes, _ := io.ReadAll(resp.Body) + errorBody := string(bodyBytes) + + // Check for token-related errors (should not retry) + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusBadRequest { + errorLower := strings.ToLower(errorBody) + if strings.Contains(errorLower, "invalid") || + strings.Contains(errorLower, "expired") || + strings.Contains(errorLower, "already used") || + strings.Contains(errorLower, "token") { + return 0, fmt.Errorf("token error: %s", errorBody) + } + } + + return 0, fmt.Errorf("registration failed (status %d): %s", resp.StatusCode, errorBody) + } + + var result RegisterResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return 0, fmt.Errorf("failed to decode response: %w", err) + } + + m.runnerID = result.ID + m.apiKey = registrationToken + + return result.ID, nil +} + +// NextJobResponse represents the response from the next-job endpoint. +type NextJobResponse struct { + JobToken string `json:"job_token"` + JobPath string `json:"job_path"` + Task NextJobTaskInfo `json:"task"` +} + +// NextJobTaskInfo contains task information from the next-job response. +type NextJobTaskInfo struct { + TaskID int64 `json:"task_id"` + JobID int64 `json:"job_id"` + JobName string `json:"job_name"` + Frame int `json:"frame"` + TaskType string `json:"task_type"` + Metadata *types.BlendMetadata `json:"metadata,omitempty"` +} + +// PollNextJob polls the manager for the next available job. +// Returns nil, nil if no job is available. +func (m *ManagerClient) PollNextJob() (*NextJobResponse, error) { + if m.runnerID == 0 || m.apiKey == "" { + return nil, fmt.Errorf("runner not authenticated") + } + + path := fmt.Sprintf("/api/runner/workers/%d/next-job", m.runnerID) + resp, err := m.Request("GET", path, nil) + if err != nil { + return nil, fmt.Errorf("failed to poll for job: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNoContent { + return nil, nil // No job available + } + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("unexpected status %d: %s", resp.StatusCode, string(body)) + } + + var job NextJobResponse + if err := json.NewDecoder(resp.Body).Decode(&job); err != nil { + return nil, fmt.Errorf("failed to decode job response: %w", err) + } + + return &job, nil +} + +// DownloadContext downloads the job context tar file. +func (m *ManagerClient) DownloadContext(contextPath, jobToken string) (io.ReadCloser, error) { + resp, err := m.RequestLongWithToken("GET", contextPath, jobToken, nil) + if err != nil { + return nil, fmt.Errorf("failed to download context: %w", err) + } + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + return nil, fmt.Errorf("context download failed with status %d: %s", resp.StatusCode, string(body)) + } + + return resp.Body, nil +} + +// UploadFile uploads a file to the manager. +func (m *ManagerClient) UploadFile(uploadPath, jobToken, filePath string) error { + file, err := os.Open(filePath) + if err != nil { + return fmt.Errorf("failed to open file: %w", err) + } + defer file.Close() + + // Create multipart form + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + part, err := writer.CreateFormFile("file", filepath.Base(filePath)) + if err != nil { + return fmt.Errorf("failed to create form file: %w", err) + } + if _, err := io.Copy(part, file); err != nil { + return fmt.Errorf("failed to copy file to form: %w", err) + } + writer.Close() + + fullURL := m.baseURL + uploadPath + req, err := http.NewRequest("POST", fullURL, body) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+jobToken) + req.Header.Set("Content-Type", writer.FormDataContentType()) + + resp, err := m.longClient.Do(req) + if err != nil { + return fmt.Errorf("failed to upload file: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("upload failed with status %d: %s", resp.StatusCode, string(respBody)) + } + + return nil +} + +// GetJobMetadata retrieves job metadata from the manager. +func (m *ManagerClient) GetJobMetadata(jobID int64) (*types.BlendMetadata, error) { + path := fmt.Sprintf("/api/runner/jobs/%d/metadata?runner_id=%d", jobID, m.runnerID) + resp, err := m.Request("GET", path, nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + return nil, nil // No metadata found + } + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("failed to get job metadata: %s", string(body)) + } + + var metadata types.BlendMetadata + if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil { + return nil, err + } + + return &metadata, nil +} + +// JobFile represents a file associated with a job. +type JobFile struct { + ID int64 `json:"id"` + JobID int64 `json:"job_id"` + FileType string `json:"file_type"` + FilePath string `json:"file_path"` + FileName string `json:"file_name"` + FileSize int64 `json:"file_size"` +} + +// GetJobFiles retrieves the list of files for a job. +func (m *ManagerClient) GetJobFiles(jobID int64) ([]JobFile, error) { + path := fmt.Sprintf("/api/runner/jobs/%d/files?runner_id=%d", jobID, m.runnerID) + resp, err := m.Request("GET", path, nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("failed to get job files: %s", string(body)) + } + + var files []JobFile + if err := json.NewDecoder(resp.Body).Decode(&files); err != nil { + return nil, err + } + + return files, nil +} + +// DownloadFrame downloads a frame file from the manager. +func (m *ManagerClient) DownloadFrame(jobID int64, fileName, destPath string) error { + encodedFileName := url.PathEscape(fileName) + path := fmt.Sprintf("/api/runner/files/%d/%s?runner_id=%d", jobID, encodedFileName, m.runnerID) + resp, err := m.RequestLong("GET", path, nil) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("download failed: %s", string(body)) + } + + file, err := os.Create(destPath) + if err != nil { + return err + } + defer file.Close() + + _, err = io.Copy(file, resp.Body) + return err +} + +// SubmitMetadata submits extracted metadata to the manager. +func (m *ManagerClient) SubmitMetadata(jobID int64, metadata types.BlendMetadata) error { + metadataJSON, err := json.Marshal(metadata) + if err != nil { + return fmt.Errorf("failed to marshal metadata: %w", err) + } + + path := fmt.Sprintf("/api/runner/jobs/%d/metadata?runner_id=%d", jobID, m.runnerID) + fullURL := m.baseURL + path + req, err := http.NewRequest("POST", fullURL, bytes.NewReader(metadataJSON)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+m.apiKey) + + resp, err := m.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to submit metadata: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("metadata submission failed: %s", string(body)) + } + + return nil +} + +// DownloadBlender downloads a Blender version from the manager. +func (m *ManagerClient) DownloadBlender(version string) (io.ReadCloser, error) { + path := fmt.Sprintf("/api/runner/blender/download?version=%s&runner_id=%d", version, m.runnerID) + resp, err := m.RequestLong("GET", path, nil) + if err != nil { + return nil, fmt.Errorf("failed to download blender from manager: %w", err) + } + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + return nil, fmt.Errorf("failed to download blender: status %d, body: %s", resp.StatusCode, string(body)) + } + + return resp.Body, nil +} diff --git a/internal/runner/blender/binary.go b/internal/runner/blender/binary.go new file mode 100644 index 0000000..d00c958 --- /dev/null +++ b/internal/runner/blender/binary.go @@ -0,0 +1,87 @@ +// Package blender handles Blender binary management and execution. +package blender + +import ( + "fmt" + "log" + "os" + "path/filepath" + + "jiggablend/internal/runner/api" + "jiggablend/internal/runner/workspace" +) + +// Manager handles Blender binary downloads and management. +type Manager struct { + manager *api.ManagerClient + workspaceDir string +} + +// NewManager creates a new Blender manager. +func NewManager(managerClient *api.ManagerClient, workspaceDir string) *Manager { + return &Manager{ + manager: managerClient, + workspaceDir: workspaceDir, + } +} + +// GetBinaryPath returns the path to the Blender binary for a specific version. +// Downloads from manager and extracts if not already present. +func (m *Manager) GetBinaryPath(version string) (string, error) { + blenderDir := filepath.Join(m.workspaceDir, "blender-versions") + if err := os.MkdirAll(blenderDir, 0755); err != nil { + return "", fmt.Errorf("failed to create blender directory: %w", err) + } + + // Check if already installed - look for version folder first + versionDir := filepath.Join(blenderDir, version) + binaryPath := filepath.Join(versionDir, "blender") + + // Check if version folder exists and contains the binary + if versionInfo, err := os.Stat(versionDir); err == nil && versionInfo.IsDir() { + // Version folder exists, check if binary is present + if binaryInfo, err := os.Stat(binaryPath); err == nil { + // Verify it's actually a file (not a directory) + if !binaryInfo.IsDir() { + log.Printf("Found existing Blender %s installation at %s", version, binaryPath) + return binaryPath, nil + } + } + // Version folder exists but binary is missing - might be incomplete installation + log.Printf("Version folder %s exists but binary not found, will re-download", versionDir) + } + + // Download from manager + log.Printf("Downloading Blender %s from manager", version) + + reader, err := m.manager.DownloadBlender(version) + if err != nil { + return "", err + } + defer reader.Close() + + // Manager serves pre-decompressed .tar files - extract directly + log.Printf("Extracting Blender %s...", version) + if err := workspace.ExtractTarStripPrefix(reader, versionDir); err != nil { + return "", fmt.Errorf("failed to extract blender: %w", err) + } + + // Verify binary exists + if _, err := os.Stat(binaryPath); err != nil { + return "", fmt.Errorf("blender binary not found after extraction") + } + + log.Printf("Blender %s installed at %s", version, binaryPath) + return binaryPath, nil +} + +// GetBinaryForJob returns the Blender binary path for a job. +// Uses the version from metadata or falls back to system blender. +func (m *Manager) GetBinaryForJob(version string) (string, error) { + if version == "" { + return "blender", nil // System blender + } + + return m.GetBinaryPath(version) +} + diff --git a/internal/runner/blender/logfilter.go b/internal/runner/blender/logfilter.go new file mode 100644 index 0000000..2347c2f --- /dev/null +++ b/internal/runner/blender/logfilter.go @@ -0,0 +1,100 @@ +package blender + +import ( + "regexp" + "strings" + + "jiggablend/pkg/types" +) + +// FilterLog checks if a Blender log line should be filtered or downgraded. +// Returns (shouldFilter, logLevel) - if shouldFilter is true, the log should be skipped. +func FilterLog(line string) (shouldFilter bool, logLevel types.LogLevel) { + trimmed := strings.TrimSpace(line) + + // Filter out empty lines + if trimmed == "" { + return true, types.LogLevelInfo + } + + // Filter out separator lines + if trimmed == "--------------------------------------------------------------------" || + (strings.HasPrefix(trimmed, "-----") && strings.Contains(trimmed, "----")) { + return true, types.LogLevelInfo + } + + // Filter out trace headers + upperLine := strings.ToUpper(trimmed) + upperOriginal := strings.ToUpper(line) + + if trimmed == "Trace:" || + trimmed == "Depth Type Name" || + trimmed == "----- ---- ----" || + line == "Depth Type Name" || + line == "----- ---- ----" || + (strings.Contains(upperLine, "DEPTH") && strings.Contains(upperLine, "TYPE") && strings.Contains(upperLine, "NAME")) || + (strings.Contains(upperOriginal, "DEPTH") && strings.Contains(upperOriginal, "TYPE") && strings.Contains(upperOriginal, "NAME")) || + strings.Contains(line, "Depth Type Name") || + strings.Contains(line, "----- ---- ----") || + strings.HasPrefix(trimmed, "-----") || + regexp.MustCompile(`^[-]+\s+[-]+\s+[-]+$`).MatchString(trimmed) { + return true, types.LogLevelInfo + } + + // Completely filter out dependency graph messages (they're just noise) + dependencyGraphPatterns := []string{ + "Failed to add relation", + "Could not find op_from", + "OperationKey", + "find_node_operation: Failed for", + "BONE_DONE", + "component name:", + "operation code:", + "rope_ctrl_rot_", + } + + for _, pattern := range dependencyGraphPatterns { + if strings.Contains(line, pattern) { + return true, types.LogLevelInfo + } + } + + // Filter out animation system warnings (invalid drivers are common and harmless) + animationSystemPatterns := []string{ + "BKE_animsys_eval_driver: invalid driver", + "bke.anim_sys", + "rotation_quaternion[", + "constraints[", + ".influence[0]", + "pose.bones[", + } + + for _, pattern := range animationSystemPatterns { + if strings.Contains(line, pattern) { + return true, types.LogLevelInfo + } + } + + // Filter out modifier warnings (common when vertices change) + modifierPatterns := []string{ + "BKE_modifier_set_error", + "bke.modifier", + "Vertices changed from", + "Modifier:", + } + + for _, pattern := range modifierPatterns { + if strings.Contains(line, pattern) { + return true, types.LogLevelInfo + } + } + + // Filter out lines that are just numbers or trace depth indicators + // Pattern: number, word, word (e.g., "1 Object timer_box_franck") + if matched, _ := regexp.MatchString(`^\d+\s+\w+\s+\w+`, trimmed); matched { + return true, types.LogLevelInfo + } + + return false, types.LogLevelInfo +} + diff --git a/internal/runner/blender/version.go b/internal/runner/blender/version.go new file mode 100644 index 0000000..f702b8f --- /dev/null +++ b/internal/runner/blender/version.go @@ -0,0 +1,143 @@ +package blender + +import ( + "compress/gzip" + "fmt" + "io" + "os" + "os/exec" +) + +// ParseVersionFromFile parses the Blender version that a .blend file was saved with. +// Returns major and minor version numbers. +func ParseVersionFromFile(blendPath string) (major, minor int, err error) { + file, err := os.Open(blendPath) + if err != nil { + return 0, 0, fmt.Errorf("failed to open blend file: %w", err) + } + defer file.Close() + + // Read the first 12 bytes of the blend file header + // Format: BLENDER-v or BLENDER_v + // The header is: "BLENDER" (7 bytes) + pointer size (1 byte: '-' for 64-bit, '_' for 32-bit) + // + endianness (1 byte: 'v' for little-endian, 'V' for big-endian) + // + version (3 bytes: e.g., "402" for 4.02) + header := make([]byte, 12) + n, err := file.Read(header) + if err != nil || n < 12 { + return 0, 0, fmt.Errorf("failed to read blend file header: %w", err) + } + + // Check for BLENDER magic + if string(header[:7]) != "BLENDER" { + // Might be compressed - try to decompress + file.Seek(0, 0) + return parseCompressedVersion(file) + } + + // Parse version from bytes 9-11 (3 digits) + versionStr := string(header[9:12]) + + // Version format changed in Blender 3.0 + // Pre-3.0: "279" = 2.79, "280" = 2.80 + // 3.0+: "300" = 3.0, "402" = 4.02, "410" = 4.10 + if len(versionStr) == 3 { + // First digit is major version + fmt.Sscanf(string(versionStr[0]), "%d", &major) + // Next two digits are minor version + fmt.Sscanf(versionStr[1:3], "%d", &minor) + } + + return major, minor, nil +} + +// parseCompressedVersion handles gzip and zstd compressed blend files. +func parseCompressedVersion(file *os.File) (major, minor int, err error) { + magic := make([]byte, 4) + if _, err := file.Read(magic); err != nil { + return 0, 0, err + } + file.Seek(0, 0) + + if magic[0] == 0x1f && magic[1] == 0x8b { + // gzip compressed + gzReader, err := gzip.NewReader(file) + if err != nil { + return 0, 0, fmt.Errorf("failed to create gzip reader: %w", err) + } + defer gzReader.Close() + + header := make([]byte, 12) + n, err := gzReader.Read(header) + if err != nil || n < 12 { + return 0, 0, fmt.Errorf("failed to read compressed blend header: %w", err) + } + + if string(header[:7]) != "BLENDER" { + return 0, 0, fmt.Errorf("invalid blend file format") + } + + versionStr := string(header[9:12]) + if len(versionStr) == 3 { + fmt.Sscanf(string(versionStr[0]), "%d", &major) + fmt.Sscanf(versionStr[1:3], "%d", &minor) + } + + return major, minor, nil + } + + // Check for zstd magic (Blender 3.0+): 0x28 0xB5 0x2F 0xFD + if magic[0] == 0x28 && magic[1] == 0xb5 && magic[2] == 0x2f && magic[3] == 0xfd { + return parseZstdVersion(file) + } + + return 0, 0, fmt.Errorf("unknown blend file format") +} + +// parseZstdVersion handles zstd-compressed blend files (Blender 3.0+). +// Uses zstd command line tool since Go doesn't have native zstd support. +func parseZstdVersion(file *os.File) (major, minor int, err error) { + file.Seek(0, 0) + + cmd := exec.Command("zstd", "-d", "-c") + cmd.Stdin = file + + stdout, err := cmd.StdoutPipe() + if err != nil { + return 0, 0, fmt.Errorf("failed to create zstd stdout pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + return 0, 0, fmt.Errorf("failed to start zstd decompression: %w", err) + } + + // Read just the header (12 bytes) + header := make([]byte, 12) + n, readErr := io.ReadFull(stdout, header) + + // Kill the process early - we only need the header + cmd.Process.Kill() + cmd.Wait() + + if readErr != nil || n < 12 { + return 0, 0, fmt.Errorf("failed to read zstd compressed blend header: %v", readErr) + } + + if string(header[:7]) != "BLENDER" { + return 0, 0, fmt.Errorf("invalid blend file format in zstd archive") + } + + versionStr := string(header[9:12]) + if len(versionStr) == 3 { + fmt.Sscanf(string(versionStr[0]), "%d", &major) + fmt.Sscanf(versionStr[1:3], "%d", &minor) + } + + return major, minor, nil +} + +// VersionString returns a formatted version string like "4.2". +func VersionString(major, minor int) string { + return fmt.Sprintf("%d.%d", major, minor) +} + diff --git a/internal/runner/client.go b/internal/runner/client.go deleted file mode 100644 index eae4160..0000000 --- a/internal/runner/client.go +++ /dev/null @@ -1,3726 +0,0 @@ -package runner - -import ( - "archive/tar" - "bufio" - "bytes" - "crypto/sha256" - _ "embed" - "encoding/hex" - "encoding/json" - "errors" - "fmt" - "io" - "log" - "mime/multipart" - "net" - "net/http" - "net/url" - "os" - "os/exec" - "path/filepath" - "regexp" - "sort" - "strings" - "sync" - "time" - - "jiggablend/pkg/executils" - "jiggablend/pkg/scripts" - "jiggablend/pkg/types" - - "github.com/gorilla/websocket" -) - -// Client represents a runner client -type Client struct { - managerURL string - name string - hostname string - httpClient *http.Client - runnerID int64 - apiKey string // API key for authentication - wsConn *websocket.Conn - wsConnMu sync.RWMutex - wsWriteMu sync.Mutex // Protects concurrent writes to WebSocket (WebSocket is not thread-safe) - stopChan chan struct{} - stepStartTimes map[string]time.Time // key: "taskID:stepName" - stepTimesMu sync.RWMutex - workspaceDir string // Persistent workspace directory for this runner - processTracker *executils.ProcessTracker // Tracks running processes for cleanup - capabilities map[string]interface{} // Cached capabilities from initial probe (includes bools and numbers) - capabilitiesMu sync.RWMutex // Protects capabilities - hwAccelCache map[string]bool // Cached hardware acceleration detection results - hwAccelCacheMu sync.RWMutex // Protects hwAccelCache - vaapiDevices []string // Cached VAAPI device paths (all available devices) - vaapiDevicesMu sync.RWMutex // Protects vaapiDevices - allocatedDevices map[int64]string // map[taskID]device - tracks which device is allocated to which task - allocatedDevicesMu sync.RWMutex // Protects allocatedDevices - longRunningClient *http.Client // HTTP client for long-running operations (no timeout) - fingerprint string // Unique hardware fingerprint for this runner - fingerprintMu sync.RWMutex // Protects fingerprint -} - -// NewClient creates a new runner client -func NewClient(managerURL, name, hostname string) *Client { - client := &Client{ - managerURL: managerURL, - name: name, - hostname: hostname, - httpClient: &http.Client{Timeout: 30 * time.Second}, - longRunningClient: &http.Client{Timeout: 0}, // No timeout for long-running operations (context downloads, file uploads/downloads) - stopChan: make(chan struct{}), - stepStartTimes: make(map[string]time.Time), - processTracker: executils.NewProcessTracker(), - } - // Generate fingerprint immediately - client.generateFingerprint() - return client -} - -// generateFingerprint creates a unique hardware fingerprint for this runner -// This fingerprint should be stable across restarts but unique per physical/virtual machine -func (c *Client) generateFingerprint() { - c.fingerprintMu.Lock() - defer c.fingerprintMu.Unlock() - - // Use a combination of stable hardware identifiers - var components []string - - // Add hostname (stable on most systems) - components = append(components, c.hostname) - - // Try to get machine ID from /etc/machine-id (Linux) - if machineID, err := os.ReadFile("/etc/machine-id"); err == nil { - components = append(components, strings.TrimSpace(string(machineID))) - } - - // Try to get product UUID from /sys/class/dmi/id/product_uuid (Linux) - if productUUID, err := os.ReadFile("/sys/class/dmi/id/product_uuid"); err == nil { - components = append(components, strings.TrimSpace(string(productUUID))) - } - - // Try to get MAC address of first network interface (cross-platform) - if macAddr, err := c.getMACAddress(); err == nil { - components = append(components, macAddr) - } - - // If no stable identifiers found, fall back to hostname + process ID + timestamp - // This is less ideal but ensures uniqueness - if len(components) <= 1 { - components = append(components, fmt.Sprintf("%d", os.Getpid())) - components = append(components, fmt.Sprintf("%d", time.Now().Unix())) - } - - // Create fingerprint by hashing the components - h := sha256.New() - for _, comp := range components { - h.Write([]byte(comp)) - h.Write([]byte{0}) // separator - } - - c.fingerprint = hex.EncodeToString(h.Sum(nil)) -} - -// getMACAddress returns the MAC address of the first non-loopback network interface -func (c *Client) getMACAddress() (string, error) { - interfaces, err := net.Interfaces() - if err != nil { - return "", err - } - - for _, iface := range interfaces { - // Skip loopback and down interfaces - if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 { - continue - } - // Skip interfaces without hardware address - if iface.HardwareAddr == nil || len(iface.HardwareAddr) == 0 { - continue - } - return iface.HardwareAddr.String(), nil - } - - return "", fmt.Errorf("no suitable network interface found") -} - -// GetFingerprint returns the runner's hardware fingerprint -func (c *Client) GetFingerprint() string { - c.fingerprintMu.RLock() - defer c.fingerprintMu.RUnlock() - return c.fingerprint -} - -// SetAPIKey sets the runner ID and API key -func (c *Client) SetAPIKey(runnerID int64, apiKey string) { - c.runnerID = runnerID - c.apiKey = apiKey - - // Initialize runner workspace directory if not already initialized - if c.workspaceDir == "" { - c.initWorkspace() - } -} - -// initWorkspace creates the persistent workspace directory for this runner -func (c *Client) initWorkspace() { - // Use runner name if available, otherwise use runner ID - workspaceName := c.name - if workspaceName == "" { - workspaceName = fmt.Sprintf("runner-%d", c.runnerID) - } - // Sanitize workspace name (remove invalid characters) - workspaceName = strings.ReplaceAll(workspaceName, " ", "_") - workspaceName = strings.ReplaceAll(workspaceName, "/", "_") - workspaceName = strings.ReplaceAll(workspaceName, "\\", "_") - workspaceName = strings.ReplaceAll(workspaceName, ":", "_") - - // Create workspace in a jiggablend directory under temp or current directory - baseDir := os.TempDir() - if cwd, err := os.Getwd(); err == nil { - // Prefer current directory if writable - baseDir = cwd - } - - c.workspaceDir = filepath.Join(baseDir, "jiggablend-workspaces", workspaceName) - if err := os.MkdirAll(c.workspaceDir, 0755); err != nil { - log.Printf("Warning: Failed to create workspace directory %s: %v", c.workspaceDir, err) - // Fallback to temp directory - c.workspaceDir = filepath.Join(os.TempDir(), "jiggablend-workspaces", workspaceName) - if err := os.MkdirAll(c.workspaceDir, 0755); err != nil { - log.Printf("Error: Failed to create fallback workspace directory: %v", err) - // Last resort: use temp directory with runner ID - c.workspaceDir = filepath.Join(os.TempDir(), fmt.Sprintf("jiggablend-runner-%d", c.runnerID)) - os.MkdirAll(c.workspaceDir, 0755) - } - } - log.Printf("Runner workspace initialized at: %s", c.workspaceDir) -} - -// getWorkspaceDir returns the workspace directory, initializing it if needed -func (c *Client) getWorkspaceDir() string { - if c.workspaceDir == "" { - c.initWorkspace() - } - return c.workspaceDir -} - -// probeCapabilities checks what capabilities the runner has by probing for blender and ffmpeg -// Returns a map that includes both boolean capabilities and numeric values (like GPU count) -func (c *Client) probeCapabilities() map[string]interface{} { - capabilities := make(map[string]interface{}) - - // Check for blender - blenderCmd := exec.Command("blender", "--version") - if err := blenderCmd.Run(); err == nil { - capabilities["blender"] = true - } else { - capabilities["blender"] = false - } - - // Check for ffmpeg - ffmpegCmd := exec.Command("ffmpeg", "-version") - if err := ffmpegCmd.Run(); err == nil { - capabilities["ffmpeg"] = true - - // Immediately probe GPU capabilities when ffmpeg is detected - log.Printf("FFmpeg detected, probing GPU hardware acceleration capabilities...") - c.probeGPUCapabilities(capabilities) - } else { - capabilities["ffmpeg"] = false - } - - return capabilities -} - -// probeGPUCapabilities probes GPU hardware acceleration capabilities for ffmpeg -// This is called immediately after detecting ffmpeg during initial capability probe -func (c *Client) probeGPUCapabilities(capabilities map[string]interface{}) { - // First, probe all available hardware acceleration methods - log.Printf("Probing all hardware acceleration methods...") - hwaccels := c.probeAllHardwareAccelerators() - if len(hwaccels) > 0 { - log.Printf("Available hardware acceleration methods: %v", getKeys(hwaccels)) - } else { - log.Printf("No hardware acceleration methods found") - } - - // Probe all hardware encoders - log.Printf("Probing all hardware encoders...") - hwEncoders := c.probeAllHardwareEncoders() - if len(hwEncoders) > 0 { - log.Printf("Available hardware encoders: %v", getKeys(hwEncoders)) - } - - // Check for other hardware encoders (for completeness) - log.Printf("Checking for other hardware encoders...") - if c.checkEncoderAvailable("h264_qsv") { - capabilities["qsv"] = true - capabilities["qsv_gpu_count"] = 1 - log.Printf("Intel Quick Sync (QSV) detected") - } else { - capabilities["qsv"] = false - capabilities["qsv_gpu_count"] = 0 - } - - if c.checkEncoderAvailable("h264_videotoolbox") { - capabilities["videotoolbox"] = true - capabilities["videotoolbox_gpu_count"] = 1 - log.Printf("VideoToolbox (macOS) detected") - } else { - capabilities["videotoolbox"] = false - capabilities["videotoolbox_gpu_count"] = 0 - } - - if c.checkEncoderAvailable("h264_amf") { - capabilities["amf"] = true - capabilities["amf_gpu_count"] = 1 - log.Printf("AMD AMF detected") - } else { - capabilities["amf"] = false - capabilities["amf_gpu_count"] = 0 - } - - // Check for V4L2M2M (Video4Linux2) - if c.checkEncoderAvailable("h264_v4l2m2m") { - capabilities["v4l2m2m"] = true - capabilities["v4l2m2m_gpu_count"] = 1 - log.Printf("V4L2 M2M detected") - } else { - capabilities["v4l2m2m"] = false - capabilities["v4l2m2m_gpu_count"] = 0 - } - - // Check for OpenMAX (Raspberry Pi) - if c.checkEncoderAvailable("h264_omx") { - capabilities["omx"] = true - capabilities["omx_gpu_count"] = 1 - log.Printf("OpenMAX detected") - } else { - capabilities["omx"] = false - capabilities["omx_gpu_count"] = 0 - } - - // Check for MediaCodec (Android) - if c.checkEncoderAvailable("h264_mediacodec") { - capabilities["mediacodec"] = true - capabilities["mediacodec_gpu_count"] = 1 - log.Printf("MediaCodec detected") - } else { - capabilities["mediacodec"] = false - capabilities["mediacodec_gpu_count"] = 0 - } -} - -// getKeys returns all keys from a map as a slice (helper function) -func getKeys(m map[string]bool) []string { - keys := make([]string, 0, len(m)) - for k := range m { - keys = append(keys, k) - } - return keys -} - -// ProbeCapabilities probes and caches capabilities (should be called once at startup) -func (c *Client) ProbeCapabilities() { - capabilities := c.probeCapabilities() - c.capabilitiesMu.Lock() - c.capabilities = capabilities - c.capabilitiesMu.Unlock() -} - -// GetCapabilities returns the cached capabilities -func (c *Client) GetCapabilities() map[string]interface{} { - c.capabilitiesMu.RLock() - defer c.capabilitiesMu.RUnlock() - // Return a copy to prevent external modification - result := make(map[string]interface{}) - for k, v := range c.capabilities { - result[k] = v - } - return result -} - -// Register registers the runner with the manager using a registration token -func (c *Client) Register(registrationToken string) (int64, string, string, error) { - // Use cached capabilities (should have been probed once at startup) - c.capabilitiesMu.RLock() - capabilities := c.capabilities - c.capabilitiesMu.RUnlock() - - // If capabilities weren't probed yet, probe them now (fallback) - if capabilities == nil { - capabilities = c.probeCapabilities() - c.capabilitiesMu.Lock() - c.capabilities = capabilities - c.capabilitiesMu.Unlock() - } - - capabilitiesJSON, err := json.Marshal(capabilities) - if err != nil { - return 0, "", "", fmt.Errorf("failed to marshal capabilities: %w", err) - } - - req := map[string]interface{}{ - "name": c.name, - "hostname": c.hostname, - "capabilities": string(capabilitiesJSON), - "api_key": registrationToken, // API key passed as registrationToken param for compatibility - } - - // Only send fingerprint for non-fixed API keys to avoid uniqueness conflicts - if !strings.HasPrefix(registrationToken, "jk_r0_") { // Fixed test key - req["fingerprint"] = c.GetFingerprint() - } - - body, _ := json.Marshal(req) - resp, err := c.httpClient.Post( - fmt.Sprintf("%s/api/runner/register", c.managerURL), - "application/json", - bytes.NewReader(body), - ) - if err != nil { - // Network/connection error - should retry - return 0, "", "", fmt.Errorf("connection error: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusCreated { - bodyBytes, _ := io.ReadAll(resp.Body) - errorBody := string(bodyBytes) - - // Check if it's a token-related error (should not retry) - if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusBadRequest { - // Check error message for token-related issues - errorLower := strings.ToLower(errorBody) - if strings.Contains(errorLower, "invalid") || - strings.Contains(errorLower, "expired") || - strings.Contains(errorLower, "already used") || - strings.Contains(errorLower, "token") { - return 0, "", "", fmt.Errorf("token error: %s", errorBody) - } - } - - // Other errors (like 500) might be retryable - return 0, "", "", fmt.Errorf("registration failed (status %d): %s", resp.StatusCode, errorBody) - } - - var result struct { - ID int64 `json:"id"` - } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return 0, "", "", fmt.Errorf("failed to decode response: %w", err) - } - - c.runnerID = result.ID - c.apiKey = registrationToken // Store the API key for future use - - return result.ID, registrationToken, "", nil // Return API key as "runner secret" for compatibility -} - -// doSignedRequest performs an authenticated HTTP request using shared secret -// queryParams is optional and will be appended to the URL -func (c *Client) doSignedRequest(method, path string, body []byte, queryParams ...string) (*http.Response, error) { - return c.doSignedRequestWithClient(method, path, body, c.httpClient, queryParams...) -} - -// doSignedRequestLong performs an authenticated HTTP request using the long-running client (no timeout) -// Use this for context downloads, file uploads/downloads, and other operations that may take a long time -func (c *Client) doSignedRequestLong(method, path string, body []byte, queryParams ...string) (*http.Response, error) { - return c.doSignedRequestWithClient(method, path, body, c.longRunningClient, queryParams...) -} - -// doSignedRequestWithClient performs an authenticated HTTP request using the specified client -func (c *Client) doSignedRequestWithClient(method, path string, body []byte, client *http.Client, queryParams ...string) (*http.Response, error) { - if c.apiKey == "" { - return nil, fmt.Errorf("runner not authenticated") - } - - // Build URL with query params if provided - url := fmt.Sprintf("%s%s", c.managerURL, path) - if len(queryParams) > 0 { - url += "?" + strings.Join(queryParams, "&") - } - - req, err := http.NewRequest(method, url, bytes.NewReader(body)) - if err != nil { - return nil, err - } - - // Add authentication - use API key in Authorization header - req.Header.Set("Authorization", "Bearer "+c.apiKey) - if len(body) > 0 { - req.Header.Set("Content-Type", "application/json") - } - - return client.Do(req) -} - -// ConnectWebSocket establishes a WebSocket connection to the manager -func (c *Client) ConnectWebSocket() error { - if c.runnerID == 0 || c.apiKey == "" { - return fmt.Errorf("runner not authenticated") - } - - // Build WebSocket URL with authentication - path := "/api/runner/ws" - - // Convert HTTP URL to WebSocket URL - wsURL := strings.Replace(c.managerURL, "http://", "ws://", 1) - wsURL = strings.Replace(wsURL, "https://", "wss://", 1) - wsURL = fmt.Sprintf("%s%s?runner_id=%d&api_key=%s", - wsURL, path, c.runnerID, url.QueryEscape(c.apiKey)) - - // Parse URL - u, err := url.Parse(wsURL) - if err != nil { - return fmt.Errorf("invalid WebSocket URL: %w", err) - } - - // Connect - dialer := websocket.Dialer{ - HandshakeTimeout: 10 * time.Second, - } - conn, _, err := dialer.Dial(u.String(), nil) - if err != nil { - return fmt.Errorf("failed to connect WebSocket: %w", err) - } - - c.wsConnMu.Lock() - if c.wsConn != nil { - c.wsConn.Close() - } - c.wsConn = conn - c.wsConnMu.Unlock() - - log.Printf("WebSocket connected to manager") - return nil -} - -// ConnectWebSocketWithReconnect connects with automatic reconnection -func (c *Client) ConnectWebSocketWithReconnect() { - backoff := 1 * time.Second - maxBackoff := 60 * time.Second - - for { - err := c.ConnectWebSocket() - if err == nil { - backoff = 1 * time.Second // Reset on success - c.HandleWebSocketMessages() - } else { - log.Printf("WebSocket connection failed: %v, retrying in %v", err, backoff) - time.Sleep(backoff) - backoff *= 2 - if backoff > maxBackoff { - backoff = maxBackoff - } - } - - // Check if we should stop - select { - case <-c.stopChan: - return - default: - } - } -} - -// HandleWebSocketMessages handles incoming WebSocket messages -func (c *Client) HandleWebSocketMessages() { - c.wsConnMu.Lock() - conn := c.wsConn - c.wsConnMu.Unlock() - - if conn == nil { - return - } - - // Set pong handler to respond to ping messages - // Also reset read deadline to keep connection alive - conn.SetPongHandler(func(string) error { - conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Increased to 90 seconds - return nil - }) - - // Set ping handler to respond with pong - // Also reset read deadline to keep connection alive - conn.SetPingHandler(func(string) error { - conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Increased to 90 seconds - // Respond to ping with pong - protect with write mutex - c.wsWriteMu.Lock() - defer c.wsWriteMu.Unlock() - return conn.WriteControl(websocket.PongMessage, []byte{}, time.Now().Add(10*time.Second)) - }) - - // Set read deadline to ensure we process control frames - conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Increased to 90 seconds - - // Handle messages - for { - // Reset read deadline for each message to allow ping/pong processing - conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Increased to 90 seconds - - var msg map[string]interface{} - err := conn.ReadJSON(&msg) - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - log.Printf("WebSocket error: %v", err) - } - c.wsConnMu.Lock() - c.wsConn = nil - c.wsConnMu.Unlock() - return - } - - // Reset read deadline after successfully reading a message - // This ensures the connection stays alive as long as we're receiving messages - conn.SetReadDeadline(time.Now().Add(90 * time.Second)) - - msgType, _ := msg["type"].(string) - switch msgType { - case "task_assignment": - c.handleTaskAssignment(msg) - case "ping": - // Respond to ping with pong (automatic) - } - } -} - -// handleTaskAssignment handles a task assignment message -func (c *Client) handleTaskAssignment(msg map[string]interface{}) { - data, ok := msg["data"].(map[string]interface{}) - if !ok { - log.Printf("Invalid task assignment message") - return - } - - taskID, _ := data["task_id"].(float64) - jobID, _ := data["job_id"].(float64) - jobName, _ := data["job_name"].(string) - outputFormat, _ := data["output_format"].(string) - frameStart, _ := data["frame_start"].(float64) - frameEnd, _ := data["frame_end"].(float64) - taskType, _ := data["task_type"].(string) - inputFilesRaw, _ := data["input_files"].([]interface{}) - - // Log that task assignment was received - taskIDInt := int64(taskID) - c.sendLog(taskIDInt, types.LogLevelInfo, fmt.Sprintf("Task assignment received from manager (job: %d, type: %s, frames: %d-%d)", int64(jobID), taskType, int(frameStart), int(frameEnd)), "") - - // Convert to task map format - taskMap := map[string]interface{}{ - "id": taskID, - "job_id": jobID, - "frame_start": frameStart, - "frame_end": frameEnd, - } - - // Process the task based on type - go func() { - var err error - switch taskType { - case "metadata": - if len(inputFilesRaw) == 0 { - log.Printf("No input files for metadata task %v", taskID) - c.sendTaskComplete(int64(taskID), "", false, "No input files") - return - } - err = c.processMetadataTask(taskMap, int64(jobID), inputFilesRaw) - case "video_generation": - err = c.processVideoGenerationTask(taskMap, int64(jobID)) - default: - if len(inputFilesRaw) == 0 { - errMsg := fmt.Sprintf("No input files provided for task %d (job %d). Task assignment data: job_name=%s, output_format=%s, task_type=%s", - int64(taskID), int64(jobID), jobName, outputFormat, taskType) - log.Printf("ERROR: %s", errMsg) - c.sendLog(int64(taskID), types.LogLevelError, errMsg, "") - c.sendTaskComplete(int64(taskID), "", false, "No input files provided") - return - } - log.Printf("Processing render task %d with %d input files: %v", int64(taskID), len(inputFilesRaw), inputFilesRaw) - err = c.processTask(taskMap, jobName, outputFormat, inputFilesRaw) - } - if err != nil { - errMsg := fmt.Sprintf("Task %d failed: %v", int64(taskID), err) - log.Printf("ERROR: %s", errMsg) - c.sendLog(int64(taskID), types.LogLevelError, errMsg, "") - c.sendTaskComplete(int64(taskID), "", false, err.Error()) - } - }() -} - -// HeartbeatLoop sends periodic heartbeats via WebSocket -func (c *Client) HeartbeatLoop() { - ticker := time.NewTicker(30 * time.Second) - defer ticker.Stop() - - for range ticker.C { - c.wsConnMu.RLock() - conn := c.wsConn - c.wsConnMu.RUnlock() - - if conn != nil { - // Send heartbeat via WebSocket - protect with write mutex - c.wsWriteMu.Lock() - msg := map[string]interface{}{ - "type": "heartbeat", - "timestamp": time.Now().Unix(), - } - err := conn.WriteJSON(msg) - c.wsWriteMu.Unlock() - - if err != nil { - log.Printf("Failed to send heartbeat: %v", err) - } - } - } -} - -// shouldFilterBlenderLog checks if a Blender log line should be filtered or downgraded -// Returns (shouldFilter, logLevel) - if shouldFilter is true, the log should be skipped -func shouldFilterBlenderLog(line string) (bool, types.LogLevel) { - // Filter out common Blender dependency graph noise - trimmed := strings.TrimSpace(line) - - // Filter out empty lines - if trimmed == "" { - return true, types.LogLevelInfo - } - - // Filter out separator lines (check both original and trimmed) - if trimmed == "--------------------------------------------------------------------" || - strings.HasPrefix(trimmed, "-----") && strings.Contains(trimmed, "----") { - return true, types.LogLevelInfo - } - - // Filter out trace headers (check both original and trimmed, case-insensitive) - upperLine := strings.ToUpper(trimmed) - upperOriginal := strings.ToUpper(line) - - // Check for "Depth Type Name" - match even if words are separated by different spacing - if trimmed == "Trace:" || - trimmed == "Depth Type Name" || - trimmed == "----- ---- ----" || - line == "Depth Type Name" || - line == "----- ---- ----" || - (strings.Contains(upperLine, "DEPTH") && strings.Contains(upperLine, "TYPE") && strings.Contains(upperLine, "NAME")) || - (strings.Contains(upperOriginal, "DEPTH") && strings.Contains(upperOriginal, "TYPE") && strings.Contains(upperOriginal, "NAME")) || - strings.Contains(line, "Depth Type Name") || - strings.Contains(line, "----- ---- ----") || - strings.HasPrefix(trimmed, "-----") || - regexp.MustCompile(`^[-]+\s+[-]+\s+[-]+$`).MatchString(trimmed) { - return true, types.LogLevelInfo - } - - // Completely filter out dependency graph messages (they're just noise) - dependencyGraphPatterns := []string{ - "Failed to add relation", - "Could not find op_from", - "OperationKey", - "find_node_operation: Failed for", - "BONE_DONE", - "component name:", - "operation code:", - "rope_ctrl_rot_", - } - - for _, pattern := range dependencyGraphPatterns { - if strings.Contains(line, pattern) { - return true, types.LogLevelInfo // Completely filter out - } - } - - // Filter out animation system warnings (invalid drivers are common and harmless) - animationSystemPatterns := []string{ - "BKE_animsys_eval_driver: invalid driver", - "bke.anim_sys", - "rotation_quaternion[", - "constraints[", - ".influence[0]", - "pose.bones[", - } - - for _, pattern := range animationSystemPatterns { - if strings.Contains(line, pattern) { - return true, types.LogLevelInfo // Completely filter out - } - } - - // Filter out modifier warnings (common when vertices change) - modifierPatterns := []string{ - "BKE_modifier_set_error", - "bke.modifier", - "Vertices changed from", - "Modifier:", - } - - for _, pattern := range modifierPatterns { - if strings.Contains(line, pattern) { - return true, types.LogLevelInfo // Completely filter out - } - } - - // Filter out lines that are just numbers or trace depth indicators - // Pattern: number, word, word (e.g., "1 Object timer_box_franck") - if matched, _ := regexp.MatchString(`^\d+\s+\w+\s+\w+`, trimmed); matched { - return true, types.LogLevelInfo - } - - return false, types.LogLevelInfo -} - -// sendLog sends a log entry to the manager via WebSocket -func (c *Client) sendLog(taskID int64, logLevel types.LogLevel, message, stepName string) { - c.wsConnMu.RLock() - conn := c.wsConn - c.wsConnMu.RUnlock() - - if conn != nil { - // Serialize all WebSocket writes to prevent concurrent write panics - c.wsWriteMu.Lock() - defer c.wsWriteMu.Unlock() - - msg := map[string]interface{}{ - "type": "log_entry", - "data": map[string]interface{}{ - "task_id": taskID, - "log_level": string(logLevel), - "message": message, - "step_name": stepName, - }, - "timestamp": time.Now().Unix(), - } - if err := conn.WriteJSON(msg); err != nil { - log.Printf("Failed to send log: %v", err) - } - } else { - log.Printf("WebSocket not connected, cannot send log") - } -} - -// KillAllProcesses kills all running processes tracked by this client -func (c *Client) KillAllProcesses() { - log.Printf("Killing all running processes...") - killedCount := c.processTracker.KillAll() - // Release all allocated VAAPI devices - c.allocatedDevicesMu.Lock() - for taskID := range c.allocatedDevices { - delete(c.allocatedDevices, taskID) - } - c.allocatedDevicesMu.Unlock() - log.Printf("Killed %d process(es)", killedCount) -} - -// CleanupWorkspace removes the runner's workspace directory and all contents -func (c *Client) CleanupWorkspace() { - log.Printf("DEBUG: CleanupWorkspace method called") - log.Printf("CleanupWorkspace called, workspaceDir: %s", c.workspaceDir) - if c.workspaceDir != "" { - log.Printf("Cleaning up workspace directory: %s", c.workspaceDir) - if err := os.RemoveAll(c.workspaceDir); err != nil { - log.Printf("Warning: Failed to remove workspace directory %s: %v", c.workspaceDir, err) - } else { - log.Printf("Successfully removed workspace directory: %s", c.workspaceDir) - } - } - - // Also clean up any orphaned jiggablend directories that might exist - // This ensures zero persistence even if workspaceDir wasn't set - cleanupOrphanedWorkspaces() -} - -// cleanupOrphanedWorkspaces removes any jiggablend workspace directories -// that might be left behind from previous runs or crashes -func cleanupOrphanedWorkspaces() { - log.Printf("Cleaning up orphaned jiggablend workspace directories...") - - // Clean up jiggablend-workspaces directories in current and temp directories - dirsToCheck := []string{".", os.TempDir()} - for _, baseDir := range dirsToCheck { - workspaceDir := filepath.Join(baseDir, "jiggablend-workspaces") - if _, err := os.Stat(workspaceDir); err == nil { - log.Printf("Removing orphaned workspace directory: %s", workspaceDir) - if err := os.RemoveAll(workspaceDir); err != nil { - log.Printf("Warning: Failed to remove workspace directory %s: %v", workspaceDir, err) - } else { - log.Printf("Successfully removed workspace directory: %s", workspaceDir) - } - } - } -} - -// sendStepUpdate sends a step start/complete event to the manager -func (c *Client) sendStepUpdate(taskID int64, stepName string, status types.StepStatus, errorMsg string) { - key := fmt.Sprintf("%d:%s", taskID, stepName) - var durationMs *int - - // Track step start time - if status == types.StepStatusRunning { - c.stepTimesMu.Lock() - c.stepStartTimes[key] = time.Now() - c.stepTimesMu.Unlock() - } - - // Calculate duration if step is completing - if status == types.StepStatusCompleted || status == types.StepStatusFailed { - c.stepTimesMu.RLock() - startTime, exists := c.stepStartTimes[key] - c.stepTimesMu.RUnlock() - if exists { - duration := int(time.Since(startTime).Milliseconds()) - durationMs = &duration - c.stepTimesMu.Lock() - delete(c.stepStartTimes, key) - c.stepTimesMu.Unlock() - } - } - - // Send step update via HTTP API - reqBody := map[string]interface{}{ - "step_name": stepName, - "status": string(status), - } - if durationMs != nil { - reqBody["duration_ms"] = *durationMs - } - if errorMsg != "" { - reqBody["error_message"] = errorMsg - } - - body, _ := json.Marshal(reqBody) - // Sign with path only (without query params) to match manager verification - path := fmt.Sprintf("/api/runner/tasks/%d/steps", taskID) - resp, err := c.doSignedRequest("POST", path, body, fmt.Sprintf("runner_id=%d", c.runnerID)) - if err != nil { - log.Printf("Failed to send step update: %v", err) - // Fallback to log-based tracking - msg := fmt.Sprintf("Step %s: %s", stepName, status) - if errorMsg != "" { - msg += " - " + errorMsg - } - logLevel := types.LogLevelInfo - if status == types.StepStatusFailed { - logLevel = types.LogLevelError - } - c.sendLog(taskID, logLevel, msg, stepName) - return - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - log.Printf("Step update failed: %s", string(body)) - // Fallback to log-based tracking - msg := fmt.Sprintf("Step %s: %s", stepName, status) - if errorMsg != "" { - msg += " - " + errorMsg - } - logLevel := types.LogLevelInfo - if status == types.StepStatusFailed { - logLevel = types.LogLevelError - } - c.sendLog(taskID, logLevel, msg, stepName) - return - } - - // Also send log for debugging - msg := fmt.Sprintf("Step %s: %s", stepName, status) - if errorMsg != "" { - msg += " - " + errorMsg - } - logLevel := types.LogLevelInfo - if status == types.StepStatusFailed { - logLevel = types.LogLevelError - } - c.sendLog(taskID, logLevel, msg, stepName) -} - -// processTask processes a single task -func (c *Client) processTask(task map[string]interface{}, jobName string, outputFormat string, inputFiles []interface{}) (err error) { - _ = jobName - - taskID := int64(task["id"].(float64)) - jobID := int64(task["job_id"].(float64)) - frameStart := int(task["frame_start"].(float64)) - frameEnd := int(task["frame_end"].(float64)) - - // Create temporary job workspace within runner workspace - workDir := filepath.Join(c.getWorkspaceDir(), fmt.Sprintf("job-%d-task-%d", jobID, taskID)) - if mkdirErr := os.MkdirAll(workDir, 0755); mkdirErr != nil { - return fmt.Errorf("failed to create work directory: %w", mkdirErr) - } - - // Guaranteed cleanup even on panic - defer func() { - if cleanupErr := os.RemoveAll(workDir); cleanupErr != nil { - log.Printf("Warning: Failed to cleanup work directory %s: %v", workDir, cleanupErr) - } - }() - - // Panic recovery for this task - defer func() { - if r := recover(); r != nil { - log.Printf("Task %d panicked: %v", taskID, r) - err = fmt.Errorf("task panicked: %v", r) - } - }() - - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Starting task: job %d, frames %d-%d, format: %s", jobID, frameStart, frameEnd, outputFormat), "") - log.Printf("Processing task %d: job %d, frames %d-%d, format: %s (from task assignment)", taskID, jobID, frameStart, frameEnd, outputFormat) - - // Step: download - c.sendStepUpdate(taskID, "download", types.StepStatusRunning, "") - c.sendLog(taskID, types.LogLevelInfo, "Downloading job context...", "download") - - // Clean up expired cache entries periodically - c.cleanupExpiredContextCache() - - // Download context tar - contextPath := filepath.Join(workDir, "context.tar") - if err := c.downloadJobContext(jobID, contextPath); err != nil { - c.sendStepUpdate(taskID, "download", types.StepStatusFailed, err.Error()) - return fmt.Errorf("failed to download context: %w", err) - } - - // Extract context tar - c.sendLog(taskID, types.LogLevelInfo, "Extracting context...", "download") - if err := c.extractTar(contextPath, workDir); err != nil { - c.sendStepUpdate(taskID, "download", types.StepStatusFailed, err.Error()) - return fmt.Errorf("failed to extract context: %w", err) - } - - // Find .blend file in extracted contents - blendFile := "" - err = filepath.Walk(workDir, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".blend") { - // Check it's not a Blender save file (.blend1, .blend2, etc.) - lower := strings.ToLower(info.Name()) - idx := strings.LastIndex(lower, ".blend") - if idx != -1 { - suffix := lower[idx+len(".blend"):] - // If there are digits after .blend, it's a save file - isSaveFile := false - if len(suffix) > 0 { - isSaveFile = true - for _, r := range suffix { - if r < '0' || r > '9' { - isSaveFile = false - break - } - } - } - if !isSaveFile { - blendFile = path - return filepath.SkipAll // Stop walking once we find a blend file - } - } - } - return nil - }) - - if err != nil { - c.sendStepUpdate(taskID, "download", types.StepStatusFailed, err.Error()) - return fmt.Errorf("failed to find blend file: %w", err) - } - - if blendFile == "" { - err := fmt.Errorf("no .blend file found in context - the uploaded context archive must contain at least one .blend file to render") - c.sendStepUpdate(taskID, "download", types.StepStatusFailed, err.Error()) - return err - } - - c.sendStepUpdate(taskID, "download", types.StepStatusCompleted, "") - c.sendLog(taskID, types.LogLevelInfo, "Context downloaded and extracted successfully", "download") - - // Fetch job metadata to get render settings - var jobMetadata *types.BlendMetadata - metadata, err := c.getJobMetadata(jobID) - if err == nil && metadata != nil { - jobMetadata = metadata - c.sendLog(taskID, types.LogLevelInfo, "Loaded render settings from job metadata", "render_blender") - } else { - c.sendLog(taskID, types.LogLevelInfo, "No render settings found in job metadata, using blend file defaults", "render_blender") - } - - // Render frames - outputDir := filepath.Join(workDir, "output") - if err := os.MkdirAll(outputDir, 0755); err != nil { - return fmt.Errorf("failed to create output directory: %w", err) - } - - // For EXR_264_MP4 and EXR_AV1_MP4, render as EXR (OpenEXR) first for highest fidelity, then combine into video - renderFormat := outputFormat - if outputFormat == "EXR_264_MP4" || outputFormat == "EXR_AV1_MP4" { - renderFormat = "EXR" // Use EXR for maximum quality (32-bit float, HDR) - } - - // Step: render_blender - c.sendStepUpdate(taskID, "render_blender", types.StepStatusRunning, "") - if frameStart == frameEnd { - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Starting Blender render for frame %d...", frameStart), "render_blender") - } else { - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Starting Blender render for frames %d-%d...", frameStart, frameEnd), "render_blender") - } - - // Always render frames individually for precise control over file naming - // This avoids Blender's automatic frame numbering quirks - - // Override output format and render settings from job submission - // For MP4, we render as EXR (handled above) for highest fidelity, so renderFormat is already EXR - // This script will override the blend file's settings based on job metadata - formatFilePath := filepath.Join(workDir, "output_format.txt") - renderSettingsFilePath := filepath.Join(workDir, "render_settings.json") - - // Check if unhide_objects is enabled - unhideObjects := false - if jobMetadata != nil && jobMetadata.UnhideObjects != nil && *jobMetadata.UnhideObjects { - unhideObjects = true - } - - // Build unhide code conditionally from embedded script - unhideCode := "" - if unhideObjects { - unhideCode = scripts.UnhideObjects - } - - // Load template and replace placeholders - scriptContent := scripts.RenderBlenderTemplate - scriptContent = strings.ReplaceAll(scriptContent, "{{UNHIDE_CODE}}", unhideCode) - scriptContent = strings.ReplaceAll(scriptContent, "{{FORMAT_FILE_PATH}}", fmt.Sprintf("%q", formatFilePath)) - scriptContent = strings.ReplaceAll(scriptContent, "{{RENDER_SETTINGS_FILE}}", fmt.Sprintf("%q", renderSettingsFilePath)) - scriptPath := filepath.Join(workDir, "enable_gpu.py") - if err := os.WriteFile(scriptPath, []byte(scriptContent), 0644); err != nil { - errMsg := fmt.Sprintf("failed to create GPU enable script: %v", err) - c.sendLog(taskID, types.LogLevelError, errMsg, "render_blender") - c.sendStepUpdate(taskID, "render_blender", types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - - // Write output format to a temporary file for the script to read - // (Blender's argument parsing makes it tricky to pass custom args to Python scripts) - // IMPORTANT: Write the user's selected outputFormat, NOT renderFormat - // renderFormat might be "EXR" for video, but we want the user's actual selection (PNG, JPEG, etc.) - formatFile := filepath.Join(workDir, "output_format.txt") - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Writing output format '%s' to format file (user selected: '%s', render format: '%s')", outputFormat, outputFormat, renderFormat), "render_blender") - if err := os.WriteFile(formatFile, []byte(outputFormat), 0644); err != nil { - errMsg := fmt.Sprintf("failed to create format file: %v", err) - c.sendLog(taskID, types.LogLevelError, errMsg, "render_blender") - c.sendStepUpdate(taskID, "render_blender", types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - - // Write render settings to a JSON file if we have metadata with render settings - renderSettingsFile := filepath.Join(workDir, "render_settings.json") - if jobMetadata != nil && jobMetadata.RenderSettings.EngineSettings != nil { - settingsJSON, err := json.Marshal(jobMetadata.RenderSettings) - if err == nil { - if err := os.WriteFile(renderSettingsFile, settingsJSON, 0644); err != nil { - c.sendLog(taskID, types.LogLevelWarn, fmt.Sprintf("Failed to write render settings file: %v", err), "render_blender") - } - } - } - - // Check if execution should be enabled (defaults to false/off) - enableExecution := false - if jobMetadata != nil && jobMetadata.EnableExecution != nil && *jobMetadata.EnableExecution { - enableExecution = true - } - - // Run Blender with GPU enabled via Python script - // Use -s (start) and -e (end) for frame ranges, or -f for single frame - // Use Blender's automatic frame numbering with #### pattern - var cmd *exec.Cmd - args := []string{"-b", blendFile, "--python", scriptPath} - if enableExecution { - args = append(args, "--enable-autoexec") - } - - // Output pattern uses #### which Blender will replace with frame numbers - outputPattern := filepath.Join(outputDir, fmt.Sprintf("frame_####.%s", strings.ToLower(renderFormat))) - outputAbsPattern, _ := filepath.Abs(outputPattern) - args = append(args, "-o", outputAbsPattern) - - if frameStart == frameEnd { - // Single frame - args = append(args, "-f", fmt.Sprintf("%d", frameStart)) - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Rendering frame %d...", frameStart), "render_blender") - } else { - // Frame range - args = append(args, "-s", fmt.Sprintf("%d", frameStart), "-e", fmt.Sprintf("%d", frameEnd)) - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Rendering frames %d-%d...", frameStart, frameEnd), "render_blender") - } - - // Create and run Blender command - cmd = exec.Command("blender", args...) - cmd.Dir = workDir - cmd.Env = os.Environ() - - // Blender will handle headless rendering automatically - // We preserve the environment to allow GPU access if available - - // Capture stdout and stderr separately for line-by-line streaming - stdoutPipe, err := cmd.StdoutPipe() - if err != nil { - errMsg := fmt.Sprintf("failed to create stdout pipe: %v", err) - c.sendLog(taskID, types.LogLevelError, errMsg, "render_blender") - c.sendStepUpdate(taskID, "render_blender", types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - - stderrPipe, err := cmd.StderrPipe() - if err != nil { - errMsg := fmt.Sprintf("failed to create stderr pipe: %v", err) - c.sendLog(taskID, types.LogLevelError, errMsg, "render_blender") - c.sendStepUpdate(taskID, "render_blender", types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - - // Start the command - if err := cmd.Start(); err != nil { - errMsg := fmt.Sprintf("failed to start blender: %v", err) - c.sendLog(taskID, types.LogLevelError, errMsg, "render_blender") - c.sendStepUpdate(taskID, "render_blender", types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - - // Register process for cleanup on shutdown - c.processTracker.Track(taskID, cmd) - defer c.processTracker.Untrack(taskID) - - // Stream stdout line by line - stdoutDone := make(chan bool) - go func() { - defer close(stdoutDone) - scanner := bufio.NewScanner(stdoutPipe) - for scanner.Scan() { - line := scanner.Text() - if line != "" { - shouldFilter, logLevel := shouldFilterBlenderLog(line) - if !shouldFilter { - c.sendLog(taskID, logLevel, line, "render_blender") - } - } - } - }() - - // Stream stderr line by line - stderrDone := make(chan bool) - go func() { - defer close(stderrDone) - scanner := bufio.NewScanner(stderrPipe) - for scanner.Scan() { - line := scanner.Text() - if line != "" { - shouldFilter, logLevel := shouldFilterBlenderLog(line) - if !shouldFilter { - // Use the filtered log level, but if it's still WARN, keep it as WARN - if logLevel == types.LogLevelInfo { - logLevel = types.LogLevelWarn - } - c.sendLog(taskID, logLevel, line, "render_blender") - } - } - } - }() - - // Wait for command to complete - err = cmd.Wait() - - // Wait for streaming goroutines to finish - <-stdoutDone - <-stderrDone - if err != nil { - var errMsg string - if exitErr, ok := err.(*exec.ExitError); ok { - if exitErr.ExitCode() == 137 { - errMsg = "Blender was killed due to excessive memory usage (OOM)" - } else { - errMsg = fmt.Sprintf("blender failed: %v", err) - } - } else { - errMsg = fmt.Sprintf("blender failed: %v", err) - } - c.sendLog(taskID, types.LogLevelError, errMsg, "render_blender") - c.sendStepUpdate(taskID, "render_blender", types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - - // Blender has rendered frames with automatic numbering using the #### pattern - // Files will be named like frame_0001.png, frame_0002.png, etc. - - // Find rendered output file(s) - // For frame ranges, we'll find all frames in the upload step - // For single frames, we need to find the specific output file - outputFile := "" - - // Only check for single output file if it's a single frame render - if frameStart == frameEnd { - // List all files in output directory to find what Blender actually created - entries, err := os.ReadDir(outputDir) - if err == nil { - c.sendLog(taskID, types.LogLevelInfo, "Checking output directory for files...", "render_blender") - - // Try exact match first: frame_0155.png - expectedFile := filepath.Join(outputDir, fmt.Sprintf("frame_%04d.%s", frameStart, strings.ToLower(renderFormat))) - if _, err := os.Stat(expectedFile); err == nil { - outputFile = expectedFile - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Found output file: %s", filepath.Base(expectedFile)), "render_blender") - } else { - // Try without zero padding: frame_155.png - altFile := filepath.Join(outputDir, fmt.Sprintf("frame_%d.%s", frameStart, strings.ToLower(renderFormat))) - if _, err := os.Stat(altFile); err == nil { - outputFile = altFile - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Found output file: %s", filepath.Base(altFile)), "render_blender") - } else { - // Try just frame number: 0155.png or 155.png - altFile2 := filepath.Join(outputDir, fmt.Sprintf("%04d.%s", frameStart, strings.ToLower(renderFormat))) - if _, err := os.Stat(altFile2); err == nil { - outputFile = altFile2 - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Found output file: %s", filepath.Base(altFile2)), "render_blender") - } else { - // Search through all files for one containing the frame number - for _, entry := range entries { - if !entry.IsDir() { - fileName := entry.Name() - // Skip files that contain the literal pattern string (Blender bug) - if strings.Contains(fileName, "%04d") || strings.Contains(fileName, "%d") { - c.sendLog(taskID, types.LogLevelWarn, fmt.Sprintf("Skipping file with literal pattern: %s", fileName), "render_blender") - continue - } - // Check if filename contains the frame number (with or without padding) - frameStr := fmt.Sprintf("%d", frameStart) - frameStrPadded := fmt.Sprintf("%04d", frameStart) - if strings.Contains(fileName, frameStrPadded) || - (strings.Contains(fileName, frameStr) && strings.HasSuffix(strings.ToLower(fileName), strings.ToLower(renderFormat))) { - outputFile = filepath.Join(outputDir, fileName) - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Found output file: %s", fileName), "render_blender") - break - } - } - } - } - } - } - } - - if outputFile == "" { - // List all files in output directory for debugging - entries, _ := os.ReadDir(outputDir) - fileList := []string{} - for _, entry := range entries { - if !entry.IsDir() { - fileList = append(fileList, entry.Name()) - } - } - expectedFile := filepath.Join(outputDir, fmt.Sprintf("frame_%04d.%s", frameStart, strings.ToLower(renderFormat))) - errMsg := fmt.Sprintf("output file not found: %s\nFiles in output directory: %v", - expectedFile, fileList) - c.sendLog(taskID, types.LogLevelError, errMsg, "render_blender") - c.sendStepUpdate(taskID, "render_blender", types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Blender render completed for frame %d", frameStart), "render_blender") - } else { - // Frame range - Blender renders multiple frames, we'll find them all in the upload step - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Blender render completed for frames %d-%d", frameStart, frameEnd), "render_blender") - } - c.sendStepUpdate(taskID, "render_blender", types.StepStatusCompleted, "") - - // Step: upload or upload_frames - uploadStepName := "upload" - if outputFormat == "EXR_264_MP4" || outputFormat == "EXR_AV1_MP4" { - uploadStepName = "upload_frames" - } - c.sendStepUpdate(taskID, uploadStepName, types.StepStatusRunning, "") - - var outputPath string - // If we have a frame range, find and upload all frames - if frameStart != frameEnd { - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Uploading frames %d-%d...", frameStart, frameEnd), uploadStepName) - - // Find all rendered frames in the output directory - var frameFiles []string - entries, err := os.ReadDir(outputDir) - if err == nil { - for frame := frameStart; frame <= frameEnd; frame++ { - // Try different naming patterns - patterns := []string{ - fmt.Sprintf("frame_%04d.%s", frame, strings.ToLower(renderFormat)), - fmt.Sprintf("frame_%d.%s", frame, strings.ToLower(renderFormat)), - fmt.Sprintf("%04d.%s", frame, strings.ToLower(renderFormat)), - fmt.Sprintf("%d.%s", frame, strings.ToLower(renderFormat)), - } - - found := false - for _, pattern := range patterns { - framePath := filepath.Join(outputDir, pattern) - if _, err := os.Stat(framePath); err == nil { - frameFiles = append(frameFiles, framePath) - found = true - break - } - } - - // If not found with patterns, search through entries - if !found { - frameStr := fmt.Sprintf("%d", frame) - frameStrPadded := fmt.Sprintf("%04d", frame) - for _, entry := range entries { - if entry.IsDir() { - continue - } - fileName := entry.Name() - // Skip files with literal pattern strings - if strings.Contains(fileName, "%04d") || strings.Contains(fileName, "%d") { - continue - } - // Check if filename contains the frame number - fullPath := filepath.Join(outputDir, fileName) - alreadyAdded := false - for _, existing := range frameFiles { - if existing == fullPath { - alreadyAdded = true - break - } - } - if !alreadyAdded && - (strings.Contains(fileName, frameStrPadded) || - (strings.Contains(fileName, frameStr) && strings.HasSuffix(strings.ToLower(fileName), strings.ToLower(renderFormat)))) { - frameFiles = append(frameFiles, fullPath) - found = true - break - } - } - } - } - } - - if len(frameFiles) == 0 { - errMsg := fmt.Sprintf("no frame files found for range %d-%d", frameStart, frameEnd) - c.sendLog(taskID, types.LogLevelError, errMsg, uploadStepName) - c.sendStepUpdate(taskID, uploadStepName, types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - - // Upload all frames - uploadedCount := 0 - uploadedFiles := []string{} - for i, frameFile := range frameFiles { - fileName := filepath.Base(frameFile) - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Uploading frame %d/%d: %s", i+1, len(frameFiles), fileName), uploadStepName) - uploadedPath, err := c.uploadFile(jobID, frameFile) - if err != nil { - errMsg := fmt.Sprintf("failed to upload frame %s: %v", fileName, err) - c.sendLog(taskID, types.LogLevelError, errMsg, uploadStepName) - c.sendStepUpdate(taskID, uploadStepName, types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - uploadedCount++ - uploadedFiles = append(uploadedFiles, fileName) - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Uploaded frame %d/%d: %s -> %s", i+1, len(frameFiles), fileName, uploadedPath), uploadStepName) - } - - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Successfully uploaded %d frames: %v", uploadedCount, uploadedFiles), uploadStepName) - c.sendStepUpdate(taskID, uploadStepName, types.StepStatusCompleted, "") - outputPath = "" // Not used for frame ranges, frames are uploaded individually - } else { - // Single frame upload - fileName := filepath.Base(outputFile) - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Uploading output file: %s", fileName), uploadStepName) - outputPath, err = c.uploadFile(jobID, outputFile) - if err != nil { - errMsg := fmt.Sprintf("failed to upload output file %s: %v", fileName, err) - c.sendLog(taskID, types.LogLevelError, errMsg, uploadStepName) - c.sendStepUpdate(taskID, uploadStepName, types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Output file uploaded successfully: %s -> %s", fileName, outputPath), uploadStepName) - c.sendStepUpdate(taskID, uploadStepName, types.StepStatusCompleted, "") - } - - // Step: complete - c.sendStepUpdate(taskID, "complete", types.StepStatusRunning, "") - c.sendLog(taskID, types.LogLevelInfo, "Task completed successfully", "complete") - - // Mark task as complete - if err := c.completeTask(taskID, outputPath, true, ""); err != nil { - c.sendStepUpdate(taskID, "complete", types.StepStatusFailed, err.Error()) - return err - } - c.sendStepUpdate(taskID, "complete", types.StepStatusCompleted, "") - - return nil -} - -// processVideoGenerationTask processes a video generation task -func (c *Client) processVideoGenerationTask(task map[string]interface{}, jobID int64) (err error) { - taskID := int64(task["id"].(float64)) - - // Create temporary job workspace for video generation within runner workspace - workDir := filepath.Join(c.getWorkspaceDir(), fmt.Sprintf("job-%d-video", jobID)) - if mkdirErr := os.MkdirAll(workDir, 0755); mkdirErr != nil { - return fmt.Errorf("failed to create work directory: %w", mkdirErr) - } - - // Guaranteed cleanup even on panic - defer func() { - if cleanupErr := os.RemoveAll(workDir); cleanupErr != nil { - log.Printf("Warning: Failed to cleanup work directory %s: %v", workDir, cleanupErr) - } - }() - - // Panic recovery for this task - defer func() { - if r := recover(); r != nil { - log.Printf("Video generation task %d panicked: %v", taskID, r) - err = fmt.Errorf("video generation task panicked: %v", r) - } - }() - - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Starting video generation task: job %d", jobID), "") - log.Printf("Processing video generation task %d for job %d", taskID, jobID) - - // Get job metadata to determine output format - jobMetadata, err := c.getJobMetadata(jobID) - var outputFormat string - if err == nil && jobMetadata != nil && jobMetadata.RenderSettings.OutputFormat != "" { - outputFormat = jobMetadata.RenderSettings.OutputFormat - } else { - // Fallback: try to get from task data or default to EXR_264_MP4 - if format, ok := task["output_format"].(string); ok { - outputFormat = format - } else { - outputFormat = "EXR_264_MP4" // Default - } - } - - // Debug logging for output format detection - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Video generation: detected output format '%s'", outputFormat), "generate_video") - - // Get frame rate from render settings - var frameRate float64 = 24.0 // Default fallback - if err == nil && jobMetadata != nil && jobMetadata.RenderSettings.FrameRate > 0 { - frameRate = jobMetadata.RenderSettings.FrameRate - } - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Video generation: using frame rate %.2f fps", frameRate), "generate_video") - - // Get all output files for this job - files, err := c.getJobFiles(jobID) - if err != nil { - c.sendStepUpdate(taskID, "get_files", types.StepStatusFailed, err.Error()) - return fmt.Errorf("failed to get job files: %w", err) - } - - // Find all EXR frame files (MP4 is rendered as EXR for highest fidelity - 32-bit float HDR) - var exrFiles []map[string]interface{} - for _, file := range files { - fileType, _ := file["file_type"].(string) - fileName, _ := file["file_name"].(string) - // Check for both .exr and .EXR extensions - if fileType == "output" && (strings.HasSuffix(strings.ToLower(fileName), ".exr") || strings.HasSuffix(fileName, ".EXR")) { - exrFiles = append(exrFiles, file) - } - } - - if len(exrFiles) == 0 { - err := fmt.Errorf("no EXR frame files found for MP4 generation") - c.sendStepUpdate(taskID, "get_files", types.StepStatusFailed, err.Error()) - return err - } - - c.sendStepUpdate(taskID, "get_files", types.StepStatusCompleted, "") - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Found %d EXR frames for video generation (highest fidelity - 32-bit HDR)", len(exrFiles)), "get_files") - - log.Printf("Generating MP4 for job %d from %d EXR frames", jobID, len(exrFiles)) - - // Step: download_frames - c.sendStepUpdate(taskID, "download_frames", types.StepStatusRunning, "") - c.sendLog(taskID, types.LogLevelInfo, "Downloading EXR frames...", "download_frames") - - // Download all EXR frames - var frameFiles []string - for _, file := range exrFiles { - fileName, _ := file["file_name"].(string) - framePath := filepath.Join(workDir, fileName) - if err := c.downloadFrameFile(jobID, fileName, framePath); err != nil { - log.Printf("Failed to download frame %s: %v", fileName, err) - continue - } - frameFiles = append(frameFiles, framePath) - } - - if len(frameFiles) == 0 { - err := fmt.Errorf("failed to download any frame files") - c.sendStepUpdate(taskID, "download_frames", types.StepStatusFailed, err.Error()) - return err - } - - // Sort frame files by name to ensure correct order - sort.Strings(frameFiles) - c.sendStepUpdate(taskID, "download_frames", types.StepStatusCompleted, "") - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Downloaded %d frames", len(frameFiles)), "download_frames") - - // Step: generate_video - c.sendStepUpdate(taskID, "generate_video", types.StepStatusRunning, "") - - // Determine codec and pixel format based on output format - var codec string - var pixFmt string - var useAlpha bool - - if outputFormat == "EXR_AV1_MP4" { - codec = "libaom-av1" - pixFmt = "yuva420p" // AV1 with alpha channel - useAlpha = true - c.sendLog(taskID, types.LogLevelInfo, "Generating MP4 video with AV1 codec (with alpha channel)...", "generate_video") - } else { - // Default to H.264 for EXR_264_MP4 - codec = "libx264" - pixFmt = "yuv420p" // H.264 without alpha - useAlpha = false - c.sendLog(taskID, types.LogLevelInfo, "Generating MP4 video with H.264 codec...", "generate_video") - } - - // Generate MP4 using ffmpeg - outputMP4 := filepath.Join(workDir, fmt.Sprintf("output_%d.mp4", jobID)) - - // Use ffmpeg to combine EXR frames into MP4 - // Method 1: Using image sequence input (more reliable) - firstFrame := frameFiles[0] - // Extract frame number pattern (e.g., frame_2470.exr -> frame_%04d.exr) - baseName := filepath.Base(firstFrame) - // Find the numeric part and replace it with %04d pattern - // Use regex to find digits (positive only, negative frames not supported) after underscore and before extension - re := regexp.MustCompile(`_(\d+)\.`) - var pattern string - var startNumber int - frameNumStr := re.FindStringSubmatch(baseName) - if len(frameNumStr) > 1 { - // Replace the numeric part with %04d - pattern = re.ReplaceAllString(baseName, "_%04d.") - // Extract the starting frame number - fmt.Sscanf(frameNumStr[1], "%d", &startNumber) - } else { - // Fallback: try simple replacement - startNumber = extractFrameNumber(baseName) - pattern = strings.Replace(baseName, fmt.Sprintf("%d", startNumber), "%04d", 1) - } - // Pattern path should be in workDir where the frame files are downloaded - patternPath := filepath.Join(workDir, pattern) - - // Allocate a VAAPI device for this task (if available) - allocatedDevice := c.allocateVAAPIDevice(taskID) - defer c.releaseVAAPIDevice(taskID) // Always release the device when done - if allocatedDevice != "" { - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Using VAAPI device: %s", allocatedDevice), "generate_video") - } else { - c.sendLog(taskID, types.LogLevelInfo, "No VAAPI device available, will use software encoding or other hardware", "generate_video") - } - - // Run ffmpeg to combine EXR frames into MP4 at 24 fps - // EXR is 32-bit float HDR format - FFmpeg will automatically tonemap to 8-bit/10-bit for video - // Use -start_number to tell ffmpeg the starting frame number - var cmd *exec.Cmd - var useHardware bool - - if outputFormat == "EXR_AV1_MP4" { - // Try AV1 hardware acceleration - cmd, err = c.buildFFmpegCommandAV1(allocatedDevice, useAlpha, "-y", "-start_number", fmt.Sprintf("%d", startNumber), - "-framerate", "24", "-i", patternPath, - "-r", "24", outputMP4) - if err == nil { - useHardware = true - c.sendLog(taskID, types.LogLevelInfo, "Using AV1 hardware acceleration", "generate_video") - } else { - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("AV1 hardware acceleration not available, will use software: %v", err), "generate_video") - } - } else { - // Try H.264 hardware acceleration - if allocatedDevice != "" { - cmd, err = c.buildFFmpegCommand(allocatedDevice, "-y", "-start_number", fmt.Sprintf("%d", startNumber), - "-framerate", "24", "-i", patternPath, - "-r", "24", outputMP4) - if err == nil { - useHardware = true - } else { - allocatedDevice = "" // Fall back to software - } - } - } - - if !useHardware { - // Software encoding with HDR tonemapping - // Build video filter for HDR to SDR conversion - var vf string - if useAlpha { - // For AV1 with alpha: preserve alpha channel during tonemapping - vf = "zscale=t=linear:npl=100,format=gbrpf32le,zscale=p=bt709,tonemap=tonemap=hable:desat=0,zscale=t=bt709:m=bt709:r=tv,format=yuva420p" - } else { - // For H.264 without alpha: standard tonemapping - vf = "zscale=t=linear:npl=100,format=gbrpf32le,zscale=p=bt709,tonemap=tonemap=hable:desat=0,zscale=t=bt709:m=bt709:r=tv,format=yuv420p" - } - - // Build ffmpeg command with high-quality EXR input processing - cmd = exec.Command("ffmpeg", "-y", - "-f", "image2", // Force image sequence input format - "-start_number", fmt.Sprintf("%d", startNumber), - "-framerate", fmt.Sprintf("%.2f", frameRate), - "-i", patternPath, - "-vf", vf, - "-c:v", codec, "-pix_fmt", pixFmt, - "-r", fmt.Sprintf("%.2f", frameRate), - "-color_primaries", "bt709", // Ensure proper color primaries - "-color_trc", "bt709", // Ensure proper transfer characteristics - "-colorspace", "bt709", // Ensure proper color space - outputMP4) - - // Prepare codec-specific arguments - var codecArgs []string - if outputFormat == "EXR_AV1_MP4" { - // AV1 encoding options for maximum quality - codecArgs = []string{"-cpu-used", "1", "-crf", "15", "-b:v", "0", "-row-mt", "1", "-tiles", "4x4", "-lag-in-frames", "25", "-arnr-max-frames", "15", "-arnr-strength", "4"} - } else { - // H.264 encoding options for maximum quality - codecArgs = []string{"-preset", "veryslow", "-crf", "15", "-profile:v", "high", "-level", "5.2", "-tune", "film", "-keyint_min", "24", "-g", "240", "-bf", "2", "-refs", "4"} - } - - // Perform 2-pass encoding for optimal quality distribution - c.sendLog(taskID, types.LogLevelInfo, "Starting 2-pass video encoding for optimal quality...", "generate_video") - - // PASS 1: Analysis pass (collects statistics for better rate distribution) - c.sendLog(taskID, types.LogLevelInfo, "Pass 1/2: Analyzing video content for optimal encoding...", "generate_video") - pass1Args := append([]string{"-y", "-f", "image2", "-start_number", fmt.Sprintf("%d", startNumber), "-framerate", fmt.Sprintf("%.2f", frameRate), "-i", patternPath, "-vf", vf, "-c:v", codec, "-pix_fmt", pixFmt, "-r", fmt.Sprintf("%.2f", frameRate), "-color_primaries", "bt709", "-color_trc", "bt709", "-colorspace", "bt709"}, codecArgs...) - pass1Args = append(pass1Args, "-pass", "1", "-f", "null", "/dev/null") - - pass1Cmd := exec.Command("ffmpeg", pass1Args...) - pass1Cmd.Dir = workDir - pass1Err := pass1Cmd.Run() - if pass1Err != nil { - c.sendLog(taskID, types.LogLevelWarn, fmt.Sprintf("Pass 1 completed (warnings expected): %v", pass1Err), "generate_video") - } - - // PASS 2: Encoding pass (uses statistics from pass 1 for optimal quality) - c.sendLog(taskID, types.LogLevelInfo, "Pass 2/2: Encoding video with optimal quality distribution...", "generate_video") - cmd = exec.Command("ffmpeg", "-y", "-f", "image2", "-start_number", fmt.Sprintf("%d", startNumber), "-framerate", fmt.Sprintf("%.2f", frameRate), "-i", patternPath, "-vf", vf, "-c:v", codec, "-pix_fmt", pixFmt, "-r", fmt.Sprintf("%.2f", frameRate), "-color_primaries", "bt709", "-color_trc", "bt709", "-colorspace", "bt709") - cmd.Args = append(cmd.Args, codecArgs...) - cmd.Args = append(cmd.Args, "-pass", "2", outputMP4) - } - - // Create stdout and stderr pipes for streaming - stdoutPipe, err := cmd.StdoutPipe() - if err != nil { - errMsg := fmt.Sprintf("failed to create ffmpeg stdout pipe: %v", err) - c.sendLog(taskID, types.LogLevelError, errMsg, "generate_video") - c.sendStepUpdate(taskID, "generate_video", types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - - stderrPipe, err := cmd.StderrPipe() - if err != nil { - errMsg := fmt.Sprintf("failed to create ffmpeg stderr pipe: %v", err) - c.sendLog(taskID, types.LogLevelError, errMsg, "generate_video") - c.sendStepUpdate(taskID, "generate_video", types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - - cmd.Dir = workDir - - // Start the command - if err := cmd.Start(); err != nil { - errMsg := fmt.Sprintf("failed to start ffmpeg: %v", err) - c.sendLog(taskID, types.LogLevelError, errMsg, "generate_video") - c.sendStepUpdate(taskID, "generate_video", types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - - // Register process for cleanup on shutdown - c.processTracker.Track(taskID, cmd) - defer c.processTracker.Untrack(taskID) - - // Stream stdout line by line - stdoutDone := make(chan bool) - go func() { - defer close(stdoutDone) - scanner := bufio.NewScanner(stdoutPipe) - for scanner.Scan() { - line := scanner.Text() - if line != "" { - // Filter out common ffmpeg informational messages that aren't useful - if !strings.Contains(line, "Input #") && - !strings.Contains(line, "Duration:") && - !strings.Contains(line, "Stream mapping:") && - !strings.Contains(line, "Output #") && - !strings.Contains(line, "encoder") && - !strings.Contains(line, "fps=") && - !strings.Contains(line, "size=") && - !strings.Contains(line, "time=") && - !strings.Contains(line, "bitrate=") && - !strings.Contains(line, "speed=") { - c.sendLog(taskID, types.LogLevelInfo, line, "generate_video") - } - } - } - }() - - // Stream stderr line by line - stderrDone := make(chan bool) - go func() { - defer close(stderrDone) - scanner := bufio.NewScanner(stderrPipe) - for scanner.Scan() { - line := scanner.Text() - if line != "" { - // Filter out common ffmpeg informational messages and show only warnings/errors - if strings.Contains(line, "error") || - strings.Contains(line, "Error") || - strings.Contains(line, "failed") || - strings.Contains(line, "Failed") || - strings.Contains(line, "warning") || - strings.Contains(line, "Warning") { - c.sendLog(taskID, types.LogLevelWarn, line, "generate_video") - } else if !strings.Contains(line, "Input #") && - !strings.Contains(line, "Duration:") && - !strings.Contains(line, "Stream mapping:") && - !strings.Contains(line, "Output #") && - !strings.Contains(line, "encoder") && - !strings.Contains(line, "fps=") && - !strings.Contains(line, "size=") && - !strings.Contains(line, "time=") && - !strings.Contains(line, "bitrate=") && - !strings.Contains(line, "speed=") { - c.sendLog(taskID, types.LogLevelInfo, line, "generate_video") - } - } - } - }() - - // Wait for command to complete - err = cmd.Wait() - - // Wait for streaming goroutines to finish - <-stdoutDone - <-stderrDone - - if err != nil { - var errMsg string - if exitErr, ok := err.(*exec.ExitError); ok { - if exitErr.ExitCode() == 137 { - errMsg = "FFmpeg was killed due to excessive memory usage (OOM)" - } else { - errMsg = fmt.Sprintf("ffmpeg encoding failed: %v", err) - } - } else { - errMsg = fmt.Sprintf("ffmpeg encoding failed: %v", err) - } - // Check for size-related errors and provide helpful messages - if sizeErr := c.checkFFmpegSizeError(errMsg); sizeErr != nil { - c.sendLog(taskID, types.LogLevelError, sizeErr.Error(), "generate_video") - c.sendStepUpdate(taskID, "generate_video", types.StepStatusFailed, sizeErr.Error()) - return sizeErr - } - c.sendLog(taskID, types.LogLevelError, errMsg, "generate_video") - c.sendStepUpdate(taskID, "generate_video", types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - - // Check if MP4 was created - if _, err := os.Stat(outputMP4); os.IsNotExist(err) { - err := fmt.Errorf("MP4 file not created: %s", outputMP4) - c.sendStepUpdate(taskID, "generate_video", types.StepStatusFailed, err.Error()) - return err - } - - // Clean up 2-pass log files - _ = os.Remove(filepath.Join(workDir, "ffmpeg2pass-0.log")) - _ = os.Remove(filepath.Join(workDir, "ffmpeg2pass-0.log.mbtree")) - - c.sendStepUpdate(taskID, "generate_video", types.StepStatusCompleted, "") - c.sendLog(taskID, types.LogLevelInfo, "MP4 video generated with 2-pass encoding successfully", "generate_video") - - // Step: upload_video - c.sendStepUpdate(taskID, "upload_video", types.StepStatusRunning, "") - c.sendLog(taskID, types.LogLevelInfo, "Uploading MP4 video...", "upload_video") - - // Upload MP4 file - mp4Path, err := c.uploadFile(jobID, outputMP4) - if err != nil { - c.sendStepUpdate(taskID, "upload_video", types.StepStatusFailed, err.Error()) - return fmt.Errorf("failed to upload MP4: %w", err) - } - - c.sendStepUpdate(taskID, "upload_video", types.StepStatusCompleted, "") - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Successfully uploaded MP4: %s", mp4Path), "upload_video") - - // Mark task as complete - if err := c.completeTask(taskID, mp4Path, true, ""); err != nil { - return err - } - - log.Printf("Successfully generated and uploaded MP4 for job %d: %s", jobID, mp4Path) - return nil -} - -// buildFFmpegCommand builds an ffmpeg command with hardware acceleration if available -// If device is provided (non-empty), it will be used for VAAPI encoding -// Returns the command and any error encountered during detection -func (c *Client) buildFFmpegCommand(device string, args ...string) (*exec.Cmd, error) { - // Try hardware encoders in order of preference - // Priority: NVENC (NVIDIA) > VideoToolbox (macOS) > VAAPI (Intel/AMD Linux) > AMF (AMD Windows) > software fallback - - // Check for NVIDIA NVENC - if c.checkEncoderAvailable("h264_nvenc") { - // Insert hardware encoding args before output file - outputIdx := len(args) - 1 - hwArgs := []string{"-c:v", "h264_nvenc", "-preset", "p4", "-b:v", "10M", "-maxrate", "12M", "-bufsize", "20M", "-pix_fmt", "yuv420p"} - newArgs := make([]string, 0, len(args)+len(hwArgs)) - newArgs = append(newArgs, args[:outputIdx]...) - newArgs = append(newArgs, hwArgs...) - newArgs = append(newArgs, args[outputIdx:]...) - return exec.Command("ffmpeg", newArgs...), nil - } - - // Check for VideoToolbox (macOS) - if c.checkEncoderAvailable("h264_videotoolbox") { - outputIdx := len(args) - 1 - hwArgs := []string{"-c:v", "h264_videotoolbox", "-b:v", "10M", "-pix_fmt", "yuv420p"} - newArgs := make([]string, 0, len(args)+len(hwArgs)) - newArgs = append(newArgs, args[:outputIdx]...) - newArgs = append(newArgs, hwArgs...) - newArgs = append(newArgs, args[outputIdx:]...) - return exec.Command("ffmpeg", newArgs...), nil - } - - // Check for VAAPI (Intel/AMD on Linux) - if c.checkEncoderAvailable("h264_vaapi") { - // Use provided device if available, otherwise get the first available - vaapiDevice := device - if vaapiDevice == "" { - vaapiDevice = c.getVAAPIDevice() - } - - if vaapiDevice != "" { - outputIdx := len(args) - 1 - hwArgs := []string{"-vaapi_device", vaapiDevice, "-vf", "format=nv12,hwupload", "-c:v", "h264_vaapi", "-b:v", "10M", "-pix_fmt", "yuv420p"} - newArgs := make([]string, 0, len(args)+len(hwArgs)) - newArgs = append(newArgs, args[:outputIdx]...) - newArgs = append(newArgs, hwArgs...) - newArgs = append(newArgs, args[outputIdx:]...) - return exec.Command("ffmpeg", newArgs...), nil - } - } - - // Check for AMF (AMD on Windows) - if c.checkEncoderAvailable("h264_amf") { - outputIdx := len(args) - 1 - hwArgs := []string{"-c:v", "h264_amf", "-quality", "balanced", "-b:v", "10M", "-pix_fmt", "yuv420p"} - newArgs := make([]string, 0, len(args)+len(hwArgs)) - newArgs = append(newArgs, args[:outputIdx]...) - newArgs = append(newArgs, hwArgs...) - newArgs = append(newArgs, args[outputIdx:]...) - return exec.Command("ffmpeg", newArgs...), nil - } - - // Check for Intel Quick Sync (QSV) - if c.checkEncoderAvailable("h264_qsv") { - outputIdx := len(args) - 1 - hwArgs := []string{"-c:v", "h264_qsv", "-preset", "medium", "-b:v", "10M", "-pix_fmt", "yuv420p"} - newArgs := make([]string, 0, len(args)+len(hwArgs)) - newArgs = append(newArgs, args[:outputIdx]...) - newArgs = append(newArgs, hwArgs...) - newArgs = append(newArgs, args[outputIdx:]...) - return exec.Command("ffmpeg", newArgs...), nil - } - - // No hardware acceleration available - return nil, fmt.Errorf("no hardware encoder available for video encoding - falling back to software encoding which may be slower") -} - -// buildFFmpegCommandAV1 builds an ffmpeg command with AV1 hardware acceleration if available -// If device is provided (non-empty), it will be used for VAAPI encoding -// useAlpha indicates if alpha channel should be preserved -// Returns the command and any error encountered during detection -func (c *Client) buildFFmpegCommandAV1(device string, useAlpha bool, args ...string) (*exec.Cmd, error) { - // Try AV1 hardware encoders in order of preference - // Priority: NVENC (NVIDIA) > QSV (Intel) > VAAPI (Intel/AMD Linux) > AMF (AMD Windows) > software fallback - // Note: Hardware AV1 encoders may not support alpha, so we may need to fall back to software - - // Build HDR tonemapping filter for EXR input - // Hardware encoders need the input to be tonemapped first - var tonemapFilter string - if useAlpha { - tonemapFilter = "zscale=t=linear:npl=100,format=gbrpf32le,zscale=p=bt709,tonemap=tonemap=hable:desat=0,zscale=t=bt709:m=bt709:r=tv,format=yuva420p" - } else { - tonemapFilter = "zscale=t=linear:npl=100,format=gbrpf32le,zscale=p=bt709,tonemap=tonemap=hable:desat=0,zscale=t=bt709:m=bt709:r=tv,format=yuv420p" - } - - // Check for NVIDIA NVENC AV1 (RTX 40 series and newer) - if c.checkEncoderAvailable("av1_nvenc") { - outputIdx := len(args) - 1 - // AV1 NVENC may support alpha, but let's use yuva420p only if useAlpha is true - pixFmt := "yuv420p" - if useAlpha { - // Check if av1_nvenc supports alpha (it should on newer drivers) - pixFmt = "yuva420p" - } - // Insert tonemapping filter and hardware encoding args before output file - hwArgs := []string{"-vf", tonemapFilter, "-c:v", "av1_nvenc", "-preset", "p4", "-b:v", "10M", "-maxrate", "12M", "-bufsize", "20M", "-pix_fmt", pixFmt} - newArgs := make([]string, 0, len(args)+len(hwArgs)) - newArgs = append(newArgs, args[:outputIdx]...) - newArgs = append(newArgs, hwArgs...) - newArgs = append(newArgs, args[outputIdx:]...) - return exec.Command("ffmpeg", newArgs...), nil - } - - // Check for Intel Quick Sync AV1 (Arc GPUs and newer) - if c.checkEncoderAvailable("av1_qsv") { - outputIdx := len(args) - 1 - pixFmt := "yuv420p" - if useAlpha { - // QSV AV1 may support alpha on newer hardware - pixFmt = "yuva420p" - } - // Insert tonemapping filter and hardware encoding args - hwArgs := []string{"-vf", tonemapFilter, "-c:v", "av1_qsv", "-preset", "medium", "-b:v", "10M", "-pix_fmt", pixFmt} - newArgs := make([]string, 0, len(args)+len(hwArgs)) - newArgs = append(newArgs, args[:outputIdx]...) - newArgs = append(newArgs, hwArgs...) - newArgs = append(newArgs, args[outputIdx:]...) - return exec.Command("ffmpeg", newArgs...), nil - } - - // Check for VAAPI AV1 (Intel/AMD on Linux, newer hardware) - if c.checkEncoderAvailable("av1_vaapi") { - // Use provided device if available, otherwise get the first available - vaapiDevice := device - if vaapiDevice == "" { - vaapiDevice = c.getVAAPIDevice() - } - - if vaapiDevice != "" { - outputIdx := len(args) - 1 - pixFmt := "yuv420p" - vaapiFilter := tonemapFilter - if useAlpha { - // VAAPI AV1 may support alpha on newer hardware - // Note: VAAPI may need format conversion before hwupload - pixFmt = "yuva420p" - } - // For VAAPI, we need to tonemap first, then convert format and upload to hardware - vaapiFilter = vaapiFilter + ",format=nv12,hwupload" - hwArgs := []string{"-vaapi_device", vaapiDevice, "-vf", vaapiFilter, "-c:v", "av1_vaapi", "-b:v", "10M", "-pix_fmt", pixFmt} - newArgs := make([]string, 0, len(args)+len(hwArgs)) - newArgs = append(newArgs, args[:outputIdx]...) - newArgs = append(newArgs, hwArgs...) - newArgs = append(newArgs, args[outputIdx:]...) - return exec.Command("ffmpeg", newArgs...), nil - } - } - - // Check for AMD AMF AV1 (newer AMD GPUs) - if c.checkEncoderAvailable("av1_amf") { - outputIdx := len(args) - 1 - pixFmt := "yuv420p" - if useAlpha { - // AMF AV1 may support alpha on newer hardware - pixFmt = "yuva420p" - } - // Insert tonemapping filter and hardware encoding args - hwArgs := []string{"-vf", tonemapFilter, "-c:v", "av1_amf", "-quality", "balanced", "-b:v", "10M", "-pix_fmt", pixFmt} - newArgs := make([]string, 0, len(args)+len(hwArgs)) - newArgs = append(newArgs, args[:outputIdx]...) - newArgs = append(newArgs, hwArgs...) - newArgs = append(newArgs, args[outputIdx:]...) - return exec.Command("ffmpeg", newArgs...), nil - } - - // No AV1 hardware acceleration available - return nil, fmt.Errorf("no AV1 hardware encoder available - falling back to software AV1 encoding which may be slower") -} - -// probeAllHardwareAccelerators probes ffmpeg for all available hardware acceleration methods -// Returns a map of hwaccel method -> true/false -func (c *Client) probeAllHardwareAccelerators() map[string]bool { - hwaccels := make(map[string]bool) - - cmd := exec.Command("ffmpeg", "-hide_banner", "-hwaccels") - output, err := cmd.CombinedOutput() - if err != nil { - log.Printf("Failed to probe hardware accelerators: %v", err) - return hwaccels - } - - // Parse output - hwaccels are listed one per line after "Hardware acceleration methods:" - outputStr := string(output) - lines := strings.Split(outputStr, "\n") - inHwaccelsSection := false - - for _, line := range lines { - line = strings.TrimSpace(line) - if strings.Contains(line, "Hardware acceleration methods:") { - inHwaccelsSection = true - continue - } - if inHwaccelsSection { - if line == "" { - break - } - // Each hwaccel is on its own line - hwaccel := strings.TrimSpace(line) - if hwaccel != "" { - hwaccels[hwaccel] = true - } - } - } - - return hwaccels -} - -// probeAllHardwareEncoders probes ffmpeg for all available hardware encoders -// Returns a map of encoder name -> true/false -func (c *Client) probeAllHardwareEncoders() map[string]bool { - encoders := make(map[string]bool) - - cmd := exec.Command("ffmpeg", "-hide_banner", "-encoders") - output, err := cmd.CombinedOutput() - if err != nil { - log.Printf("Failed to probe encoders: %v", err) - return encoders - } - - // Parse output - encoders are listed with format: " V..... h264_nvenc" - outputStr := string(output) - lines := strings.Split(outputStr, "\n") - inEncodersSection := false - - // Common hardware encoder patterns - hwPatterns := []string{ - "_nvenc", "_vaapi", "_qsv", "_videotoolbox", "_amf", "_v4l2m2m", "_omx", "_mediacodec", - } - - for _, line := range lines { - line = strings.TrimSpace(line) - if strings.Contains(line, "Encoders:") || strings.Contains(line, "Codecs:") { - inEncodersSection = true - continue - } - if inEncodersSection { - // Encoder lines typically look like: " V..... h264_nvenc H.264 / AVC / MPEG-4 AVC (NVIDIA NVENC)" - // Split by whitespace and check if any part matches hardware patterns - parts := strings.Fields(line) - for _, part := range parts { - for _, pattern := range hwPatterns { - if strings.Contains(part, pattern) { - encoders[part] = true - break - } - } - } - } - } - - return encoders -} - -// checkEncoderAvailable checks if an ffmpeg encoder is available and actually usable -func (c *Client) checkEncoderAvailable(encoder string) bool { - // Check cache first - c.hwAccelCacheMu.RLock() - if cached, ok := c.hwAccelCache[encoder]; ok { - c.hwAccelCacheMu.RUnlock() - return cached - } - c.hwAccelCacheMu.RUnlock() - - // Initialize cache if needed - c.hwAccelCacheMu.Lock() - if c.hwAccelCache == nil { - c.hwAccelCache = make(map[string]bool) - } - c.hwAccelCacheMu.Unlock() - - // First check if encoder is listed in encoders output - cmd := exec.Command("ffmpeg", "-hide_banner", "-encoders") - output, err := cmd.CombinedOutput() - if err != nil { - c.hwAccelCacheMu.Lock() - c.hwAccelCache[encoder] = false - c.hwAccelCacheMu.Unlock() - return false - } - - encoderOutput := string(output) - // Check for exact encoder name (more reliable than just contains) - encoderPattern := regexp.MustCompile(`\b` + regexp.QuoteMeta(encoder) + `\b`) - if !encoderPattern.MatchString(encoderOutput) { - // Also try case-insensitive and without exact word boundary - if !strings.Contains(strings.ToLower(encoderOutput), strings.ToLower(encoder)) { - c.hwAccelCacheMu.Lock() - c.hwAccelCache[encoder] = false - c.hwAccelCacheMu.Unlock() - return false - } - } - - // Check hardware acceleration methods that might be needed - hwaccelCmd := exec.Command("ffmpeg", "-hide_banner", "-hwaccels") - hwaccelOutput, err := hwaccelCmd.CombinedOutput() - hwaccelStr := "" - if err == nil { - hwaccelStr = string(hwaccelOutput) - } - - // Encoder-specific detection and testing - var available bool - switch encoder { - case "h264_nvenc", "hevc_nvenc": - // NVENC - check for CUDA/NVENC support - hasCuda := strings.Contains(hwaccelStr, "cuda") || strings.Contains(hwaccelStr, "cuvid") - if hasCuda { - available = c.testNVENCEncoder() - } else { - // Some builds have NVENC without CUDA hwaccel, still test - available = c.testNVENCEncoder() - } - case "h264_vaapi", "hevc_vaapi": - // VAAPI needs device setup - // Check if encoder is listed first (more reliable than hwaccels check) - hasVAAPI := strings.Contains(hwaccelStr, "vaapi") - if hasVAAPI { - available = c.testVAAPIEncoder() - } else { - // Even if hwaccels doesn't show vaapi, the encoder might still work - // Try testing anyway (some builds have the encoder but not the hwaccel method) - log.Printf("VAAPI not in hwaccels list, but encoder found - testing anyway") - available = c.testVAAPIEncoder() - } - case "h264_qsv", "hevc_qsv": - // QSV needs specific setup - hasQSV := strings.Contains(hwaccelStr, "qsv") - if hasQSV { - available = c.testQSVEncoder() - } else { - available = false - } - case "h264_videotoolbox", "hevc_videotoolbox": - // VideoToolbox on macOS - hasVideoToolbox := strings.Contains(hwaccelStr, "videotoolbox") - if hasVideoToolbox { - available = c.testVideoToolboxEncoder() - } else { - available = false - } - case "h264_amf", "hevc_amf": - // AMF on Windows - hasAMF := strings.Contains(hwaccelStr, "d3d11va") || strings.Contains(hwaccelStr, "dxva2") - if hasAMF { - available = c.testAMFEncoder() - } else { - available = false - } - case "h264_v4l2m2m", "hevc_v4l2m2m": - // V4L2 M2M (Video4Linux2 Memory-to-Memory) on Linux - available = c.testV4L2M2MEncoder() - case "h264_omx", "hevc_omx": - // OpenMAX on Raspberry Pi - available = c.testOMXEncoder() - case "h264_mediacodec", "hevc_mediacodec": - // MediaCodec on Android - available = c.testMediaCodecEncoder() - default: - // Generic test for other encoders - available = c.testGenericEncoder(encoder) - } - - // Cache the result - c.hwAccelCacheMu.Lock() - c.hwAccelCache[encoder] = available - c.hwAccelCacheMu.Unlock() - - return available -} - -// testNVENCEncoder tests NVIDIA NVENC encoder -func (c *Client) testNVENCEncoder() bool { - // Test with a simple encode - testCmd := exec.Command("ffmpeg", - "-f", "lavfi", - "-i", "color=c=black:s=64x64:d=0.1", - "-c:v", "h264_nvenc", - "-preset", "p1", - "-frames:v", "1", - "-f", "null", - "-", - ) - testCmd.Stdout = nil - testCmd.Stderr = nil - err := testCmd.Run() - return err == nil -} - -// testVAAPIEncoder tests VAAPI encoder and finds all available devices -func (c *Client) testVAAPIEncoder() bool { - // First, find all available VAAPI devices - devices := c.findVAAPIDevices() - if len(devices) == 0 { - log.Printf("VAAPI test failed: No devices found") - return false - } - - // Test with each device until one works - for _, device := range devices { - log.Printf("Testing VAAPI device: %s", device) - - // Try multiple test approaches with proper parameters - testCommands := [][]string{ - // Standard test with proper size and bitrate - {"-vaapi_device", device, "-f", "lavfi", "-i", "color=c=black:s=1920x1080:d=0.1", "-vf", "format=nv12,hwupload", "-c:v", "h264_vaapi", "-b:v", "1M", "-frames:v", "1", "-f", "null", "-"}, - // Try with smaller but still reasonable size - {"-vaapi_device", device, "-f", "lavfi", "-i", "color=c=black:s=640x480:d=0.1", "-vf", "format=nv12,hwupload", "-c:v", "h264_vaapi", "-b:v", "1M", "-frames:v", "1", "-f", "null", "-"}, - // Try with minimum reasonable size - {"-vaapi_device", device, "-f", "lavfi", "-i", "color=c=black:s=64x64:d=0.1", "-vf", "format=nv12,hwupload", "-c:v", "h264_vaapi", "-b:v", "1M", "-frames:v", "1", "-f", "null", "-"}, - } - - for i, testArgs := range testCommands { - testCmd := exec.Command("ffmpeg", testArgs...) - var stderr bytes.Buffer - testCmd.Stdout = nil - testCmd.Stderr = &stderr - err := testCmd.Run() - if err == nil { - log.Printf("VAAPI device %s works with test method %d", device, i+1) - return true - } - // Log error for debugging but continue trying - if i == 0 { - log.Printf("VAAPI device %s test failed (method %d): %v, stderr: %s", device, i+1, err, stderr.String()) - } - } - } - - log.Printf("VAAPI test failed: All devices failed all test methods") - return false -} - -// findVAAPIDevices finds all available VAAPI render devices -func (c *Client) findVAAPIDevices() []string { - // Check cache first - c.vaapiDevicesMu.RLock() - if len(c.vaapiDevices) > 0 { - // Verify devices still exist - validDevices := make([]string, 0, len(c.vaapiDevices)) - for _, device := range c.vaapiDevices { - if _, err := os.Stat(device); err == nil { - validDevices = append(validDevices, device) - } - } - if len(validDevices) > 0 { - c.vaapiDevicesMu.RUnlock() - // Update cache if some devices were removed - if len(validDevices) != len(c.vaapiDevices) { - c.vaapiDevicesMu.Lock() - c.vaapiDevices = validDevices - c.vaapiDevicesMu.Unlock() - } - return validDevices - } - } - c.vaapiDevicesMu.RUnlock() - - log.Printf("Discovering VAAPI devices...") - - // Build list of potential device paths - deviceCandidates := []string{} - - // First, check /dev/dri for render nodes (preferred) - if entries, err := os.ReadDir("/dev/dri"); err == nil { - log.Printf("Found %d entries in /dev/dri", len(entries)) - for _, entry := range entries { - if strings.HasPrefix(entry.Name(), "renderD") { - devPath := filepath.Join("/dev/dri", entry.Name()) - deviceCandidates = append(deviceCandidates, devPath) - log.Printf("Found render node: %s", devPath) - } else if strings.HasPrefix(entry.Name(), "card") { - // Also try card devices as fallback - devPath := filepath.Join("/dev/dri", entry.Name()) - deviceCandidates = append(deviceCandidates, devPath) - log.Printf("Found card device: %s", devPath) - } - } - } else { - log.Printf("Failed to read /dev/dri: %v", err) - } - - // Also try common device paths as fallback - commonDevices := []string{ - "/dev/dri/renderD128", - "/dev/dri/renderD129", - "/dev/dri/renderD130", - "/dev/dri/renderD131", - "/dev/dri/renderD132", - "/dev/dri/card0", - "/dev/dri/card1", - "/dev/dri/card2", - } - for _, dev := range commonDevices { - // Only add if not already in candidates - found := false - for _, candidate := range deviceCandidates { - if candidate == dev { - found = true - break - } - } - if !found { - deviceCandidates = append(deviceCandidates, dev) - } - } - - log.Printf("Testing %d device candidates for VAAPI", len(deviceCandidates)) - - // Test each device and collect working ones - workingDevices := []string{} - for _, device := range deviceCandidates { - if _, err := os.Stat(device); err != nil { - log.Printf("Device %s does not exist, skipping", device) - continue - } - - log.Printf("Testing VAAPI device: %s", device) - - // Try multiple test methods with proper frame sizes and bitrate - // VAAPI encoders require minimum frame sizes and bitrate parameters - testMethods := [][]string{ - // Standard test with proper size and bitrate - {"-vaapi_device", device, "-f", "lavfi", "-i", "color=c=black:s=1920x1080:d=0.1", "-vf", "format=nv12,hwupload", "-c:v", "h264_vaapi", "-b:v", "1M", "-frames:v", "1", "-f", "null", "-"}, - // Try with smaller but still reasonable size - {"-vaapi_device", device, "-f", "lavfi", "-i", "color=c=black:s=640x480:d=0.1", "-vf", "format=nv12,hwupload", "-c:v", "h264_vaapi", "-b:v", "1M", "-frames:v", "1", "-f", "null", "-"}, - // Try with minimum reasonable size - {"-vaapi_device", device, "-f", "lavfi", "-i", "color=c=black:s=64x64:d=0.1", "-vf", "format=nv12,hwupload", "-c:v", "h264_vaapi", "-b:v", "1M", "-frames:v", "1", "-f", "null", "-"}, - } - - deviceWorks := false - for i, testArgs := range testMethods { - testCmd := exec.Command("ffmpeg", testArgs...) - var stderr bytes.Buffer - testCmd.Stdout = nil - testCmd.Stderr = &stderr - err := testCmd.Run() - if err == nil { - log.Printf("VAAPI device %s works (method %d)", device, i+1) - workingDevices = append(workingDevices, device) - deviceWorks = true - break - } - if i == 0 { - // Log first failure for debugging - log.Printf("VAAPI device %s test failed (method %d): %v", device, i+1, err) - if stderr.Len() > 0 { - log.Printf(" stderr: %s", strings.TrimSpace(stderr.String())) - } - } - } - - if !deviceWorks { - log.Printf("VAAPI device %s failed all test methods", device) - } - } - - log.Printf("Found %d working VAAPI device(s): %v", len(workingDevices), workingDevices) - - // Cache all working devices - c.vaapiDevicesMu.Lock() - c.vaapiDevices = workingDevices - c.vaapiDevicesMu.Unlock() - - return workingDevices -} - -// getVAAPIDevice returns the first available VAAPI device, or empty string if none -func (c *Client) getVAAPIDevice() string { - devices := c.findVAAPIDevices() - if len(devices) > 0 { - return devices[0] - } - return "" -} - -// allocateVAAPIDevice allocates an available VAAPI device to a task -// Returns the device path, or empty string if no device is available -func (c *Client) allocateVAAPIDevice(taskID int64) string { - c.allocatedDevicesMu.Lock() - defer c.allocatedDevicesMu.Unlock() - - // Initialize map if needed - if c.allocatedDevices == nil { - c.allocatedDevices = make(map[int64]string) - } - - // Get all available devices - allDevices := c.findVAAPIDevices() - if len(allDevices) == 0 { - return "" - } - - // Find which devices are currently allocated - allocatedSet := make(map[string]bool) - for _, allocatedDevice := range c.allocatedDevices { - allocatedSet[allocatedDevice] = true - } - - // Find the first available (not allocated) device - for _, device := range allDevices { - if !allocatedSet[device] { - c.allocatedDevices[taskID] = device - log.Printf("Allocated VAAPI device %s to task %d", device, taskID) - return device - } - } - - // All devices are in use - log.Printf("No available VAAPI devices for task %d (all %d devices in use)", taskID, len(allDevices)) - return "" -} - -// releaseVAAPIDevice releases a VAAPI device allocated to a task -func (c *Client) releaseVAAPIDevice(taskID int64) { - c.allocatedDevicesMu.Lock() - defer c.allocatedDevicesMu.Unlock() - - if c.allocatedDevices == nil { - return - } - - if device, ok := c.allocatedDevices[taskID]; ok { - delete(c.allocatedDevices, taskID) - log.Printf("Released VAAPI device %s from task %d", device, taskID) - } -} - -// testQSVEncoder tests Intel Quick Sync Video encoder -func (c *Client) testQSVEncoder() bool { - // QSV can work with different backends - testCmd := exec.Command("ffmpeg", - "-f", "lavfi", - "-i", "color=c=black:s=64x64:d=0.1", - "-c:v", "h264_qsv", - "-preset", "medium", - "-frames:v", "1", - "-f", "null", - "-", - ) - testCmd.Stdout = nil - testCmd.Stderr = nil - err := testCmd.Run() - return err == nil -} - -// testVideoToolboxEncoder tests macOS VideoToolbox encoder -func (c *Client) testVideoToolboxEncoder() bool { - testCmd := exec.Command("ffmpeg", - "-f", "lavfi", - "-i", "color=c=black:s=64x64:d=0.1", - "-c:v", "h264_videotoolbox", - "-frames:v", "1", - "-f", "null", - "-", - ) - testCmd.Stdout = nil - testCmd.Stderr = nil - err := testCmd.Run() - return err == nil -} - -// testAMFEncoder tests AMD AMF encoder -func (c *Client) testAMFEncoder() bool { - testCmd := exec.Command("ffmpeg", - "-f", "lavfi", - "-i", "color=c=black:s=64x64:d=0.1", - "-c:v", "h264_amf", - "-quality", "balanced", - "-frames:v", "1", - "-f", "null", - "-", - ) - testCmd.Stdout = nil - testCmd.Stderr = nil - err := testCmd.Run() - return err == nil -} - -// testV4L2M2MEncoder tests V4L2 M2M encoder (Video4Linux2 Memory-to-Memory) -func (c *Client) testV4L2M2MEncoder() bool { - testCmd := exec.Command("ffmpeg", - "-f", "lavfi", - "-i", "color=c=black:s=64x64:d=0.1", - "-c:v", "h264_v4l2m2m", - "-frames:v", "1", - "-f", "null", - "-", - ) - testCmd.Stdout = nil - testCmd.Stderr = nil - err := testCmd.Run() - return err == nil -} - -// testOMXEncoder tests OpenMAX encoder (Raspberry Pi) -func (c *Client) testOMXEncoder() bool { - testCmd := exec.Command("ffmpeg", - "-f", "lavfi", - "-i", "color=c=black:s=64x64:d=0.1", - "-c:v", "h264_omx", - "-frames:v", "1", - "-f", "null", - "-", - ) - testCmd.Stdout = nil - testCmd.Stderr = nil - err := testCmd.Run() - return err == nil -} - -// testMediaCodecEncoder tests MediaCodec encoder (Android) -func (c *Client) testMediaCodecEncoder() bool { - testCmd := exec.Command("ffmpeg", - "-f", "lavfi", - "-i", "color=c=black:s=64x64:d=0.1", - "-c:v", "h264_mediacodec", - "-frames:v", "1", - "-f", "null", - "-", - ) - testCmd.Stdout = nil - testCmd.Stderr = nil - err := testCmd.Run() - return err == nil -} - -// testGenericEncoder tests a generic encoder -func (c *Client) testGenericEncoder(encoder string) bool { - testCmd := exec.Command("ffmpeg", - "-f", "lavfi", - "-i", "color=c=black:s=64x64:d=0.1", - "-c:v", encoder, - "-frames:v", "1", - "-f", "null", - "-", - ) - testCmd.Stdout = nil - testCmd.Stderr = nil - err := testCmd.Run() - return err == nil -} - -// generateMP4WithConcat uses ffmpeg concat demuxer as fallback -// device parameter is optional - if provided, it will be used for VAAPI encoding -func (c *Client) generateMP4WithConcat(taskID int, frameFiles []string, outputMP4, workDir string, device string, outputFormat string, codec string, pixFmt string, useAlpha bool, useHardware bool, frameRate float64) error { - // Create file list for ffmpeg concat demuxer - listFile := filepath.Join(workDir, "frames.txt") - listFileHandle, err := os.Create(listFile) - if err != nil { - return fmt.Errorf("failed to create list file: %w", err) - } - - for _, frameFile := range frameFiles { - absPath, _ := filepath.Abs(frameFile) - fmt.Fprintf(listFileHandle, "file '%s'\n", absPath) - } - listFileHandle.Close() - - // Build video filter for HDR to SDR conversion - var vf string - if useAlpha { - // For AV1 with alpha: preserve alpha channel during tonemapping - vf = "zscale=t=linear:npl=100,format=gbrpf32le,zscale=p=bt709,tonemap=tonemap=hable:desat=0,zscale=t=bt709:m=bt709:r=tv,format=yuva420p" - } else { - // For H.264 without alpha: standard tonemapping - vf = "zscale=t=linear:npl=100,format=gbrpf32le,zscale=p=bt709,tonemap=tonemap=hable:desat=0,zscale=t=bt709:m=bt709:r=tv,format=yuv420p" - } - - // Run ffmpeg with concat demuxer - // EXR frames are 32-bit float HDR - FFmpeg will tonemap automatically - var cmd *exec.Cmd - - if useHardware { - if outputFormat == "EXR_AV1_MP4" { - // Try AV1 hardware acceleration - cmd, err = c.buildFFmpegCommandAV1(device, useAlpha, "-f", "concat", "-safe", "0", "-i", listFile, - "-r", "24", "-y", outputMP4) - if err != nil { - useHardware = false // Fall back to software - } - } else { - // Try H.264 hardware acceleration - if device != "" { - cmd, err = c.buildFFmpegCommand(device, "-f", "concat", "-safe", "0", "-i", listFile, - "-r", "24", "-y", outputMP4) - if err != nil { - useHardware = false // Fall back to software - } - } - } - } - - if !useHardware { - // Software encoding with HDR tonemapping - 2-pass for optimal quality - var codecArgs []string - if outputFormat == "EXR_AV1_MP4" { - codecArgs = []string{"-cpu-used", "1", "-crf", "15", "-b:v", "0", "-row-mt", "1", "-tiles", "4x4", "-lag-in-frames", "25", "-arnr-max-frames", "15", "-arnr-strength", "4"} - } else { - codecArgs = []string{"-preset", "veryslow", "-crf", "15", "-profile:v", "high", "-level", "5.2", "-tune", "film", "-keyint_min", "24", "-g", "240", "-bf", "2", "-refs", "4"} - } - - // PASS 1: Analysis pass - pass1Args := append([]string{"-f", "concat", "-safe", "0", "-i", listFile, "-vf", vf, "-c:v", codec, "-pix_fmt", pixFmt, "-r", fmt.Sprintf("%.2f", frameRate)}, codecArgs...) - pass1Args = append(pass1Args, "-pass", "1", "-f", "null", "/dev/null") - pass1Cmd := exec.Command("ffmpeg", pass1Args...) - pass1Cmd.Dir = workDir - _ = pass1Cmd.Run() // Ignore errors for pass 1 - - // PASS 2: Encoding pass - cmd = exec.Command("ffmpeg", "-f", "concat", "-safe", "0", "-i", listFile, "-vf", vf, "-c:v", codec, "-pix_fmt", pixFmt, "-r", fmt.Sprintf("%.2f", frameRate)) - cmd.Args = append(cmd.Args, codecArgs...) - cmd.Args = append(cmd.Args, "-pass", "2", "-y", outputMP4) - } - - // Create stdout and stderr pipes for streaming - stdoutPipe, err := cmd.StdoutPipe() - if err != nil { - return fmt.Errorf("failed to create ffmpeg stdout pipe: %w", err) - } - - stderrPipe, err := cmd.StderrPipe() - if err != nil { - return fmt.Errorf("failed to create ffmpeg stderr pipe: %w", err) - } - - cmd.Dir = workDir - - // Start the command - if err := cmd.Start(); err != nil { - return fmt.Errorf("failed to start ffmpeg: %w", err) - } - - // Stream stdout line by line (minimal logging for concat method) - stdoutDone := make(chan bool) - go func() { - defer close(stdoutDone) - scanner := bufio.NewScanner(stdoutPipe) - for scanner.Scan() { - line := scanner.Text() - if line != "" { - // Only log actual errors/warnings for concat method - if strings.Contains(line, "error") || - strings.Contains(line, "Error") || - strings.Contains(line, "failed") || - strings.Contains(line, "Failed") { - log.Printf("FFmpeg concat stdout: %s", line) - } - } - } - }() - - // Stream stderr line by line - stderrDone := make(chan bool) - go func() { - defer close(stderrDone) - scanner := bufio.NewScanner(stderrPipe) - for scanner.Scan() { - line := scanner.Text() - if line != "" { - // Log warnings and errors for concat method - if strings.Contains(line, "error") || - strings.Contains(line, "Error") || - strings.Contains(line, "failed") || - strings.Contains(line, "Failed") || - strings.Contains(line, "warning") || - strings.Contains(line, "Warning") { - log.Printf("FFmpeg concat stderr: %s", line) - } - } - } - }() - - // Wait for command to complete - err = cmd.Wait() - - // Wait for streaming goroutines to finish - <-stdoutDone - <-stderrDone - - if err != nil { - var errMsg string - if exitErr, ok := err.(*exec.ExitError); ok { - if exitErr.ExitCode() == 137 { - errMsg = "FFmpeg was killed due to excessive memory usage (OOM)" - } else { - errMsg = fmt.Sprintf("ffmpeg concat failed: %v", err) - } - } else { - errMsg = fmt.Sprintf("ffmpeg concat failed: %v", err) - } - // Check for size-related errors - if sizeErr := c.checkFFmpegSizeError(errMsg); sizeErr != nil { - return sizeErr - } - c.sendLog(int64(taskID), types.LogLevelError, errMsg, "generate_video") - c.sendStepUpdate(int64(taskID), "generate_video", types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - - if _, err := os.Stat(outputMP4); os.IsNotExist(err) { - return fmt.Errorf("MP4 file not created: %s", outputMP4) - } - - // Clean up 2-pass log files - _ = os.Remove(filepath.Join(workDir, "ffmpeg2pass-0.log")) - _ = os.Remove(filepath.Join(workDir, "ffmpeg2pass-0.log.mbtree")) - - return nil -} - -// checkFFmpegSizeError checks ffmpeg output for size-related errors and returns a helpful error message -func (c *Client) checkFFmpegSizeError(output string) error { - outputLower := strings.ToLower(output) - - // Check for hardware encoding size constraints - if strings.Contains(outputLower, "hardware does not support encoding at size") { - // Extract size constraints if available - constraintsMatch := regexp.MustCompile(`constraints:\s*width\s+(\d+)-(\d+)\s+height\s+(\d+)-(\d+)`).FindStringSubmatch(output) - if len(constraintsMatch) == 5 { - return fmt.Errorf("video frame size is outside hardware encoder limits. Hardware requires: width %s-%s, height %s-%s. Please adjust your render resolution to fit within these constraints", - constraintsMatch[1], constraintsMatch[2], constraintsMatch[3], constraintsMatch[4]) - } - return fmt.Errorf("video frame size is outside hardware encoder limits. Please adjust your render resolution") - } - - // Check for invalid picture size - if strings.Contains(outputLower, "picture size") && strings.Contains(outputLower, "is invalid") { - sizeMatch := regexp.MustCompile(`picture size\s+(\d+)x(\d+)`).FindStringSubmatch(output) - if len(sizeMatch) == 3 { - return fmt.Errorf("invalid video frame size: %sx%s. Frame dimensions are too large or invalid", sizeMatch[1], sizeMatch[2]) - } - return fmt.Errorf("invalid video frame size. Frame dimensions are too large or invalid") - } - - // Check for encoder parameter errors mentioning width/height - if strings.Contains(outputLower, "error while opening encoder") && - (strings.Contains(outputLower, "width") || strings.Contains(outputLower, "height") || strings.Contains(outputLower, "size")) { - // Try to extract the actual size if mentioned - sizeMatch := regexp.MustCompile(`at size\s+(\d+)x(\d+)`).FindStringSubmatch(output) - if len(sizeMatch) == 3 { - return fmt.Errorf("hardware encoder cannot encode frame size %sx%s. The frame dimensions may be too small, too large, or not supported by the hardware encoder", sizeMatch[1], sizeMatch[2]) - } - return fmt.Errorf("hardware encoder error: frame size may be invalid. Common issues: frame too small (minimum usually 128x128) or too large (maximum varies by hardware)") - } - - // Check for general size-related errors - if strings.Contains(outputLower, "invalid") && - (strings.Contains(outputLower, "width") || strings.Contains(outputLower, "height") || strings.Contains(outputLower, "dimension")) { - return fmt.Errorf("invalid frame dimensions detected. Please check your render resolution settings") - } - - return nil -} - -// extractFrameNumber extracts frame number from filename like "frame_0001.exr" or "frame_0001.png" -func extractFrameNumber(filename string) int { - parts := strings.Split(filepath.Base(filename), "_") - if len(parts) < 2 { - return 0 - } - framePart := strings.Split(parts[1], ".")[0] - var frameNum int - fmt.Sscanf(framePart, "%d", &frameNum) - return frameNum -} - -// getJobFiles gets job files from manager -func (c *Client) getJobFiles(jobID int64) ([]map[string]interface{}, error) { - path := fmt.Sprintf("/api/runner/jobs/%d/files", jobID) - resp, err := c.doSignedRequest("GET", path, nil, fmt.Sprintf("runner_id=%d", c.runnerID)) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("failed to get job files: %s", string(body)) - } - - var files []map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&files); err != nil { - return nil, err - } - - return files, nil -} - -// getJobMetadata gets job metadata from manager -func (c *Client) getJobMetadata(jobID int64) (*types.BlendMetadata, error) { - path := fmt.Sprintf("/api/runner/jobs/%d/metadata", jobID) - resp, err := c.doSignedRequest("GET", path, nil, fmt.Sprintf("runner_id=%d", c.runnerID)) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusNotFound { - return nil, nil // No metadata found, not an error - } - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("failed to get job metadata: %s", string(body)) - } - - var metadata types.BlendMetadata - if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil { - return nil, err - } - - return &metadata, nil -} - -// downloadFrameFile downloads a frame file for MP4 generation -func (c *Client) downloadFrameFile(jobID int64, fileName, destPath string) error { - // URL encode the fileName to handle special characters in filenames - encodedFileName := url.PathEscape(fileName) - path := fmt.Sprintf("/api/runner/files/%d/%s", jobID, encodedFileName) - // Use long-running client for file downloads (no timeout) - EXR files can be large - resp, err := c.doSignedRequestLong("GET", path, nil, fmt.Sprintf("runner_id=%d", c.runnerID)) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("download failed: %s", string(body)) - } - - file, err := os.Create(destPath) - if err != nil { - return err - } - defer file.Close() - - _, err = io.Copy(file, resp.Body) - return err -} - -// downloadFile downloads a file from the manager to a directory (preserves filename only) -func (c *Client) downloadFile(filePath, destDir string) error { - fileName := filepath.Base(filePath) - destPath := filepath.Join(destDir, fileName) - return c.downloadFileToPath(filePath, destPath) -} - -// downloadFileToPath downloads a file from the manager to a specific path (preserves directory structure) -func (c *Client) downloadFileToPath(filePath, destPath string) error { - // Extract job ID and relative path from storage path - // Path format: storage/jobs/{jobID}/{relativePath} - parts := strings.Split(strings.TrimPrefix(filePath, "./"), "/") - if len(parts) < 3 { - return fmt.Errorf("invalid file path format: %s", filePath) - } - - // Find job ID in path (look for "jobs" directory) - jobID := "" - var relPathParts []string - foundJobs := false - for i, part := range parts { - if part == "jobs" && i+1 < len(parts) { - jobID = parts[i+1] - foundJobs = true - if i+2 < len(parts) { - relPathParts = parts[i+2:] - } - break - } - } - - if !foundJobs || jobID == "" { - return fmt.Errorf("could not extract job ID from path: %s", filePath) - } - - // Build download path - preserve relative path structure - downloadPath := fmt.Sprintf("/api/runner/files/%s", jobID) - if len(relPathParts) > 0 { - // URL encode each path component - for _, part := range relPathParts { - downloadPath += "/" + part - } - } else { - // Fallback to filename only - downloadPath += "/" + filepath.Base(filePath) - } - - // Use long-running client for file downloads (no timeout) - resp, err := c.doSignedRequestLong("GET", downloadPath, nil, fmt.Sprintf("runner_id=%d", c.runnerID)) - if err != nil { - return fmt.Errorf("failed to download file: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("download failed: %s", string(body)) - } - - // Ensure destination directory exists - destDir := filepath.Dir(destPath) - if err := os.MkdirAll(destDir, 0755); err != nil { - return fmt.Errorf("failed to create destination directory: %w", err) - } - - file, err := os.Create(destPath) - if err != nil { - return fmt.Errorf("failed to create destination file: %w", err) - } - defer file.Close() - - _, err = io.Copy(file, resp.Body) - return err -} - -// uploadFile uploads a file to the manager -func (c *Client) uploadFile(jobID int64, filePath string) (string, error) { - file, err := os.Open(filePath) - if err != nil { - return "", fmt.Errorf("failed to open file: %w", err) - } - defer file.Close() - - // Create multipart form - var buf bytes.Buffer - formWriter := multipart.NewWriter(&buf) - - part, err := formWriter.CreateFormFile("file", filepath.Base(filePath)) - if err != nil { - return "", fmt.Errorf("failed to create form file: %w", err) - } - - _, err = io.Copy(part, file) - if err != nil { - return "", fmt.Errorf("failed to copy file data: %w", err) - } - - formWriter.Close() - - // Upload file with shared secret - path := fmt.Sprintf("/api/runner/files/%d/upload?runner_id=%d", jobID, c.runnerID) - url := fmt.Sprintf("%s%s", c.managerURL, path) - req, err := http.NewRequest("POST", url, &buf) - if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", formWriter.FormDataContentType()) - req.Header.Set("Authorization", "Bearer "+c.apiKey) - - // Use long-running client for file uploads (no timeout) - resp, err := c.longRunningClient.Do(req) - if err != nil { - return "", fmt.Errorf("failed to upload file: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusCreated { - body, _ := io.ReadAll(resp.Body) - return "", fmt.Errorf("upload failed: %s", string(body)) - } - - var result struct { - FilePath string `json:"file_path"` - FileName string `json:"file_name"` - } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return "", fmt.Errorf("failed to decode response: %w", err) - } - - return result.FilePath, nil -} - -// getContextCacheKey generates a cache key for a job's context -func (c *Client) getContextCacheKey(jobID int64) string { - // Use job ID as the cache key (context is regenerated when job files change) - return fmt.Sprintf("job_%d", jobID) -} - -// getContextCachePath returns the path to a cached context file -func (c *Client) getContextCachePath(cacheKey string) string { - cacheDir := filepath.Join(c.getWorkspaceDir(), "cache", "contexts") - os.MkdirAll(cacheDir, 0755) - return filepath.Join(cacheDir, cacheKey+".tar") -} - -// isContextCacheValid checks if a cached context file exists and is not expired (1 hour TTL) -func (c *Client) isContextCacheValid(cachePath string) bool { - info, err := os.Stat(cachePath) - if err != nil { - return false - } - // Check if file is less than 1 hour old - return time.Since(info.ModTime()) < time.Hour -} - -// downloadJobContext downloads the job context tar, using cache if available -func (c *Client) downloadJobContext(jobID int64, destPath string) error { - cacheKey := c.getContextCacheKey(jobID) - cachePath := c.getContextCachePath(cacheKey) - - // Check cache first - if c.isContextCacheValid(cachePath) { - log.Printf("Using cached context for job %d", jobID) - // Copy from cache to destination - src, err := os.Open(cachePath) - if err != nil { - log.Printf("Failed to open cached context, will download: %v", err) - } else { - defer src.Close() - dst, err := os.Create(destPath) - if err != nil { - return fmt.Errorf("failed to create destination file: %w", err) - } - defer dst.Close() - _, err = io.Copy(dst, src) - if err == nil { - return nil - } - log.Printf("Failed to copy cached context, will download: %v", err) - } - } - - // Download from manager - use long-running client (no timeout) for large context files - path := fmt.Sprintf("/api/runner/jobs/%d/context.tar", jobID) - resp, err := c.doSignedRequestLong("GET", path, nil, fmt.Sprintf("runner_id=%d", c.runnerID)) - if err != nil { - return fmt.Errorf("failed to download context: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("context download failed: %s", string(body)) - } - - // Create temporary file first - tmpPath := destPath + ".tmp" - tmpFile, err := os.Create(tmpPath) - if err != nil { - return fmt.Errorf("failed to create temporary file: %w", err) - } - defer tmpFile.Close() - defer os.Remove(tmpPath) - - // Stream download to temporary file - _, err = io.Copy(tmpFile, resp.Body) - if err != nil { - return fmt.Errorf("failed to download context: %w", err) - } - tmpFile.Close() - - // Move to final destination - if err := os.Rename(tmpPath, destPath); err != nil { - return fmt.Errorf("failed to move context to destination: %w", err) - } - - // Update cache - cacheDir := filepath.Dir(cachePath) - os.MkdirAll(cacheDir, 0755) - if err := os.Link(destPath, cachePath); err != nil { - // If link fails (e.g., cross-filesystem), copy instead - src, err := os.Open(destPath) - if err == nil { - defer src.Close() - dst, err := os.Create(cachePath) - if err == nil { - defer dst.Close() - io.Copy(dst, src) - } - } - } - - return nil -} - -// extractTar extracts a tar file to the destination directory -func (c *Client) extractTar(tarPath, destDir string) error { - // Open the tar file - file, err := os.Open(tarPath) - if err != nil { - return fmt.Errorf("failed to open tar file: %w", err) - } - defer file.Close() - - // Create tar reader - tarReader := tar.NewReader(file) - - // Extract files - for { - header, err := tarReader.Next() - if err == io.EOF { - break - } - if err != nil { - return fmt.Errorf("failed to read tar header: %w", err) - } - - // Sanitize path to prevent directory traversal - targetPath := filepath.Join(destDir, header.Name) - if !strings.HasPrefix(filepath.Clean(targetPath), filepath.Clean(destDir)+string(os.PathSeparator)) { - return fmt.Errorf("invalid file path in tar: %s", header.Name) - } - - // Handle directories - if header.Typeflag == tar.TypeDir { - if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil { - return fmt.Errorf("failed to create directory: %w", err) - } - continue - } - - // Handle regular files - if header.Typeflag == tar.TypeReg { - // Create parent directories - if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil { - return fmt.Errorf("failed to create parent directory: %w", err) - } - - // Create file - outFile, err := os.Create(targetPath) - if err != nil { - return fmt.Errorf("failed to create file: %w", err) - } - - // Copy file contents - if _, err := io.Copy(outFile, tarReader); err != nil { - outFile.Close() - return fmt.Errorf("failed to extract file: %w", err) - } - - // Set file permissions - if err := os.Chmod(targetPath, os.FileMode(header.Mode)); err != nil { - outFile.Close() - return fmt.Errorf("failed to set file permissions: %w", err) - } - - outFile.Close() - } - } - - return nil -} - -// cleanupExpiredContextCache removes context cache files older than 1 hour -func (c *Client) cleanupExpiredContextCache() { - cacheDir := filepath.Join(c.getWorkspaceDir(), "cache", "contexts") - entries, err := os.ReadDir(cacheDir) - if err != nil { - return - } - - now := time.Now() - for _, entry := range entries { - if entry.IsDir() { - continue - } - info, err := entry.Info() - if err != nil { - continue - } - if now.Sub(info.ModTime()) > time.Hour { - cachePath := filepath.Join(cacheDir, entry.Name()) - os.Remove(cachePath) - log.Printf("Removed expired context cache: %s", entry.Name()) - } - } -} - -// processMetadataTask processes a metadata extraction task -func (c *Client) processMetadataTask(task map[string]interface{}, jobID int64, inputFiles []interface{}) (err error) { - taskID := int64(task["id"].(float64)) - - // Create temporary job workspace for metadata extraction within runner workspace - workDir := filepath.Join(c.getWorkspaceDir(), fmt.Sprintf("job-%d-metadata-%d", jobID, taskID)) - if mkdirErr := os.MkdirAll(workDir, 0755); mkdirErr != nil { - return fmt.Errorf("failed to create work directory: %w", mkdirErr) - } - - // Guaranteed cleanup even on panic - defer func() { - if cleanupErr := os.RemoveAll(workDir); cleanupErr != nil { - log.Printf("Warning: Failed to cleanup work directory %s: %v", workDir, cleanupErr) - } - }() - - // Panic recovery for this task - defer func() { - if r := recover(); r != nil { - log.Printf("Metadata extraction task %d panicked: %v", taskID, r) - err = fmt.Errorf("metadata extraction task panicked: %v", r) - } - }() - - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Starting metadata extraction task: job %d", jobID), "") - log.Printf("Processing metadata extraction task %d for job %d", taskID, jobID) - - // Step: download - c.sendStepUpdate(taskID, "download", types.StepStatusRunning, "") - c.sendLog(taskID, types.LogLevelInfo, "Downloading job context...", "download") - - // Download context tar - contextPath := filepath.Join(workDir, "context.tar") - if err := c.downloadJobContext(jobID, contextPath); err != nil { - c.sendStepUpdate(taskID, "download", types.StepStatusFailed, err.Error()) - return fmt.Errorf("failed to download context: %w", err) - } - - // Extract context tar - c.sendLog(taskID, types.LogLevelInfo, "Extracting context...", "download") - if err := c.extractTar(contextPath, workDir); err != nil { - c.sendStepUpdate(taskID, "download", types.StepStatusFailed, err.Error()) - return fmt.Errorf("failed to extract context: %w", err) - } - - // Find .blend file in extracted contents - blendFile := "" - err = filepath.Walk(workDir, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".blend") { - // Check it's not a Blender save file (.blend1, .blend2, etc.) - lower := strings.ToLower(info.Name()) - idx := strings.LastIndex(lower, ".blend") - if idx != -1 { - suffix := lower[idx+len(".blend"):] - // If there are digits after .blend, it's a save file - isSaveFile := false - if len(suffix) > 0 { - isSaveFile = true - for _, r := range suffix { - if r < '0' || r > '9' { - isSaveFile = false - break - } - } - } - if !isSaveFile { - blendFile = path - return filepath.SkipAll // Stop walking once we find a blend file - } - } - } - return nil - }) - - if err != nil { - c.sendStepUpdate(taskID, "download", types.StepStatusFailed, err.Error()) - return fmt.Errorf("failed to find blend file: %w", err) - } - - if blendFile == "" { - err := fmt.Errorf("no .blend file found in context - the uploaded context archive must contain at least one .blend file to render") - c.sendStepUpdate(taskID, "download", types.StepStatusFailed, err.Error()) - return err - } - - c.sendStepUpdate(taskID, "download", types.StepStatusCompleted, "") - c.sendLog(taskID, types.LogLevelInfo, "Context downloaded and extracted successfully", "download") - - // Step: extract_metadata - c.sendStepUpdate(taskID, "extract_metadata", types.StepStatusRunning, "") - c.sendLog(taskID, types.LogLevelInfo, "Extracting metadata from blend file...", "extract_metadata") - - // Create Python script to extract metadata - scriptPath := filepath.Join(workDir, "extract_metadata.py") - scriptContent := `import bpy -import json -import sys - -# Make all file paths relative to the blend file location FIRST -# This must be done immediately after file load, before any other operations -# to prevent Blender from trying to access external files with absolute paths -try: - bpy.ops.file.make_paths_relative() - print("Made all file paths relative to blend file") -except Exception as e: - print(f"Warning: Could not make paths relative: {e}") - -# Check for missing addons that the blend file requires -# Blender marks missing addons with "_missing" suffix in preferences -missing_files_info = { - "checked": False, - "has_missing": False, - "missing_files": [], - "missing_addons": [] -} - -try: - missing = [] - for mod in bpy.context.preferences.addons: - if mod.module.endswith("_missing"): - missing.append(mod.module.rsplit("_", 1)[0]) - - missing_files_info["checked"] = True - if missing: - missing_files_info["has_missing"] = True - missing_files_info["missing_addons"] = missing - print("Missing add-ons required by this .blend:") - for name in missing: - print(" -", name) - else: - print("No missing add-ons detected – file is headless-safe") -except Exception as e: - print(f"Warning: Could not check for missing addons: {e}") - missing_files_info["error"] = str(e) - -# Get scene -scene = bpy.context.scene - -# Extract frame range from scene settings -frame_start = scene.frame_start -frame_end = scene.frame_end - -# Also check for actual animation range (keyframes) -# Find the earliest and latest keyframes across all objects -animation_start = None -animation_end = None - -for obj in scene.objects: - if obj.animation_data and obj.animation_data.action: - action = obj.animation_data.action - if action.fcurves: - for fcurve in action.fcurves: - if fcurve.keyframe_points: - for keyframe in fcurve.keyframe_points: - frame = int(keyframe.co[0]) - if animation_start is None or frame < animation_start: - animation_start = frame - if animation_end is None or frame > animation_end: - animation_end = frame - -# Use animation range if available, otherwise use scene frame range -# If scene range seems wrong (start == end), prefer animation range -if animation_start is not None and animation_end is not None: - if frame_start == frame_end or (animation_start < frame_start or animation_end > frame_end): - # Use animation range if scene range is invalid or animation extends beyond it - frame_start = animation_start - frame_end = animation_end - -# Extract render settings -render = scene.render -resolution_x = render.resolution_x -resolution_y = render.resolution_y -engine = scene.render.engine.upper() - -# Determine output format from file format -output_format = render.image_settings.file_format - -# Extract engine-specific settings -engine_settings = {} - -if engine == 'CYCLES': - cycles = scene.cycles - engine_settings = { - "samples": getattr(cycles, 'samples', 128), - "use_denoising": getattr(cycles, 'use_denoising', False), - "denoising_radius": getattr(cycles, 'denoising_radius', 0), - "denoising_strength": getattr(cycles, 'denoising_strength', 0.0), - "device": getattr(cycles, 'device', 'CPU'), - "use_adaptive_sampling": getattr(cycles, 'use_adaptive_sampling', False), - "adaptive_threshold": getattr(cycles, 'adaptive_threshold', 0.01) if getattr(cycles, 'use_adaptive_sampling', False) else 0.01, - "use_fast_gi": getattr(cycles, 'use_fast_gi', False), - "light_tree": getattr(cycles, 'use_light_tree', False), - "use_light_linking": getattr(cycles, 'use_light_linking', False), - "caustics_reflective": getattr(cycles, 'caustics_reflective', False), - "caustics_refractive": getattr(cycles, 'caustics_refractive', False), - "blur_glossy": getattr(cycles, 'blur_glossy', 0.0), - "max_bounces": getattr(cycles, 'max_bounces', 12), - "diffuse_bounces": getattr(cycles, 'diffuse_bounces', 4), - "glossy_bounces": getattr(cycles, 'glossy_bounces', 4), - "transmission_bounces": getattr(cycles, 'transmission_bounces', 12), - "volume_bounces": getattr(cycles, 'volume_bounces', 0), - "transparent_max_bounces": getattr(cycles, 'transparent_max_bounces', 8), - "film_transparent": getattr(cycles, 'film_transparent', False), - "use_layer_samples": getattr(cycles, 'use_layer_samples', False), - } -elif engine == 'EEVEE' or engine == 'EEVEE_NEXT': - eevee = scene.eevee - engine_settings = { - "taa_render_samples": getattr(eevee, 'taa_render_samples', 64), - "use_bloom": getattr(eevee, 'use_bloom', False), - "bloom_threshold": getattr(eevee, 'bloom_threshold', 0.8), - "bloom_intensity": getattr(eevee, 'bloom_intensity', 0.05), - "bloom_radius": getattr(eevee, 'bloom_radius', 6.5), - "use_ssr": getattr(eevee, 'use_ssr', True), - "use_ssr_refraction": getattr(eevee, 'use_ssr_refraction', False), - "ssr_quality": getattr(eevee, 'ssr_quality', 'MEDIUM'), - "use_ssao": getattr(eevee, 'use_ssao', True), - "ssao_quality": getattr(eevee, 'ssao_quality', 'MEDIUM'), - "ssao_distance": getattr(eevee, 'ssao_distance', 0.2), - "ssao_factor": getattr(eevee, 'ssao_factor', 1.0), - "use_soft_shadows": getattr(eevee, 'use_soft_shadows', True), - "use_shadow_high_bitdepth": getattr(eevee, 'use_shadow_high_bitdepth', True), - "use_volumetric": getattr(eevee, 'use_volumetric', False), - "volumetric_tile_size": getattr(eevee, 'volumetric_tile_size', '8'), - "volumetric_samples": getattr(eevee, 'volumetric_samples', 64), - "volumetric_start": getattr(eevee, 'volumetric_start', 0.0), - "volumetric_end": getattr(eevee, 'volumetric_end', 100.0), - "use_volumetric_lights": getattr(eevee, 'use_volumetric_lights', True), - "use_volumetric_shadows": getattr(eevee, 'use_volumetric_shadows', True), - "use_gtao": getattr(eevee, 'use_gtao', False), - "gtao_quality": getattr(eevee, 'gtao_quality', 'MEDIUM'), - "use_overscan": getattr(eevee, 'use_overscan', False), - } -else: - # For other engines, extract basic samples if available - engine_settings = { - "samples": getattr(scene, 'samples', 128) if hasattr(scene, 'samples') else 128 - } - -# Extract scene info -camera_count = len([obj for obj in scene.objects if obj.type == 'CAMERA']) -object_count = len(scene.objects) -material_count = len(bpy.data.materials) - -# Build metadata dictionary -metadata = { - "frame_start": frame_start, - "frame_end": frame_end, - "render_settings": { - "resolution_x": resolution_x, - "resolution_y": resolution_y, - "output_format": output_format, - "engine": engine.lower(), - "engine_settings": engine_settings - }, - "scene_info": { - "camera_count": camera_count, - "object_count": object_count, - "material_count": material_count - }, - "missing_files_info": missing_files_info -} - -# Output as JSON -print(json.dumps(metadata)) -sys.stdout.flush() -` - - if err := os.WriteFile(scriptPath, []byte(scriptContent), 0644); err != nil { - c.sendStepUpdate(taskID, "extract_metadata", types.StepStatusFailed, err.Error()) - return fmt.Errorf("failed to create extraction script: %w", err) - } - - // Execute Blender with Python script - // Note: disable_execution flag is not applied to metadata extraction for safety - cmd := exec.Command("blender", "-b", blendFile, "--python", scriptPath) - cmd.Dir = workDir - - // Capture stdout and stderr separately for line-by-line streaming - stdoutPipe, err := cmd.StdoutPipe() - if err != nil { - errMsg := fmt.Sprintf("failed to create stdout pipe: %v", err) - c.sendLog(taskID, types.LogLevelError, errMsg, "extract_metadata") - c.sendStepUpdate(taskID, "extract_metadata", types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - - stderrPipe, err := cmd.StderrPipe() - if err != nil { - errMsg := fmt.Sprintf("failed to create stderr pipe: %v", err) - c.sendLog(taskID, types.LogLevelError, errMsg, "extract_metadata") - c.sendStepUpdate(taskID, "extract_metadata", types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - - // Buffer to collect stdout for JSON parsing - var stdoutBuffer bytes.Buffer - - // Start the command - if err := cmd.Start(); err != nil { - errMsg := fmt.Sprintf("failed to start blender: %v", err) - c.sendLog(taskID, types.LogLevelError, errMsg, "extract_metadata") - c.sendStepUpdate(taskID, "extract_metadata", types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - - // Register process for cleanup on shutdown - c.processTracker.Track(taskID, cmd) - defer c.processTracker.Untrack(taskID) - - // Stream stdout line by line and collect for JSON parsing - stdoutDone := make(chan bool) - go func() { - defer close(stdoutDone) - scanner := bufio.NewScanner(stdoutPipe) - for scanner.Scan() { - line := scanner.Text() - stdoutBuffer.WriteString(line) - stdoutBuffer.WriteString("\n") - if line != "" { - shouldFilter, logLevel := shouldFilterBlenderLog(line) - if !shouldFilter { - c.sendLog(taskID, logLevel, line, "extract_metadata") - } - } - } - }() - - // Stream stderr line by line - stderrDone := make(chan bool) - go func() { - defer close(stderrDone) - scanner := bufio.NewScanner(stderrPipe) - for scanner.Scan() { - line := scanner.Text() - if line != "" { - shouldFilter, logLevel := shouldFilterBlenderLog(line) - if !shouldFilter { - // Use the filtered log level, but if it's still WARN, keep it as WARN - if logLevel == types.LogLevelInfo { - logLevel = types.LogLevelWarn - } - c.sendLog(taskID, logLevel, line, "extract_metadata") - } - } - } - }() - - // Wait for command to complete - err = cmd.Wait() - - // Wait for streaming goroutines to finish - <-stdoutDone - <-stderrDone - if err != nil { - var errMsg string - if exitErr, ok := err.(*exec.ExitError); ok { - if exitErr.ExitCode() == 137 { - errMsg = "Blender metadata extraction was killed due to excessive memory usage (OOM)" - } else { - errMsg = fmt.Sprintf("blender metadata extraction failed: %v", err) - } - } else { - errMsg = fmt.Sprintf("blender metadata extraction failed: %v", err) - } - c.sendLog(taskID, types.LogLevelError, errMsg, "extract_metadata") - c.sendStepUpdate(taskID, "extract_metadata", types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - - // Parse output (metadata is printed to stdout) - metadataJSON := strings.TrimSpace(stdoutBuffer.String()) - // Extract JSON from output (Blender may print other stuff) - jsonStart := strings.Index(metadataJSON, "{") - jsonEnd := strings.LastIndex(metadataJSON, "}") - if jsonStart == -1 || jsonEnd == -1 || jsonEnd <= jsonStart { - errMsg := "failed to extract JSON from Blender output" - c.sendLog(taskID, types.LogLevelError, errMsg, "extract_metadata") - c.sendStepUpdate(taskID, "extract_metadata", types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - metadataJSON = metadataJSON[jsonStart : jsonEnd+1] - - var metadata types.BlendMetadata - if err := json.Unmarshal([]byte(metadataJSON), &metadata); err != nil { - errMsg := fmt.Sprintf("Failed to parse metadata JSON: %v", err) - c.sendLog(taskID, types.LogLevelError, errMsg, "extract_metadata") - c.sendStepUpdate(taskID, "extract_metadata", types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Metadata extracted: frames %d-%d, resolution %dx%d", - metadata.FrameStart, metadata.FrameEnd, metadata.RenderSettings.ResolutionX, metadata.RenderSettings.ResolutionY), "extract_metadata") - c.sendStepUpdate(taskID, "extract_metadata", types.StepStatusCompleted, "") - - // Step: submit_metadata - c.sendStepUpdate(taskID, "submit_metadata", types.StepStatusRunning, "") - c.sendLog(taskID, types.LogLevelInfo, "Submitting metadata to manager...", "submit_metadata") - - // Submit metadata to manager - if err := c.submitMetadata(jobID, metadata); err != nil { - errMsg := fmt.Sprintf("Failed to submit metadata: %v", err) - c.sendLog(taskID, types.LogLevelError, errMsg, "submit_metadata") - c.sendStepUpdate(taskID, "submit_metadata", types.StepStatusFailed, errMsg) - return errors.New(errMsg) - } - - c.sendStepUpdate(taskID, "submit_metadata", types.StepStatusCompleted, "") - c.sendLog(taskID, types.LogLevelInfo, "Metadata extraction completed successfully", "") - - // Mark task as complete - c.sendTaskComplete(taskID, "", true, "") - return nil -} - -// submitMetadata submits extracted metadata to the manager -func (c *Client) submitMetadata(jobID int64, metadata types.BlendMetadata) error { - metadataJSON, err := json.Marshal(metadata) - if err != nil { - return fmt.Errorf("failed to marshal metadata: %w", err) - } - - path := fmt.Sprintf("/api/runner/jobs/%d/metadata?runner_id=%d", jobID, c.runnerID) - url := fmt.Sprintf("%s%s", c.managerURL, path) - req, err := http.NewRequest("POST", url, bytes.NewReader(metadataJSON)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+c.apiKey) - - resp, err := c.httpClient.Do(req) - if err != nil { - return fmt.Errorf("failed to submit metadata: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("metadata submission failed: %s", string(body)) - } - - return nil -} - -// completeTask marks a task as complete via WebSocket (or HTTP fallback) -func (c *Client) completeTask(taskID int64, outputPath string, success bool, errorMsg string) error { - return c.sendTaskComplete(taskID, outputPath, success, errorMsg) -} - -// sendTaskComplete sends task completion via WebSocket -func (c *Client) sendTaskComplete(taskID int64, outputPath string, success bool, errorMsg string) error { - c.wsConnMu.RLock() - conn := c.wsConn - c.wsConnMu.RUnlock() - - if conn != nil { - // Serialize all WebSocket writes to prevent concurrent write panics - c.wsWriteMu.Lock() - defer c.wsWriteMu.Unlock() - - msg := map[string]interface{}{ - "type": "task_complete", - "data": map[string]interface{}{ - "task_id": taskID, - "output_path": outputPath, - "success": success, - "error": errorMsg, - }, - "timestamp": time.Now().Unix(), - } - if err := conn.WriteJSON(msg); err != nil { - return fmt.Errorf("failed to send task completion: %w", err) - } - return nil - } - return fmt.Errorf("WebSocket not connected, cannot complete task") -} diff --git a/internal/runner/encoding/encoder.go b/internal/runner/encoding/encoder.go new file mode 100644 index 0000000..dfdae9e --- /dev/null +++ b/internal/runner/encoding/encoder.go @@ -0,0 +1,71 @@ +// Package encoding handles video encoding with software encoders. +package encoding + +import ( + "os/exec" +) + +// Encoder represents a video encoder. +type Encoder interface { + Name() string + Codec() string + Available() bool + BuildCommand(config *EncodeConfig) *exec.Cmd +} + +// EncodeConfig holds configuration for video encoding. +type EncodeConfig struct { + InputPattern string // Input file pattern (e.g., "frame_%04d.exr") + OutputPath string // Output file path + StartFrame int // Starting frame number + FrameRate float64 // Frame rate + WorkDir string // Working directory + UseAlpha bool // Whether to preserve alpha channel + TwoPass bool // Whether to use 2-pass encoding + SourceFormat string // Source format: "exr" or "png" (defaults to "exr") + PreserveHDR bool // Whether to preserve HDR range for EXR (uses HLG with bt709 primaries) +} + +// Selector selects the software encoder. +type Selector struct { + h264Encoders []Encoder + av1Encoders []Encoder + vp9Encoders []Encoder +} + +// NewSelector creates a new encoder selector with software encoders. +func NewSelector() *Selector { + s := &Selector{} + s.detectEncoders() + return s +} + +func (s *Selector) detectEncoders() { + // Use software encoding only - reliable and avoids hardware-specific colorspace issues + s.h264Encoders = []Encoder{ + &SoftwareEncoder{codec: "libx264"}, + } + + s.av1Encoders = []Encoder{ + &SoftwareEncoder{codec: "libaom-av1"}, + } + + s.vp9Encoders = []Encoder{ + &SoftwareEncoder{codec: "libvpx-vp9"}, + } +} + +// SelectH264 returns the software H.264 encoder. +func (s *Selector) SelectH264() Encoder { + return &SoftwareEncoder{codec: "libx264"} +} + +// SelectAV1 returns the software AV1 encoder. +func (s *Selector) SelectAV1() Encoder { + return &SoftwareEncoder{codec: "libaom-av1"} +} + +// SelectVP9 returns the software VP9 encoder. +func (s *Selector) SelectVP9() Encoder { + return &SoftwareEncoder{codec: "libvpx-vp9"} +} diff --git a/internal/runner/encoding/encoders.go b/internal/runner/encoding/encoders.go new file mode 100644 index 0000000..49106f7 --- /dev/null +++ b/internal/runner/encoding/encoders.go @@ -0,0 +1,270 @@ +package encoding + +import ( + "fmt" + "log" + "os/exec" + "strconv" + "strings" +) + +const ( + // CRFH264 is the Constant Rate Factor for H.264 encoding (lower = higher quality, range 0-51) + CRFH264 = 15 + // CRFAV1 is the Constant Rate Factor for AV1 encoding (lower = higher quality, range 0-63) + CRFAV1 = 30 + // CRFVP9 is the Constant Rate Factor for VP9 encoding (lower = higher quality, range 0-63) + CRFVP9 = 30 +) + +// tonemapFilter returns the appropriate filter for EXR input. +// For HDR preservation: converts linear RGB (EXR) to bt2020 YUV with HLG transfer function +// Uses zscale to properly convert colorspace from linear RGB to bt2020 YUV while preserving HDR range +// Step 1: Ensure format is gbrpf32le (linear RGB) +// Step 2: Convert transfer function from linear to HLG (arib-std-b67) with bt2020 primaries/matrix +// Step 3: Convert to YUV format +func tonemapFilter(useAlpha bool) string { + // Convert from linear RGB (gbrpf32le) to HLG with bt709 primaries to match PNG appearance + // Based on best practices: convert linear RGB directly to HLG with bt709 primaries + // This matches PNG color appearance (bt709 primaries) while preserving HDR range (HLG transfer) + // zscale uses numeric values: + // primaries: 1=bt709 (matches PNG), 9=bt2020 + // matrix: 1=bt709, 9=bt2020nc, 0=gbr (RGB input) + // transfer: 8=linear, 18=arib-std-b67 (HLG) + // Direct conversion: linear RGB -> HLG with bt709 primaries -> bt2020 YUV (for wider gamut metadata) + // The bt709 primaries in the conversion match PNG, but we set bt2020 in metadata for HDR displays + // Convert linear RGB to sRGB first, then convert to HLG + // This approach: linear -> sRGB -> HLG -> bt2020 + // Fixes red tint by using sRGB conversion, preserves HDR range with HLG + filter := "format=gbrpf32le,zscale=transferin=8:transfer=13:primariesin=1:primaries=1:matrixin=0:matrix=1:rangein=full:range=full,zscale=transferin=13:transfer=18:primariesin=1:primaries=9:matrixin=1:matrix=9:rangein=full:range=full" + if useAlpha { + return filter + ",format=yuva420p10le" + } + return filter + ",format=yuv420p10le" +} + +// SoftwareEncoder implements software encoding (libx264, libaom-av1, libvpx-vp9). +type SoftwareEncoder struct { + codec string +} + +func (e *SoftwareEncoder) Name() string { return "software" } +func (e *SoftwareEncoder) Codec() string { return e.codec } + +func (e *SoftwareEncoder) Available() bool { + return true // Software encoding is always available +} + +func (e *SoftwareEncoder) BuildCommand(config *EncodeConfig) *exec.Cmd { + // Use HDR pixel formats for EXR, SDR for PNG + var pixFmt string + var colorPrimaries, colorTrc, colorspace string + if config.SourceFormat == "png" { + // PNG: SDR format + pixFmt = "yuv420p" + if config.UseAlpha { + pixFmt = "yuva420p" + } + colorPrimaries = "bt709" + colorTrc = "bt709" + colorspace = "bt709" + } else { + // EXR: Use HDR encoding if PreserveHDR is true, otherwise SDR (like PNG) + if config.PreserveHDR { + // HDR: Use HLG transfer with bt709 primaries to preserve HDR range while matching PNG color + pixFmt = "yuv420p10le" // 10-bit to preserve HDR range + if config.UseAlpha { + pixFmt = "yuva420p10le" + } + colorPrimaries = "bt709" // bt709 primaries to match PNG color appearance + colorTrc = "arib-std-b67" // HLG transfer function - preserves HDR range, works on SDR displays + colorspace = "bt709" // bt709 colorspace to match PNG + } else { + // SDR: Treat as SDR (like PNG) - encode as bt709 + pixFmt = "yuv420p" + if config.UseAlpha { + pixFmt = "yuva420p" + } + colorPrimaries = "bt709" + colorTrc = "bt709" + colorspace = "bt709" + } + } + + var codecArgs []string + switch e.codec { + case "libaom-av1": + codecArgs = []string{"-crf", strconv.Itoa(CRFAV1), "-b:v", "0", "-tiles", "2x2", "-g", "240"} + case "libvpx-vp9": + // VP9 supports alpha and HDR, use good quality settings + codecArgs = []string{"-crf", strconv.Itoa(CRFVP9), "-b:v", "0", "-row-mt", "1", "-g", "240"} + default: + // H.264: Use High 10 profile for HDR EXR (10-bit), High profile for SDR + if config.SourceFormat != "png" && config.PreserveHDR { + codecArgs = []string{"-preset", "veryslow", "-crf", strconv.Itoa(CRFH264), "-profile:v", "high10", "-level", "5.2", "-tune", "film", "-keyint_min", "24", "-g", "240", "-bf", "2", "-refs", "4"} + } else { + codecArgs = []string{"-preset", "veryslow", "-crf", strconv.Itoa(CRFH264), "-profile:v", "high", "-level", "5.2", "-tune", "film", "-keyint_min", "24", "-g", "240", "-bf", "2", "-refs", "4"} + } + } + + args := []string{ + "-y", + "-f", "image2", + "-start_number", fmt.Sprintf("%d", config.StartFrame), + "-framerate", fmt.Sprintf("%.2f", config.FrameRate), + "-i", config.InputPattern, + "-c:v", e.codec, + "-pix_fmt", pixFmt, + "-r", fmt.Sprintf("%.2f", config.FrameRate), + "-color_primaries", colorPrimaries, + "-color_trc", colorTrc, + "-colorspace", colorspace, + "-color_range", "tv", + } + + // Add video filter for EXR: convert linear RGB based on HDR setting + // PNG doesn't need any filter as it's already in sRGB + if config.SourceFormat != "png" { + var vf string + if config.PreserveHDR { + // HDR: Convert linear RGB -> sRGB -> HLG with bt709 primaries + // This preserves HDR range while matching PNG color appearance + vf = "format=gbrpf32le,zscale=transferin=8:transfer=13:primariesin=1:primaries=1:matrixin=0:matrix=1:rangein=full:range=full,zscale=transferin=13:transfer=18:primariesin=1:primaries=1:matrixin=1:matrix=1:rangein=full:range=full" + if config.UseAlpha { + vf += ",format=yuva420p10le" + } else { + vf += ",format=yuv420p10le" + } + } else { + // SDR: Convert linear RGB (EXR) to sRGB (bt709) - simple conversion like Krita does + // zscale: linear (8) -> sRGB (13) with bt709 primaries/matrix + vf = "format=gbrpf32le,zscale=transferin=8:transfer=13:primariesin=1:primaries=1:matrixin=0:matrix=1:rangein=full:range=full" + if config.UseAlpha { + vf += ",format=yuva420p" + } else { + vf += ",format=yuv420p" + } + } + args = append(args, "-vf", vf) + } + args = append(args, codecArgs...) + + if config.TwoPass { + // For 2-pass, this builds pass 2 command + args = append(args, "-pass", "2") + } + + args = append(args, config.OutputPath) + + if config.TwoPass { + log.Printf("Build Software Pass 2 command: ffmpeg %s", strings.Join(args, " ")) + } else { + log.Printf("Build Software command: ffmpeg %s", strings.Join(args, " ")) + } + cmd := exec.Command("ffmpeg", args...) + cmd.Dir = config.WorkDir + return cmd +} + +// BuildPass1Command builds the first pass command for 2-pass encoding. +func (e *SoftwareEncoder) BuildPass1Command(config *EncodeConfig) *exec.Cmd { + // Use HDR pixel formats for EXR, SDR for PNG + var pixFmt string + var colorPrimaries, colorTrc, colorspace string + if config.SourceFormat == "png" { + // PNG: SDR format + pixFmt = "yuv420p" + if config.UseAlpha { + pixFmt = "yuva420p" + } + colorPrimaries = "bt709" + colorTrc = "bt709" + colorspace = "bt709" + } else { + // EXR: Use HDR encoding if PreserveHDR is true, otherwise SDR (like PNG) + if config.PreserveHDR { + // HDR: Use HLG transfer with bt709 primaries to preserve HDR range while matching PNG color + pixFmt = "yuv420p10le" // 10-bit to preserve HDR range + if config.UseAlpha { + pixFmt = "yuva420p10le" + } + colorPrimaries = "bt709" // bt709 primaries to match PNG color appearance + colorTrc = "arib-std-b67" // HLG transfer function - preserves HDR range, works on SDR displays + colorspace = "bt709" // bt709 colorspace to match PNG + } else { + // SDR: Treat as SDR (like PNG) - encode as bt709 + pixFmt = "yuv420p" + if config.UseAlpha { + pixFmt = "yuva420p" + } + colorPrimaries = "bt709" + colorTrc = "bt709" + colorspace = "bt709" + } + } + + var codecArgs []string + switch e.codec { + case "libaom-av1": + codecArgs = []string{"-crf", strconv.Itoa(CRFAV1), "-b:v", "0", "-tiles", "2x2", "-g", "240"} + case "libvpx-vp9": + // VP9 supports alpha and HDR, use good quality settings + codecArgs = []string{"-crf", strconv.Itoa(CRFVP9), "-b:v", "0", "-row-mt", "1", "-g", "240"} + default: + // H.264: Use High 10 profile for HDR EXR (10-bit), High profile for SDR + if config.SourceFormat != "png" && config.PreserveHDR { + codecArgs = []string{"-preset", "veryslow", "-crf", strconv.Itoa(CRFH264), "-profile:v", "high10", "-level", "5.2", "-tune", "film", "-keyint_min", "24", "-g", "240", "-bf", "2", "-refs", "4"} + } else { + codecArgs = []string{"-preset", "veryslow", "-crf", strconv.Itoa(CRFH264), "-profile:v", "high", "-level", "5.2", "-tune", "film", "-keyint_min", "24", "-g", "240", "-bf", "2", "-refs", "4"} + } + } + + args := []string{ + "-y", + "-f", "image2", + "-start_number", fmt.Sprintf("%d", config.StartFrame), + "-framerate", fmt.Sprintf("%.2f", config.FrameRate), + "-i", config.InputPattern, + "-c:v", e.codec, + "-pix_fmt", pixFmt, + "-r", fmt.Sprintf("%.2f", config.FrameRate), + "-color_primaries", colorPrimaries, + "-color_trc", colorTrc, + "-colorspace", colorspace, + "-color_range", "tv", + } + + // Add video filter for EXR: convert linear RGB based on HDR setting + // PNG doesn't need any filter as it's already in sRGB + if config.SourceFormat != "png" { + var vf string + if config.PreserveHDR { + // HDR: Convert linear RGB -> sRGB -> HLG with bt709 primaries + // This preserves HDR range while matching PNG color appearance + vf = "format=gbrpf32le,zscale=transferin=8:transfer=13:primariesin=1:primaries=1:matrixin=0:matrix=1:rangein=full:range=full,zscale=transferin=13:transfer=18:primariesin=1:primaries=1:matrixin=1:matrix=1:rangein=full:range=full" + if config.UseAlpha { + vf += ",format=yuva420p10le" + } else { + vf += ",format=yuv420p10le" + } + } else { + // SDR: Convert linear RGB (EXR) to sRGB (bt709) - simple conversion like Krita does + // zscale: linear (8) -> sRGB (13) with bt709 primaries/matrix + vf = "format=gbrpf32le,zscale=transferin=8:transfer=13:primariesin=1:primaries=1:matrixin=0:matrix=1:rangein=full:range=full" + if config.UseAlpha { + vf += ",format=yuva420p" + } else { + vf += ",format=yuv420p" + } + } + args = append(args, "-vf", vf) + } + + args = append(args, codecArgs...) + args = append(args, "-pass", "1", "-f", "null", "/dev/null") + + log.Printf("Build Software Pass 1 command: ffmpeg %s", strings.Join(args, " ")) + cmd := exec.Command("ffmpeg", args...) + cmd.Dir = config.WorkDir + return cmd +} diff --git a/internal/runner/encoding/encoders_test.go b/internal/runner/encoding/encoders_test.go new file mode 100644 index 0000000..9c0c917 --- /dev/null +++ b/internal/runner/encoding/encoders_test.go @@ -0,0 +1,980 @@ +package encoding + +import ( + "os" + "os/exec" + "path/filepath" + "strings" + "testing" +) + +func TestSoftwareEncoder_BuildCommand_H264_EXR(t *testing.T) { + encoder := &SoftwareEncoder{codec: "libx264"} + config := &EncodeConfig{ + InputPattern: "frame_%04d.exr", + OutputPath: "output.mp4", + StartFrame: 1, + FrameRate: 24.0, + WorkDir: "/tmp", + UseAlpha: false, + TwoPass: true, + SourceFormat: "exr", + } + + cmd := encoder.BuildCommand(config) + if cmd == nil { + t.Fatal("BuildCommand returned nil") + } + + if !strings.Contains(cmd.Path, "ffmpeg") { + t.Errorf("Expected command path to contain 'ffmpeg', got '%s'", cmd.Path) + } + + if cmd.Dir != "/tmp" { + t.Errorf("Expected work dir '/tmp', got '%s'", cmd.Dir) + } + + args := cmd.Args[1:] // Skip "ffmpeg" + argsStr := strings.Join(args, " ") + + // Check required arguments + checks := []struct { + name string + expected string + }{ + {"-y flag", "-y"}, + {"image2 format", "-f image2"}, + {"start number", "-start_number 1"}, + {"framerate", "-framerate 24.00"}, + {"input pattern", "-i frame_%04d.exr"}, + {"codec", "-c:v libx264"}, + {"pixel format", "-pix_fmt yuv420p"}, // EXR now treated as SDR (like PNG) + {"frame rate", "-r 24.00"}, + {"color primaries", "-color_primaries bt709"}, // EXR now uses bt709 (SDR) + {"color trc", "-color_trc bt709"}, // EXR now uses bt709 (SDR) + {"colorspace", "-colorspace bt709"}, + {"color range", "-color_range tv"}, + {"video filter", "-vf"}, + {"preset", "-preset veryslow"}, + {"crf", "-crf 15"}, + {"profile", "-profile:v high"}, // EXR now uses high profile (SDR) + {"pass 2", "-pass 2"}, + {"output path", "output.mp4"}, + } + + for _, check := range checks { + if !strings.Contains(argsStr, check.expected) { + t.Errorf("Missing expected argument: %s", check.expected) + } + } + + // Verify filter is present for EXR (linear RGB to sRGB conversion, like Krita does) + if !strings.Contains(argsStr, "format=gbrpf32le") { + t.Error("Expected format conversion filter for EXR source, but not found") + } + if !strings.Contains(argsStr, "zscale=transferin=8:transfer=13") { + t.Error("Expected linear to sRGB conversion for EXR source, but not found") + } +} + +func TestSoftwareEncoder_BuildCommand_H264_PNG(t *testing.T) { + encoder := &SoftwareEncoder{codec: "libx264"} + config := &EncodeConfig{ + InputPattern: "frame_%04d.png", + OutputPath: "output.mp4", + StartFrame: 1, + FrameRate: 24.0, + WorkDir: "/tmp", + UseAlpha: false, + TwoPass: true, + SourceFormat: "png", + } + + cmd := encoder.BuildCommand(config) + args := cmd.Args[1:] + argsStr := strings.Join(args, " ") + + // PNG should NOT have video filter + if strings.Contains(argsStr, "-vf") { + t.Error("PNG source should not have video filter, but -vf was found") + } + + // Should still have all other required args + if !strings.Contains(argsStr, "-c:v libx264") { + t.Error("Missing codec argument") + } +} + +func TestSoftwareEncoder_BuildCommand_AV1_WithAlpha(t *testing.T) { + encoder := &SoftwareEncoder{codec: "libaom-av1"} + config := &EncodeConfig{ + InputPattern: "frame_%04d.exr", + OutputPath: "output.mp4", + StartFrame: 100, + FrameRate: 30.0, + WorkDir: "/tmp", + UseAlpha: true, + TwoPass: true, + SourceFormat: "exr", + } + + cmd := encoder.BuildCommand(config) + args := cmd.Args[1:] + argsStr := strings.Join(args, " ") + + // Check alpha-specific settings + if !strings.Contains(argsStr, "-pix_fmt yuva420p") { + t.Error("Expected yuva420p pixel format for alpha, but not found") + } + + // Check AV1-specific arguments + av1Checks := []string{ + "-c:v libaom-av1", + "-crf 30", + "-b:v 0", + "-tiles 2x2", + "-g 240", + } + + for _, check := range av1Checks { + if !strings.Contains(argsStr, check) { + t.Errorf("Missing AV1 argument: %s", check) + } + } + + // Check tonemap filter includes alpha format + if !strings.Contains(argsStr, "format=yuva420p") { + t.Error("Expected tonemap filter to output yuva420p for alpha, but not found") + } +} + +func TestSoftwareEncoder_BuildCommand_VP9(t *testing.T) { + encoder := &SoftwareEncoder{codec: "libvpx-vp9"} + config := &EncodeConfig{ + InputPattern: "frame_%04d.exr", + OutputPath: "output.webm", + StartFrame: 1, + FrameRate: 24.0, + WorkDir: "/tmp", + UseAlpha: true, + TwoPass: true, + SourceFormat: "exr", + } + + cmd := encoder.BuildCommand(config) + args := cmd.Args[1:] + argsStr := strings.Join(args, " ") + + // Check VP9-specific arguments + vp9Checks := []string{ + "-c:v libvpx-vp9", + "-crf 30", + "-b:v 0", + "-row-mt 1", + "-g 240", + } + + for _, check := range vp9Checks { + if !strings.Contains(argsStr, check) { + t.Errorf("Missing VP9 argument: %s", check) + } + } +} + +func TestSoftwareEncoder_BuildPass1Command(t *testing.T) { + encoder := &SoftwareEncoder{codec: "libx264"} + config := &EncodeConfig{ + InputPattern: "frame_%04d.exr", + OutputPath: "output.mp4", + StartFrame: 1, + FrameRate: 24.0, + WorkDir: "/tmp", + UseAlpha: false, + TwoPass: true, + SourceFormat: "exr", + } + + cmd := encoder.BuildPass1Command(config) + args := cmd.Args[1:] + argsStr := strings.Join(args, " ") + + // Pass 1 should have -pass 1 and output to null + if !strings.Contains(argsStr, "-pass 1") { + t.Error("Pass 1 command should include '-pass 1'") + } + + if !strings.Contains(argsStr, "-f null") { + t.Error("Pass 1 command should include '-f null'") + } + + if !strings.Contains(argsStr, "/dev/null") { + t.Error("Pass 1 command should output to /dev/null") + } + + // Should NOT have output path + if strings.Contains(argsStr, "output.mp4") { + t.Error("Pass 1 command should not include output path") + } +} + +func TestSoftwareEncoder_BuildPass1Command_AV1(t *testing.T) { + encoder := &SoftwareEncoder{codec: "libaom-av1"} + config := &EncodeConfig{ + InputPattern: "frame_%04d.exr", + OutputPath: "output.mp4", + StartFrame: 1, + FrameRate: 24.0, + WorkDir: "/tmp", + UseAlpha: false, + TwoPass: true, + SourceFormat: "exr", + } + + cmd := encoder.BuildPass1Command(config) + args := cmd.Args[1:] + argsStr := strings.Join(args, " ") + + // Pass 1 should have -pass 1 and output to null + if !strings.Contains(argsStr, "-pass 1") { + t.Error("Pass 1 command should include '-pass 1'") + } + + if !strings.Contains(argsStr, "-f null") { + t.Error("Pass 1 command should include '-f null'") + } + + if !strings.Contains(argsStr, "/dev/null") { + t.Error("Pass 1 command should output to /dev/null") + } + + // Check AV1-specific arguments in pass 1 + av1Checks := []string{ + "-c:v libaom-av1", + "-crf 30", + "-b:v 0", + "-tiles 2x2", + "-g 240", + } + + for _, check := range av1Checks { + if !strings.Contains(argsStr, check) { + t.Errorf("Missing AV1 argument in pass 1: %s", check) + } + } +} + +func TestSoftwareEncoder_BuildPass1Command_VP9(t *testing.T) { + encoder := &SoftwareEncoder{codec: "libvpx-vp9"} + config := &EncodeConfig{ + InputPattern: "frame_%04d.exr", + OutputPath: "output.webm", + StartFrame: 1, + FrameRate: 24.0, + WorkDir: "/tmp", + UseAlpha: false, + TwoPass: true, + SourceFormat: "exr", + } + + cmd := encoder.BuildPass1Command(config) + args := cmd.Args[1:] + argsStr := strings.Join(args, " ") + + // Pass 1 should have -pass 1 and output to null + if !strings.Contains(argsStr, "-pass 1") { + t.Error("Pass 1 command should include '-pass 1'") + } + + if !strings.Contains(argsStr, "-f null") { + t.Error("Pass 1 command should include '-f null'") + } + + if !strings.Contains(argsStr, "/dev/null") { + t.Error("Pass 1 command should output to /dev/null") + } + + // Check VP9-specific arguments in pass 1 + vp9Checks := []string{ + "-c:v libvpx-vp9", + "-crf 30", + "-b:v 0", + "-row-mt 1", + "-g 240", + } + + for _, check := range vp9Checks { + if !strings.Contains(argsStr, check) { + t.Errorf("Missing VP9 argument in pass 1: %s", check) + } + } +} + +func TestSoftwareEncoder_BuildCommand_NoTwoPass(t *testing.T) { + encoder := &SoftwareEncoder{codec: "libx264"} + config := &EncodeConfig{ + InputPattern: "frame_%04d.exr", + OutputPath: "output.mp4", + StartFrame: 1, + FrameRate: 24.0, + WorkDir: "/tmp", + UseAlpha: false, + TwoPass: false, + SourceFormat: "exr", + } + + cmd := encoder.BuildCommand(config) + args := cmd.Args[1:] + argsStr := strings.Join(args, " ") + + // Should NOT have -pass flag when TwoPass is false + if strings.Contains(argsStr, "-pass") { + t.Error("Command should not include -pass flag when TwoPass is false") + } +} + +func TestSelector_SelectH264(t *testing.T) { + selector := NewSelector() + encoder := selector.SelectH264() + + if encoder == nil { + t.Fatal("SelectH264 returned nil") + } + + if encoder.Codec() != "libx264" { + t.Errorf("Expected codec 'libx264', got '%s'", encoder.Codec()) + } + + if encoder.Name() != "software" { + t.Errorf("Expected name 'software', got '%s'", encoder.Name()) + } +} + +func TestSelector_SelectAV1(t *testing.T) { + selector := NewSelector() + encoder := selector.SelectAV1() + + if encoder == nil { + t.Fatal("SelectAV1 returned nil") + } + + if encoder.Codec() != "libaom-av1" { + t.Errorf("Expected codec 'libaom-av1', got '%s'", encoder.Codec()) + } +} + +func TestSelector_SelectVP9(t *testing.T) { + selector := NewSelector() + encoder := selector.SelectVP9() + + if encoder == nil { + t.Fatal("SelectVP9 returned nil") + } + + if encoder.Codec() != "libvpx-vp9" { + t.Errorf("Expected codec 'libvpx-vp9', got '%s'", encoder.Codec()) + } +} + +func TestTonemapFilter_WithAlpha(t *testing.T) { + filter := tonemapFilter(true) + + // Filter should convert from gbrpf32le to yuva420p10le with proper colorspace conversion + if !strings.Contains(filter, "yuva420p10le") { + t.Error("Tonemap filter with alpha should output yuva420p10le format for HDR") + } + + if !strings.Contains(filter, "gbrpf32le") { + t.Error("Tonemap filter should start with gbrpf32le format") + } + + // Should use zscale for colorspace conversion from linear RGB to bt2020 YUV + if !strings.Contains(filter, "zscale") { + t.Error("Tonemap filter should use zscale for colorspace conversion") + } + + // Check for HLG transfer function (numeric value 18 or string arib-std-b67) + if !strings.Contains(filter, "transfer=18") && !strings.Contains(filter, "transfer=arib-std-b67") { + t.Error("Tonemap filter should use HLG transfer function (18 or arib-std-b67)") + } +} + +func TestTonemapFilter_WithoutAlpha(t *testing.T) { + filter := tonemapFilter(false) + + // Filter should convert from gbrpf32le to yuv420p10le with proper colorspace conversion + if !strings.Contains(filter, "yuv420p10le") { + t.Error("Tonemap filter without alpha should output yuv420p10le format for HDR") + } + + if strings.Contains(filter, "yuva420p") { + t.Error("Tonemap filter without alpha should not output yuva420p format") + } + + if !strings.Contains(filter, "gbrpf32le") { + t.Error("Tonemap filter should start with gbrpf32le format") + } + + // Should use zscale for colorspace conversion from linear RGB to bt2020 YUV + if !strings.Contains(filter, "zscale") { + t.Error("Tonemap filter should use zscale for colorspace conversion") + } + + // Check for HLG transfer function (numeric value 18 or string arib-std-b67) + if !strings.Contains(filter, "transfer=18") && !strings.Contains(filter, "transfer=arib-std-b67") { + t.Error("Tonemap filter should use HLG transfer function (18 or arib-std-b67)") + } +} + +func TestSoftwareEncoder_Available(t *testing.T) { + encoder := &SoftwareEncoder{codec: "libx264"} + if !encoder.Available() { + t.Error("Software encoder should always be available") + } +} + +func TestEncodeConfig_DefaultSourceFormat(t *testing.T) { + config := &EncodeConfig{ + InputPattern: "frame_%04d.exr", + OutputPath: "output.mp4", + StartFrame: 1, + FrameRate: 24.0, + WorkDir: "/tmp", + UseAlpha: false, + TwoPass: false, + // SourceFormat not set, should default to empty string (treated as exr) + } + + encoder := &SoftwareEncoder{codec: "libx264"} + cmd := encoder.BuildCommand(config) + args := strings.Join(cmd.Args[1:], " ") + + // Should still have tonemap filter when SourceFormat is empty (defaults to exr behavior) + if !strings.Contains(args, "-vf") { + t.Error("Empty SourceFormat should default to EXR behavior with tonemap filter") + } +} + +func TestCommandOrder(t *testing.T) { + encoder := &SoftwareEncoder{codec: "libx264"} + config := &EncodeConfig{ + InputPattern: "frame_%04d.exr", + OutputPath: "output.mp4", + StartFrame: 1, + FrameRate: 24.0, + WorkDir: "/tmp", + UseAlpha: false, + TwoPass: true, + SourceFormat: "exr", + } + + cmd := encoder.BuildCommand(config) + args := cmd.Args[1:] + + // Verify argument order: input should come before codec + inputIdx := -1 + codecIdx := -1 + vfIdx := -1 + + for i, arg := range args { + if arg == "-i" && i+1 < len(args) && args[i+1] == "frame_%04d.exr" { + inputIdx = i + } + if arg == "-c:v" && i+1 < len(args) && args[i+1] == "libx264" { + codecIdx = i + } + if arg == "-vf" { + vfIdx = i + } + } + + if inputIdx == -1 { + t.Fatal("Input pattern not found in command") + } + if codecIdx == -1 { + t.Fatal("Codec not found in command") + } + if vfIdx == -1 { + t.Fatal("Video filter not found in command") + } + + // Input should come before codec + if inputIdx >= codecIdx { + t.Error("Input pattern should come before codec in command") + } + + // Video filter should come after input (order: input -> codec -> colorspace -> filter -> codec args) + // In practice, the filter comes after codec and colorspace metadata but before codec-specific args + if vfIdx <= inputIdx { + t.Error("Video filter should come after input") + } +} + +func TestCommand_ColorspaceMetadata(t *testing.T) { + encoder := &SoftwareEncoder{codec: "libx264"} + config := &EncodeConfig{ + InputPattern: "frame_%04d.exr", + OutputPath: "output.mp4", + StartFrame: 1, + FrameRate: 24.0, + WorkDir: "/tmp", + UseAlpha: false, + TwoPass: false, + SourceFormat: "exr", + PreserveHDR: false, // SDR encoding + } + + cmd := encoder.BuildCommand(config) + args := cmd.Args[1:] + argsStr := strings.Join(args, " ") + + // Verify all SDR colorspace metadata is present for EXR (SDR encoding) + colorspaceArgs := []string{ + "-color_primaries bt709", // EXR uses bt709 (SDR) + "-color_trc bt709", // EXR uses bt709 (SDR) + "-colorspace bt709", + "-color_range tv", + } + + for _, arg := range colorspaceArgs { + if !strings.Contains(argsStr, arg) { + t.Errorf("Missing colorspace metadata: %s", arg) + } + } + + // Verify SDR pixel format + if !strings.Contains(argsStr, "-pix_fmt yuv420p") { + t.Error("SDR encoding should use yuv420p pixel format") + } + + // Verify H.264 high profile (not high10) + if !strings.Contains(argsStr, "-profile:v high") { + t.Error("SDR encoding should use high profile") + } + if strings.Contains(argsStr, "-profile:v high10") { + t.Error("SDR encoding should not use high10 profile") + } +} + +func TestCommand_HDR_ColorspaceMetadata(t *testing.T) { + encoder := &SoftwareEncoder{codec: "libx264"} + config := &EncodeConfig{ + InputPattern: "frame_%04d.exr", + OutputPath: "output.mp4", + StartFrame: 1, + FrameRate: 24.0, + WorkDir: "/tmp", + UseAlpha: false, + TwoPass: false, + SourceFormat: "exr", + PreserveHDR: true, // HDR encoding + } + + cmd := encoder.BuildCommand(config) + args := cmd.Args[1:] + argsStr := strings.Join(args, " ") + + // Verify all HDR colorspace metadata is present for EXR (HDR encoding) + colorspaceArgs := []string{ + "-color_primaries bt709", // bt709 primaries to match PNG color appearance + "-color_trc arib-std-b67", // HLG transfer function for HDR/SDR compatibility + "-colorspace bt709", // bt709 colorspace to match PNG + "-color_range tv", + } + + for _, arg := range colorspaceArgs { + if !strings.Contains(argsStr, arg) { + t.Errorf("Missing HDR colorspace metadata: %s", arg) + } + } + + // Verify HDR pixel format (10-bit) + if !strings.Contains(argsStr, "-pix_fmt yuv420p10le") { + t.Error("HDR encoding should use yuv420p10le pixel format") + } + + // Verify H.264 high10 profile (for 10-bit) + if !strings.Contains(argsStr, "-profile:v high10") { + t.Error("HDR encoding should use high10 profile") + } + + // Verify HDR filter chain (linear -> sRGB -> HLG) + if !strings.Contains(argsStr, "-vf") { + t.Fatal("HDR encoding should have video filter") + } + vfIdx := -1 + for i, arg := range args { + if arg == "-vf" && i+1 < len(args) { + vfIdx = i + 1 + break + } + } + if vfIdx == -1 { + t.Fatal("Video filter not found") + } + filter := args[vfIdx] + if !strings.Contains(filter, "transfer=18") { + t.Error("HDR filter should convert to HLG (transfer=18)") + } + if !strings.Contains(filter, "yuv420p10le") { + t.Error("HDR filter should output yuv420p10le format") + } +} + +// Integration tests using example files +func TestIntegration_Encode_EXR_H264(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + // Check if example file exists + exampleDir := filepath.Join("..", "..", "..", "examples") + exrFile := filepath.Join(exampleDir, "frame_0800.exr") + if _, err := os.Stat(exrFile); os.IsNotExist(err) { + t.Skipf("Example file not found: %s", exrFile) + } + + // Get absolute paths + workspaceRoot, err := filepath.Abs(filepath.Join("..", "..", "..")) + if err != nil { + t.Fatalf("Failed to get workspace root: %v", err) + } + exampleDirAbs, err := filepath.Abs(exampleDir) + if err != nil { + t.Fatalf("Failed to get example directory: %v", err) + } + tmpDir := filepath.Join(workspaceRoot, "tmp") + if err := os.MkdirAll(tmpDir, 0755); err != nil { + t.Fatalf("Failed to create tmp directory: %v", err) + } + + encoder := &SoftwareEncoder{codec: "libx264"} + config := &EncodeConfig{ + InputPattern: filepath.Join(exampleDirAbs, "frame_%04d.exr"), + OutputPath: filepath.Join(tmpDir, "test_exr_h264.mp4"), + StartFrame: 800, + FrameRate: 24.0, + WorkDir: tmpDir, + UseAlpha: false, + TwoPass: false, // Use single pass for faster testing + SourceFormat: "exr", + } + + // Build and run command + cmd := encoder.BuildCommand(config) + if cmd == nil { + t.Fatal("BuildCommand returned nil") + } + + // Capture stderr to see what went wrong + output, err := cmd.CombinedOutput() + if err != nil { + t.Errorf("FFmpeg command failed: %v\nCommand output: %s", err, string(output)) + return + } + + // Verify output file was created + if _, err := os.Stat(config.OutputPath); os.IsNotExist(err) { + t.Errorf("Output file was not created: %s\nCommand output: %s", config.OutputPath, string(output)) + } else { + t.Logf("Successfully created output file: %s", config.OutputPath) + // Verify file has content + info, _ := os.Stat(config.OutputPath) + if info.Size() == 0 { + t.Errorf("Output file was created but is empty\nCommand output: %s", string(output)) + } else { + t.Logf("Output file size: %d bytes", info.Size()) + } + } +} + +func TestIntegration_Encode_PNG_H264(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + // Check if example file exists + exampleDir := filepath.Join("..", "..", "..", "examples") + pngFile := filepath.Join(exampleDir, "frame_0800.png") + if _, err := os.Stat(pngFile); os.IsNotExist(err) { + t.Skipf("Example file not found: %s", pngFile) + } + + // Get absolute paths + workspaceRoot, err := filepath.Abs(filepath.Join("..", "..", "..")) + if err != nil { + t.Fatalf("Failed to get workspace root: %v", err) + } + exampleDirAbs, err := filepath.Abs(exampleDir) + if err != nil { + t.Fatalf("Failed to get example directory: %v", err) + } + tmpDir := filepath.Join(workspaceRoot, "tmp") + if err := os.MkdirAll(tmpDir, 0755); err != nil { + t.Fatalf("Failed to create tmp directory: %v", err) + } + + encoder := &SoftwareEncoder{codec: "libx264"} + config := &EncodeConfig{ + InputPattern: filepath.Join(exampleDirAbs, "frame_%04d.png"), + OutputPath: filepath.Join(tmpDir, "test_png_h264.mp4"), + StartFrame: 800, + FrameRate: 24.0, + WorkDir: tmpDir, + UseAlpha: false, + TwoPass: false, // Use single pass for faster testing + SourceFormat: "png", + } + + // Build and run command + cmd := encoder.BuildCommand(config) + if cmd == nil { + t.Fatal("BuildCommand returned nil") + } + + // Verify no video filter is used for PNG + argsStr := strings.Join(cmd.Args, " ") + if strings.Contains(argsStr, "-vf") { + t.Error("PNG encoding should not use video filter, but -vf was found in command") + } + + // Run the command + cmdOutput, err := cmd.CombinedOutput() + if err != nil { + t.Errorf("FFmpeg command failed: %v\nCommand output: %s", err, string(cmdOutput)) + return + } + + // Verify output file was created + if _, err := os.Stat(config.OutputPath); os.IsNotExist(err) { + t.Errorf("Output file was not created: %s\nCommand output: %s", config.OutputPath, string(cmdOutput)) + } else { + t.Logf("Successfully created output file: %s", config.OutputPath) + info, _ := os.Stat(config.OutputPath) + if info.Size() == 0 { + t.Error("Output file was created but is empty") + } else { + t.Logf("Output file size: %d bytes", info.Size()) + } + } +} + +func TestIntegration_Encode_EXR_VP9(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + // Check if example file exists + exampleDir := filepath.Join("..", "..", "..", "examples") + exrFile := filepath.Join(exampleDir, "frame_0800.exr") + if _, err := os.Stat(exrFile); os.IsNotExist(err) { + t.Skipf("Example file not found: %s", exrFile) + } + + // Check if VP9 encoder is available + checkCmd := exec.Command("ffmpeg", "-hide_banner", "-encoders") + checkOutput, err := checkCmd.CombinedOutput() + if err != nil || !strings.Contains(string(checkOutput), "libvpx-vp9") { + t.Skip("VP9 encoder (libvpx-vp9) not available in ffmpeg") + } + + // Get absolute paths + workspaceRoot, err := filepath.Abs(filepath.Join("..", "..", "..")) + if err != nil { + t.Fatalf("Failed to get workspace root: %v", err) + } + exampleDirAbs, err := filepath.Abs(exampleDir) + if err != nil { + t.Fatalf("Failed to get example directory: %v", err) + } + tmpDir := filepath.Join(workspaceRoot, "tmp") + if err := os.MkdirAll(tmpDir, 0755); err != nil { + t.Fatalf("Failed to create tmp directory: %v", err) + } + + encoder := &SoftwareEncoder{codec: "libvpx-vp9"} + config := &EncodeConfig{ + InputPattern: filepath.Join(exampleDirAbs, "frame_%04d.exr"), + OutputPath: filepath.Join(tmpDir, "test_exr_vp9.webm"), + StartFrame: 800, + FrameRate: 24.0, + WorkDir: tmpDir, + UseAlpha: false, + TwoPass: false, // Use single pass for faster testing + SourceFormat: "exr", + } + + // Build and run command + cmd := encoder.BuildCommand(config) + if cmd == nil { + t.Fatal("BuildCommand returned nil") + } + + // Capture stderr to see what went wrong + output, err := cmd.CombinedOutput() + if err != nil { + t.Errorf("FFmpeg command failed: %v\nCommand output: %s", err, string(output)) + return + } + + // Verify output file was created + if _, err := os.Stat(config.OutputPath); os.IsNotExist(err) { + t.Errorf("Output file was not created: %s\nCommand output: %s", config.OutputPath, string(output)) + } else { + t.Logf("Successfully created output file: %s", config.OutputPath) + // Verify file has content + info, _ := os.Stat(config.OutputPath) + if info.Size() == 0 { + t.Errorf("Output file was created but is empty\nCommand output: %s", string(output)) + } else { + t.Logf("Output file size: %d bytes", info.Size()) + } + } +} + +func TestIntegration_Encode_EXR_AV1(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + // Check if example file exists + exampleDir := filepath.Join("..", "..", "..", "examples") + exrFile := filepath.Join(exampleDir, "frame_0800.exr") + if _, err := os.Stat(exrFile); os.IsNotExist(err) { + t.Skipf("Example file not found: %s", exrFile) + } + + // Check if AV1 encoder is available + checkCmd := exec.Command("ffmpeg", "-hide_banner", "-encoders") + output, err := checkCmd.CombinedOutput() + if err != nil || !strings.Contains(string(output), "libaom-av1") { + t.Skip("AV1 encoder (libaom-av1) not available in ffmpeg") + } + + // Get absolute paths + workspaceRoot, err := filepath.Abs(filepath.Join("..", "..", "..")) + if err != nil { + t.Fatalf("Failed to get workspace root: %v", err) + } + exampleDirAbs, err := filepath.Abs(exampleDir) + if err != nil { + t.Fatalf("Failed to get example directory: %v", err) + } + tmpDir := filepath.Join(workspaceRoot, "tmp") + if err := os.MkdirAll(tmpDir, 0755); err != nil { + t.Fatalf("Failed to create tmp directory: %v", err) + } + + encoder := &SoftwareEncoder{codec: "libaom-av1"} + config := &EncodeConfig{ + InputPattern: filepath.Join(exampleDirAbs, "frame_%04d.exr"), + OutputPath: filepath.Join(tmpDir, "test_exr_av1.mp4"), + StartFrame: 800, + FrameRate: 24.0, + WorkDir: tmpDir, + UseAlpha: false, + TwoPass: false, + SourceFormat: "exr", + } + + // Build and run command + cmd := encoder.BuildCommand(config) + cmdOutput, err := cmd.CombinedOutput() + if err != nil { + t.Errorf("FFmpeg command failed: %v\nCommand output: %s", err, string(cmdOutput)) + return + } + + // Verify output file was created + if _, err := os.Stat(config.OutputPath); os.IsNotExist(err) { + t.Errorf("Output file was not created: %s\nCommand output: %s", config.OutputPath, string(cmdOutput)) + } else { + t.Logf("Successfully created AV1 output file: %s", config.OutputPath) + info, _ := os.Stat(config.OutputPath) + if info.Size() == 0 { + t.Errorf("Output file was created but is empty\nCommand output: %s", string(cmdOutput)) + } else { + t.Logf("Output file size: %d bytes", info.Size()) + } + } +} + +func TestIntegration_Encode_EXR_VP9_WithAlpha(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + // Check if example file exists + exampleDir := filepath.Join("..", "..", "..", "examples") + exrFile := filepath.Join(exampleDir, "frame_0800.exr") + if _, err := os.Stat(exrFile); os.IsNotExist(err) { + t.Skipf("Example file not found: %s", exrFile) + } + + // Check if VP9 encoder is available + checkCmd := exec.Command("ffmpeg", "-hide_banner", "-encoders") + output, err := checkCmd.CombinedOutput() + if err != nil || !strings.Contains(string(output), "libvpx-vp9") { + t.Skip("VP9 encoder (libvpx-vp9) not available in ffmpeg") + } + + // Get absolute paths + workspaceRoot, err := filepath.Abs(filepath.Join("..", "..", "..")) + if err != nil { + t.Fatalf("Failed to get workspace root: %v", err) + } + exampleDirAbs, err := filepath.Abs(exampleDir) + if err != nil { + t.Fatalf("Failed to get example directory: %v", err) + } + tmpDir := filepath.Join(workspaceRoot, "tmp") + if err := os.MkdirAll(tmpDir, 0755); err != nil { + t.Fatalf("Failed to create tmp directory: %v", err) + } + + encoder := &SoftwareEncoder{codec: "libvpx-vp9"} + config := &EncodeConfig{ + InputPattern: filepath.Join(exampleDirAbs, "frame_%04d.exr"), + OutputPath: filepath.Join(tmpDir, "test_exr_vp9_alpha.webm"), + StartFrame: 800, + FrameRate: 24.0, + WorkDir: tmpDir, + UseAlpha: true, // Test with alpha + TwoPass: false, // Use single pass for faster testing + SourceFormat: "exr", + } + + // Build and run command + cmd := encoder.BuildCommand(config) + if cmd == nil { + t.Fatal("BuildCommand returned nil") + } + + // Capture stderr to see what went wrong + cmdOutput, err := cmd.CombinedOutput() + if err != nil { + t.Errorf("FFmpeg command failed: %v\nCommand output: %s", err, string(cmdOutput)) + return + } + + // Verify output file was created + if _, err := os.Stat(config.OutputPath); os.IsNotExist(err) { + t.Errorf("Output file was not created: %s\nCommand output: %s", config.OutputPath, string(cmdOutput)) + } else { + t.Logf("Successfully created VP9 output file with alpha: %s", config.OutputPath) + info, _ := os.Stat(config.OutputPath) + if info.Size() == 0 { + t.Errorf("Output file was created but is empty\nCommand output: %s", string(cmdOutput)) + } else { + t.Logf("Output file size: %d bytes", info.Size()) + } + } +} + +// Helper function to copy files +func copyFile(src, dst string) error { + data, err := os.ReadFile(src) + if err != nil { + return err + } + return os.WriteFile(dst, data, 0644) +} diff --git a/internal/runner/runner.go b/internal/runner/runner.go new file mode 100644 index 0000000..1d7fa9d --- /dev/null +++ b/internal/runner/runner.go @@ -0,0 +1,361 @@ +// Package runner provides the Jiggablend render runner. +package runner + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "log" + "net" + "os" + "os/exec" + "strings" + "sync" + "time" + + "jiggablend/internal/runner/api" + "jiggablend/internal/runner/blender" + "jiggablend/internal/runner/encoding" + "jiggablend/internal/runner/tasks" + "jiggablend/internal/runner/workspace" + "jiggablend/pkg/executils" + "jiggablend/pkg/types" +) + +// Runner is the main render runner. +type Runner struct { + id int64 + name string + hostname string + + manager *api.ManagerClient + workspace *workspace.Manager + blender *blender.Manager + encoder *encoding.Selector + processes *executils.ProcessTracker + + processors map[string]tasks.Processor + stopChan chan struct{} + + fingerprint string + fingerprintMu sync.RWMutex +} + +// New creates a new runner. +func New(managerURL, name, hostname string) *Runner { + manager := api.NewManagerClient(managerURL) + + r := &Runner{ + name: name, + hostname: hostname, + manager: manager, + processes: executils.NewProcessTracker(), + stopChan: make(chan struct{}), + processors: make(map[string]tasks.Processor), + } + + // Generate fingerprint + r.generateFingerprint() + + return r +} + +// CheckRequiredTools verifies that required external tools are available. +func (r *Runner) CheckRequiredTools() error { + if err := exec.Command("zstd", "--version").Run(); err != nil { + return fmt.Errorf("zstd not found - required for compressed blend file support. Install with: apt install zstd") + } + log.Printf("Found zstd for compressed blend file support") + + if err := exec.Command("xvfb-run", "--help").Run(); err != nil { + return fmt.Errorf("xvfb-run not found - required for headless Blender rendering. Install with: apt install xvfb") + } + log.Printf("Found xvfb-run for headless rendering without -b option") + return nil +} + +var cachedCapabilities map[string]interface{} = nil + +// ProbeCapabilities detects hardware capabilities. +func (r *Runner) ProbeCapabilities() map[string]interface{} { + if cachedCapabilities != nil { + return cachedCapabilities + } + + caps := make(map[string]interface{}) + + // Check for ffmpeg and probe encoding capabilities + if err := exec.Command("ffmpeg", "-version").Run(); err == nil { + caps["ffmpeg"] = true + } else { + caps["ffmpeg"] = false + } + + cachedCapabilities = caps + return caps +} + +// Register registers the runner with the manager. +func (r *Runner) Register(apiKey string) (int64, error) { + caps := r.ProbeCapabilities() + + id, err := r.manager.Register(r.name, r.hostname, caps, apiKey, r.GetFingerprint()) + if err != nil { + return 0, err + } + + r.id = id + + // Initialize workspace after registration + r.workspace = workspace.NewManager(r.name) + + // Initialize blender manager + r.blender = blender.NewManager(r.manager, r.workspace.BaseDir()) + + // Initialize encoder selector + r.encoder = encoding.NewSelector() + + // Register task processors + r.processors["render"] = tasks.NewRenderProcessor() + r.processors["encode"] = tasks.NewEncodeProcessor() + + return id, nil +} + +// Start starts the job polling loop. +func (r *Runner) Start(pollInterval time.Duration) { + log.Printf("Starting job polling loop (interval: %v)", pollInterval) + + for { + select { + case <-r.stopChan: + log.Printf("Stopping job polling loop") + return + default: + } + + log.Printf("Polling for next job (runner ID: %d)", r.id) + job, err := r.manager.PollNextJob() + if err != nil { + log.Printf("Error polling for job: %v", err) + time.Sleep(pollInterval) + continue + } + + if job == nil { + log.Printf("No job available, sleeping for %v", pollInterval) + time.Sleep(pollInterval) + continue + } + + log.Printf("Received job assignment: task=%d, job=%d, type=%s", + job.Task.TaskID, job.Task.JobID, job.Task.TaskType) + + if err := r.executeJob(job); err != nil { + log.Printf("Error processing job: %v", err) + } + } +} + +// Stop stops the runner. +func (r *Runner) Stop() { + close(r.stopChan) +} + +// KillAllProcesses kills all running processes. +func (r *Runner) KillAllProcesses() { + log.Printf("Killing all running processes...") + killedCount := r.processes.KillAll() + + // Release all allocated devices + if r.encoder != nil { + // Device pool cleanup is handled internally + } + + log.Printf("Killed %d process(es)", killedCount) +} + +// Cleanup removes the workspace directory. +func (r *Runner) Cleanup() { + if r.workspace != nil { + r.workspace.Cleanup() + } +} + +// executeJob handles a job using per-job WebSocket connection. +func (r *Runner) executeJob(job *api.NextJobResponse) (err error) { + // Recover from panics to prevent runner process crashes during task execution + defer func() { + if rec := recover(); rec != nil { + log.Printf("Task execution panicked: %v", rec) + err = fmt.Errorf("task execution panicked: %v", rec) + } + }() + + // Connect to job WebSocket (no runnerID needed - authentication handles it) + jobConn := api.NewJobConnection() + if err := jobConn.Connect(r.manager.GetBaseURL(), job.JobPath, job.JobToken); err != nil { + return fmt.Errorf("failed to connect job WebSocket: %w", err) + } + defer jobConn.Close() + + log.Printf("Job WebSocket authenticated for task %d", job.Task.TaskID) + + // Create task context + workDir := r.workspace.JobDir(job.Task.JobID) + ctx := tasks.NewContext( + job.Task.TaskID, + job.Task.JobID, + job.Task.JobName, + job.Task.Frame, + job.Task.TaskType, + workDir, + job.JobToken, + job.Task.Metadata, + r.manager, + jobConn, + r.workspace, + r.blender, + r.encoder, + r.processes, + ) + + ctx.Info(fmt.Sprintf("Task assignment received (job: %d, type: %s)", + job.Task.JobID, job.Task.TaskType)) + + // Get processor for task type + processor, ok := r.processors[job.Task.TaskType] + if !ok { + return fmt.Errorf("unknown task type: %s", job.Task.TaskType) + } + + // Process the task + var processErr error + switch job.Task.TaskType { + case "render": // this task has a upload outputs step because the frames are not uploaded by the render task directly we have to do it manually here TODO: maybe we should make it work like the encode task + // Download context + contextPath := job.JobPath + "/context.tar" + if err := r.downloadContext(job.Task.JobID, contextPath, job.JobToken); err != nil { + jobConn.Log(job.Task.TaskID, types.LogLevelError, fmt.Sprintf("Failed to download context: %v", err)) + jobConn.Complete(job.Task.TaskID, false, fmt.Errorf("failed to download context: %v", err)) + return fmt.Errorf("failed to download context: %w", err) + } + processErr = processor.Process(ctx) + if processErr == nil { + processErr = r.uploadOutputs(ctx, job) + } + case "encode": // this task doesn't have a upload outputs step because the video is already uploaded by the encode task + processErr = processor.Process(ctx) + default: + return fmt.Errorf("unknown task type: %s", job.Task.TaskType) + } + + if processErr != nil { + ctx.Error(fmt.Sprintf("Task failed: %v", processErr)) + ctx.Complete(false, processErr) + return processErr + } + + ctx.Complete(true, nil) + return nil +} + +func (r *Runner) downloadContext(jobID int64, contextPath, jobToken string) error { + reader, err := r.manager.DownloadContext(contextPath, jobToken) + if err != nil { + return err + } + defer reader.Close() + + jobDir := r.workspace.JobDir(jobID) + return workspace.ExtractTar(reader, jobDir) +} + +func (r *Runner) uploadOutputs(ctx *tasks.Context, job *api.NextJobResponse) error { + outputDir := ctx.WorkDir + "/output" + uploadPath := fmt.Sprintf("/api/runner/jobs/%d/upload", job.Task.JobID) + + entries, err := os.ReadDir(outputDir) + if err != nil { + return fmt.Errorf("failed to read output directory: %w", err) + } + + for _, entry := range entries { + if entry.IsDir() { + continue + } + filePath := outputDir + "/" + entry.Name() + if err := r.manager.UploadFile(uploadPath, job.JobToken, filePath); err != nil { + log.Printf("Failed to upload %s: %v", filePath, err) + } else { + ctx.OutputUploaded(entry.Name()) + } + } + + return nil +} + +// generateFingerprint creates a unique hardware fingerprint. +func (r *Runner) generateFingerprint() { + r.fingerprintMu.Lock() + defer r.fingerprintMu.Unlock() + + var components []string + components = append(components, r.hostname) + + if machineID, err := os.ReadFile("/etc/machine-id"); err == nil { + components = append(components, strings.TrimSpace(string(machineID))) + } + + if productUUID, err := os.ReadFile("/sys/class/dmi/id/product_uuid"); err == nil { + components = append(components, strings.TrimSpace(string(productUUID))) + } + + if macAddr, err := r.getMACAddress(); err == nil { + components = append(components, macAddr) + } + + if len(components) <= 1 { + components = append(components, fmt.Sprintf("%d", os.Getpid())) + components = append(components, fmt.Sprintf("%d", time.Now().Unix())) + } + + h := sha256.New() + for _, comp := range components { + h.Write([]byte(comp)) + h.Write([]byte{0}) + } + + r.fingerprint = hex.EncodeToString(h.Sum(nil)) +} + +func (r *Runner) getMACAddress() (string, error) { + interfaces, err := net.Interfaces() + if err != nil { + return "", err + } + + for _, iface := range interfaces { + if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 { + continue + } + if len(iface.HardwareAddr) == 0 { + continue + } + return iface.HardwareAddr.String(), nil + } + + return "", fmt.Errorf("no suitable network interface found") +} + +// GetFingerprint returns the runner's hardware fingerprint. +func (r *Runner) GetFingerprint() string { + r.fingerprintMu.RLock() + defer r.fingerprintMu.RUnlock() + return r.fingerprint +} + +// GetID returns the runner ID. +func (r *Runner) GetID() int64 { + return r.id +} diff --git a/internal/runner/tasks/encode.go b/internal/runner/tasks/encode.go new file mode 100644 index 0000000..06ebd8d --- /dev/null +++ b/internal/runner/tasks/encode.go @@ -0,0 +1,588 @@ +package tasks + +import ( + "bufio" + "errors" + "fmt" + "log" + "math" + "os" + "os/exec" + "path/filepath" + "regexp" + "sort" + "strings" + + "jiggablend/internal/runner/encoding" +) + +// EncodeProcessor handles encode tasks. +type EncodeProcessor struct{} + +// NewEncodeProcessor creates a new encode processor. +func NewEncodeProcessor() *EncodeProcessor { + return &EncodeProcessor{} +} + +// Process executes an encode task. +func (p *EncodeProcessor) Process(ctx *Context) error { + ctx.Info(fmt.Sprintf("Starting encode task: job %d", ctx.JobID)) + log.Printf("Processing encode task %d for job %d", ctx.TaskID, ctx.JobID) + + // Create temporary work directory + workDir, err := ctx.Workspace.CreateVideoDir(ctx.JobID) + if err != nil { + return fmt.Errorf("failed to create work directory: %w", err) + } + defer func() { + if err := ctx.Workspace.CleanupVideoDir(ctx.JobID); err != nil { + log.Printf("Warning: Failed to cleanup encode work directory: %v", err) + } + }() + + // Get output format and frame rate + outputFormat := ctx.GetOutputFormat() + if outputFormat == "" { + outputFormat = "EXR_264_MP4" + } + frameRate := ctx.GetFrameRate() + + ctx.Info(fmt.Sprintf("Encode: detected output format '%s'", outputFormat)) + ctx.Info(fmt.Sprintf("Encode: using frame rate %.2f fps", frameRate)) + + // Get job files + files, err := ctx.Manager.GetJobFiles(ctx.JobID) + if err != nil { + ctx.Error(fmt.Sprintf("Failed to get job files: %v", err)) + return fmt.Errorf("failed to get job files: %w", err) + } + + ctx.Info(fmt.Sprintf("GetJobFiles returned %d total files for job %d", len(files), ctx.JobID)) + + // Log all files for debugging + for _, file := range files { + ctx.Info(fmt.Sprintf("File: %s (type: %s, size: %d)", file.FileName, file.FileType, file.FileSize)) + } + + // Determine source format based on output format + sourceFormat := "exr" + fileExt := ".exr" + + // Find and deduplicate frame files (EXR or PNG) + frameFileSet := make(map[string]bool) + var frameFilesList []string + for _, file := range files { + if file.FileType == "output" && strings.HasSuffix(strings.ToLower(file.FileName), fileExt) { + // Deduplicate by filename + if !frameFileSet[file.FileName] { + frameFileSet[file.FileName] = true + frameFilesList = append(frameFilesList, file.FileName) + } + } + } + + if len(frameFilesList) == 0 { + // Log why no files matched (deduplicate for error reporting) + outputFileSet := make(map[string]bool) + frameFilesOtherTypeSet := make(map[string]bool) + var outputFiles []string + var frameFilesOtherType []string + + for _, file := range files { + if file.FileType == "output" { + if !outputFileSet[file.FileName] { + outputFileSet[file.FileName] = true + outputFiles = append(outputFiles, file.FileName) + } + } + if strings.HasSuffix(strings.ToLower(file.FileName), fileExt) { + key := fmt.Sprintf("%s (type: %s)", file.FileName, file.FileType) + if !frameFilesOtherTypeSet[key] { + frameFilesOtherTypeSet[key] = true + frameFilesOtherType = append(frameFilesOtherType, key) + } + } + } + ctx.Error(fmt.Sprintf("no %s frame files found for encode: found %d total files, %d unique output files, %d unique %s files (with other types)", strings.ToUpper(fileExt[1:]), len(files), len(outputFiles), len(frameFilesOtherType), strings.ToUpper(fileExt[1:]))) + if len(outputFiles) > 0 { + ctx.Error(fmt.Sprintf("Output files found: %v", outputFiles)) + } + if len(frameFilesOtherType) > 0 { + ctx.Error(fmt.Sprintf("%s files with wrong type: %v", strings.ToUpper(fileExt[1:]), frameFilesOtherType)) + } + err := fmt.Errorf("no %s frame files found for encode", strings.ToUpper(fileExt[1:])) + return err + } + + ctx.Info(fmt.Sprintf("Found %d %s frames for encode", len(frameFilesList), strings.ToUpper(fileExt[1:]))) + + // Download frames + ctx.Info(fmt.Sprintf("Downloading %d %s frames for encode...", len(frameFilesList), strings.ToUpper(fileExt[1:]))) + + var frameFiles []string + for i, fileName := range frameFilesList { + ctx.Info(fmt.Sprintf("Downloading frame %d/%d: %s", i+1, len(frameFilesList), fileName)) + framePath := filepath.Join(workDir, fileName) + if err := ctx.Manager.DownloadFrame(ctx.JobID, fileName, framePath); err != nil { + ctx.Error(fmt.Sprintf("Failed to download %s frame %s: %v", strings.ToUpper(fileExt[1:]), fileName, err)) + log.Printf("Failed to download %s frame for encode %s: %v", strings.ToUpper(fileExt[1:]), fileName, err) + continue + } + ctx.Info(fmt.Sprintf("Successfully downloaded frame %d/%d: %s", i+1, len(frameFilesList), fileName)) + frameFiles = append(frameFiles, framePath) + } + + if len(frameFiles) == 0 { + err := fmt.Errorf("failed to download any %s frames for encode", strings.ToUpper(fileExt[1:])) + ctx.Error(err.Error()) + return err + } + + sort.Strings(frameFiles) + ctx.Info(fmt.Sprintf("Downloaded %d frames", len(frameFiles))) + + // Check if EXR files have alpha channel and HDR content (only for EXR source format) + hasAlpha := false + hasHDR := false + if sourceFormat == "exr" { + // Check first frame for alpha channel and HDR using ffprobe + firstFrame := frameFiles[0] + hasAlpha = detectAlphaChannel(ctx, firstFrame) + if hasAlpha { + ctx.Info("Detected alpha channel in EXR files") + } else { + ctx.Info("No alpha channel detected in EXR files") + } + + hasHDR = detectHDR(ctx, firstFrame) + if hasHDR { + ctx.Info("Detected HDR content in EXR files") + } else { + ctx.Info("No HDR content detected in EXR files (SDR range)") + } + } + + // Generate video + // Use alpha if: + // 1. User explicitly enabled it OR source has alpha channel AND + // 2. Codec supports alpha (AV1 or VP9) + preserveAlpha := ctx.ShouldPreserveAlpha() + useAlpha := (preserveAlpha || hasAlpha) && (outputFormat == "EXR_AV1_MP4" || outputFormat == "EXR_VP9_WEBM") + if (preserveAlpha || hasAlpha) && outputFormat == "EXR_264_MP4" { + ctx.Warn("Alpha channel requested/detected but H.264 does not support alpha. Consider using EXR_AV1_MP4 or EXR_VP9_WEBM to preserve alpha.") + } + if preserveAlpha && !hasAlpha { + ctx.Warn("Alpha preservation requested but no alpha channel detected in EXR files.") + } + if useAlpha { + if preserveAlpha && hasAlpha { + ctx.Info("Alpha preservation enabled: Using alpha channel encoding") + } else if hasAlpha { + ctx.Info("Alpha channel detected - automatically enabling alpha encoding") + } + } + var outputExt string + switch outputFormat { + case "EXR_VP9_WEBM": + outputExt = "webm" + ctx.Info("Encoding WebM video with VP9 codec (with alpha channel and HDR support)...") + case "EXR_AV1_MP4": + outputExt = "mp4" + ctx.Info("Encoding MP4 video with AV1 codec (with alpha channel)...") + default: + outputExt = "mp4" + ctx.Info("Encoding MP4 video with H.264 codec...") + } + + outputVideo := filepath.Join(workDir, fmt.Sprintf("output_%d.%s", ctx.JobID, outputExt)) + + // Build input pattern + firstFrame := frameFiles[0] + baseName := filepath.Base(firstFrame) + re := regexp.MustCompile(`_(\d+)\.`) + var pattern string + var startNumber int + frameNumStr := re.FindStringSubmatch(baseName) + if len(frameNumStr) > 1 { + pattern = re.ReplaceAllString(baseName, "_%04d.") + fmt.Sscanf(frameNumStr[1], "%d", &startNumber) + } else { + startNumber = extractFrameNumber(baseName) + pattern = strings.Replace(baseName, fmt.Sprintf("%d", startNumber), "%04d", 1) + } + patternPath := filepath.Join(workDir, pattern) + + // Select encoder and build command (software encoding only) + var encoder encoding.Encoder + switch outputFormat { + case "EXR_AV1_MP4": + encoder = ctx.Encoder.SelectAV1() + case "EXR_VP9_WEBM": + encoder = ctx.Encoder.SelectVP9() + default: + encoder = ctx.Encoder.SelectH264() + } + + ctx.Info(fmt.Sprintf("Using encoder: %s (%s)", encoder.Name(), encoder.Codec())) + + // All software encoders use 2-pass for optimal quality + ctx.Info("Starting 2-pass encode for optimal quality...") + + // Pass 1 + ctx.Info("Pass 1/2: Analyzing content for optimal encode...") + softEncoder := encoder.(*encoding.SoftwareEncoder) + // Use HDR if: user explicitly enabled it OR HDR content was detected + preserveHDR := (ctx.ShouldPreserveHDR() || hasHDR) && sourceFormat == "exr" + if hasHDR && !ctx.ShouldPreserveHDR() { + ctx.Info("HDR content detected - automatically enabling HDR preservation") + } + pass1Cmd := softEncoder.BuildPass1Command(&encoding.EncodeConfig{ + InputPattern: patternPath, + OutputPath: outputVideo, + StartFrame: startNumber, + FrameRate: frameRate, + WorkDir: workDir, + UseAlpha: useAlpha, + TwoPass: true, + SourceFormat: sourceFormat, + PreserveHDR: preserveHDR, + }) + if err := pass1Cmd.Run(); err != nil { + ctx.Warn(fmt.Sprintf("Pass 1 completed (warnings expected): %v", err)) + } + + // Pass 2 + ctx.Info("Pass 2/2: Encoding with optimal quality...") + + preserveHDR = (ctx.ShouldPreserveHDR() || hasHDR) && sourceFormat == "exr" + if preserveHDR { + if hasHDR && !ctx.ShouldPreserveHDR() { + ctx.Info("HDR preservation enabled (auto-detected): Using HLG transfer with bt709 primaries") + } else { + ctx.Info("HDR preservation enabled: Using HLG transfer with bt709 primaries") + } + } + + config := &encoding.EncodeConfig{ + InputPattern: patternPath, + OutputPath: outputVideo, + StartFrame: startNumber, + FrameRate: frameRate, + WorkDir: workDir, + UseAlpha: useAlpha, + TwoPass: true, // Software encoding always uses 2-pass for quality + SourceFormat: sourceFormat, + PreserveHDR: preserveHDR, + } + + cmd := encoder.BuildCommand(config) + if cmd == nil { + return errors.New("failed to build encode command") + } + + // Set up pipes + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("failed to create stdout pipe: %w", err) + } + + stderrPipe, err := cmd.StderrPipe() + if err != nil { + return fmt.Errorf("failed to create stderr pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start encode command: %w", err) + } + + ctx.Processes.Track(ctx.TaskID, cmd) + defer ctx.Processes.Untrack(ctx.TaskID) + + // Stream stdout + stdoutDone := make(chan bool) + go func() { + defer close(stdoutDone) + scanner := bufio.NewScanner(stdoutPipe) + for scanner.Scan() { + line := scanner.Text() + if line != "" { + ctx.Info(line) + } + } + }() + + // Stream stderr + stderrDone := make(chan bool) + go func() { + defer close(stderrDone) + scanner := bufio.NewScanner(stderrPipe) + for scanner.Scan() { + line := scanner.Text() + if line != "" { + ctx.Warn(line) + } + } + }() + + err = cmd.Wait() + <-stdoutDone + <-stderrDone + + if err != nil { + var errMsg string + if exitErr, ok := err.(*exec.ExitError); ok { + if exitErr.ExitCode() == 137 { + errMsg = "FFmpeg was killed due to excessive memory usage (OOM)" + } else { + errMsg = fmt.Sprintf("ffmpeg encoding failed: %v", err) + } + } else { + errMsg = fmt.Sprintf("ffmpeg encoding failed: %v", err) + } + + if sizeErr := checkFFmpegSizeError(errMsg); sizeErr != nil { + ctx.Error(sizeErr.Error()) + return sizeErr + } + + ctx.Error(errMsg) + return errors.New(errMsg) + } + + // Verify output + if _, err := os.Stat(outputVideo); os.IsNotExist(err) { + err := fmt.Errorf("video %s file not created: %s", outputExt, outputVideo) + ctx.Error(err.Error()) + return err + } + + // Clean up 2-pass log files + os.Remove(filepath.Join(workDir, "ffmpeg2pass-0.log")) + os.Remove(filepath.Join(workDir, "ffmpeg2pass-0.log.mbtree")) + + ctx.Info(fmt.Sprintf("%s video encoded successfully", strings.ToUpper(outputExt))) + + // Upload video + ctx.Info(fmt.Sprintf("Uploading encoded %s video...", strings.ToUpper(outputExt))) + + uploadPath := fmt.Sprintf("/api/runner/jobs/%d/upload", ctx.JobID) + if err := ctx.Manager.UploadFile(uploadPath, ctx.JobToken, outputVideo); err != nil { + ctx.Error(fmt.Sprintf("Failed to upload %s: %v", strings.ToUpper(outputExt), err)) + return fmt.Errorf("failed to upload %s: %w", strings.ToUpper(outputExt), err) + } + + ctx.Info(fmt.Sprintf("Successfully uploaded %s: %s", strings.ToUpper(outputExt), filepath.Base(outputVideo))) + + log.Printf("Successfully generated and uploaded %s for job %d: %s", strings.ToUpper(outputExt), ctx.JobID, filepath.Base(outputVideo)) + return nil +} + +// detectAlphaChannel checks if an EXR file has an alpha channel using ffprobe +func detectAlphaChannel(ctx *Context, filePath string) bool { + // Use ffprobe to check pixel format and stream properties + // EXR files with alpha will have formats like gbrapf32le (RGBA) vs gbrpf32le (RGB) + cmd := exec.Command("ffprobe", + "-v", "error", + "-select_streams", "v:0", + "-show_entries", "stream=pix_fmt:stream=codec_name", + "-of", "default=noprint_wrappers=1", + filePath, + ) + + output, err := cmd.Output() + if err != nil { + // If ffprobe fails, assume no alpha (conservative approach) + ctx.Warn(fmt.Sprintf("Failed to detect alpha channel in %s: %v", filepath.Base(filePath), err)) + return false + } + + outputStr := string(output) + // Check pixel format - EXR with alpha typically has 'a' in the format name (e.g., gbrapf32le) + // Also check for formats that explicitly indicate alpha + hasAlpha := strings.Contains(outputStr, "pix_fmt=gbrap") || + strings.Contains(outputStr, "pix_fmt=rgba") || + strings.Contains(outputStr, "pix_fmt=yuva") || + strings.Contains(outputStr, "pix_fmt=abgr") + + if hasAlpha { + ctx.Info(fmt.Sprintf("Detected alpha channel in EXR file: %s", filepath.Base(filePath))) + } + + return hasAlpha +} + +// detectHDR checks if an EXR file contains HDR content using ffprobe +func detectHDR(ctx *Context, filePath string) bool { + // First, check if the pixel format supports HDR (32-bit float) + cmd := exec.Command("ffprobe", + "-v", "error", + "-select_streams", "v:0", + "-show_entries", "stream=pix_fmt", + "-of", "default=noprint_wrappers=1:nokey=1", + filePath, + ) + + output, err := cmd.Output() + if err != nil { + // If ffprobe fails, assume no HDR (conservative approach) + ctx.Warn(fmt.Sprintf("Failed to detect HDR in %s: %v", filepath.Base(filePath), err)) + return false + } + + pixFmt := strings.TrimSpace(string(output)) + // EXR files with 32-bit float format (gbrpf32le, gbrapf32le) can contain HDR + // Check if it's a 32-bit float format + isFloat32 := strings.Contains(pixFmt, "f32") || strings.Contains(pixFmt, "f32le") + + if !isFloat32 { + // Not a float format, definitely not HDR + return false + } + + // For 32-bit float EXR, sample pixels to check if values exceed SDR range (> 1.0) + // Use ffmpeg to extract pixel statistics - check max pixel values + // This is more efficient than sampling individual pixels + cmd = exec.Command("ffmpeg", + "-v", "error", + "-i", filePath, + "-vf", "signalstats", + "-f", "null", + "-", + ) + + output, err = cmd.CombinedOutput() + if err != nil { + // If stats extraction fails, try sampling a few pixels directly + return detectHDRBySampling(ctx, filePath) + } + + // Check output for max pixel values + outputStr := string(output) + // Look for max values in the signalstats output + // If we find values > 1.0, it's HDR + if strings.Contains(outputStr, "MAX") { + // Try to extract max values from signalstats output + // Format is typically like: YMAX:1.234 UMAX:0.567 VMAX:0.890 + // For EXR (RGB), we need to check R, G, B channels + // Since signalstats works on YUV, we'll use a different approach + return detectHDRBySampling(ctx, filePath) + } + + // Fallback to pixel sampling + return detectHDRBySampling(ctx, filePath) +} + +// detectHDRBySampling samples pixels from multiple regions to detect HDR content +func detectHDRBySampling(ctx *Context, filePath string) bool { + // Sample multiple 10x10 regions from different parts of the image + // This gives us better coverage than a single sample + sampleRegions := []string{ + "crop=10:10:iw/4:ih/4", // Top-left quadrant + "crop=10:10:iw*3/4:ih/4", // Top-right quadrant + "crop=10:10:iw/4:ih*3/4", // Bottom-left quadrant + "crop=10:10:iw*3/4:ih*3/4", // Bottom-right quadrant + "crop=10:10:iw/2:ih/2", // Center + } + + for _, region := range sampleRegions { + cmd := exec.Command("ffmpeg", + "-v", "error", + "-i", filePath, + "-vf", fmt.Sprintf("%s,scale=1:1", region), + "-f", "rawvideo", + "-pix_fmt", "gbrpf32le", + "-", + ) + + output, err := cmd.Output() + if err != nil { + continue // Skip this region if sampling fails + } + + // Parse the float32 values (4 bytes per float, 3 channels RGB) + if len(output) >= 12 { // At least 3 floats (RGB) = 12 bytes + for i := 0; i < len(output)-11; i += 12 { + // Read RGB values (little-endian float32) + r := float32FromBytes(output[i : i+4]) + g := float32FromBytes(output[i+4 : i+8]) + b := float32FromBytes(output[i+8 : i+12]) + + // Check if any channel exceeds 1.0 (SDR range) + if r > 1.0 || g > 1.0 || b > 1.0 { + maxVal := max(r, max(g, b)) + ctx.Info(fmt.Sprintf("Detected HDR content in EXR file: %s (max value: %.2f)", filepath.Base(filePath), maxVal)) + return true + } + } + } + } + + // If we sampled multiple regions and none exceed 1.0, it's likely SDR content + // But since it's 32-bit float format, user can still manually enable HDR if needed + return false +} + +// float32FromBytes converts 4 bytes (little-endian) to float32 +func float32FromBytes(bytes []byte) float32 { + if len(bytes) < 4 { + return 0 + } + bits := uint32(bytes[0]) | uint32(bytes[1])<<8 | uint32(bytes[2])<<16 | uint32(bytes[3])<<24 + return math.Float32frombits(bits) +} + +// max returns the maximum of two float32 values +func max(a, b float32) float32 { + if a > b { + return a + } + return b +} + +func extractFrameNumber(filename string) int { + parts := strings.Split(filepath.Base(filename), "_") + if len(parts) < 2 { + return 0 + } + framePart := strings.Split(parts[1], ".")[0] + var frameNum int + fmt.Sscanf(framePart, "%d", &frameNum) + return frameNum +} + +func checkFFmpegSizeError(output string) error { + outputLower := strings.ToLower(output) + + if strings.Contains(outputLower, "hardware does not support encoding at size") { + constraintsMatch := regexp.MustCompile(`constraints:\s*width\s+(\d+)-(\d+)\s+height\s+(\d+)-(\d+)`).FindStringSubmatch(output) + if len(constraintsMatch) == 5 { + return fmt.Errorf("video frame size is outside hardware encoder limits. Hardware requires: width %s-%s, height %s-%s", + constraintsMatch[1], constraintsMatch[2], constraintsMatch[3], constraintsMatch[4]) + } + return fmt.Errorf("video frame size is outside hardware encoder limits") + } + + if strings.Contains(outputLower, "picture size") && strings.Contains(outputLower, "is invalid") { + sizeMatch := regexp.MustCompile(`picture size\s+(\d+)x(\d+)`).FindStringSubmatch(output) + if len(sizeMatch) == 3 { + return fmt.Errorf("invalid video frame size: %sx%s", sizeMatch[1], sizeMatch[2]) + } + return fmt.Errorf("invalid video frame size") + } + + if strings.Contains(outputLower, "error while opening encoder") && + (strings.Contains(outputLower, "width") || strings.Contains(outputLower, "height") || strings.Contains(outputLower, "size")) { + sizeMatch := regexp.MustCompile(`at size\s+(\d+)x(\d+)`).FindStringSubmatch(output) + if len(sizeMatch) == 3 { + return fmt.Errorf("hardware encoder cannot encode frame size %sx%s", sizeMatch[1], sizeMatch[2]) + } + return fmt.Errorf("hardware encoder error: frame size may be invalid") + } + + if strings.Contains(outputLower, "invalid") && + (strings.Contains(outputLower, "width") || strings.Contains(outputLower, "height") || strings.Contains(outputLower, "dimension")) { + return fmt.Errorf("invalid frame dimensions detected") + } + + return nil +} diff --git a/internal/runner/tasks/processor.go b/internal/runner/tasks/processor.go new file mode 100644 index 0000000..5b0e32f --- /dev/null +++ b/internal/runner/tasks/processor.go @@ -0,0 +1,156 @@ +// Package tasks provides task processing implementations. +package tasks + +import ( + "jiggablend/internal/runner/api" + "jiggablend/internal/runner/blender" + "jiggablend/internal/runner/encoding" + "jiggablend/internal/runner/workspace" + "jiggablend/pkg/executils" + "jiggablend/pkg/types" +) + +// Processor handles a specific task type. +type Processor interface { + Process(ctx *Context) error +} + +// Context provides task execution context. +type Context struct { + TaskID int64 + JobID int64 + JobName string + Frame int + TaskType string + WorkDir string + JobToken string + Metadata *types.BlendMetadata + + Manager *api.ManagerClient + JobConn *api.JobConnection + Workspace *workspace.Manager + Blender *blender.Manager + Encoder *encoding.Selector + Processes *executils.ProcessTracker +} + +// NewContext creates a new task context. +func NewContext( + taskID, jobID int64, + jobName string, + frame int, + taskType string, + workDir string, + jobToken string, + metadata *types.BlendMetadata, + manager *api.ManagerClient, + jobConn *api.JobConnection, + ws *workspace.Manager, + blenderMgr *blender.Manager, + encoder *encoding.Selector, + processes *executils.ProcessTracker, +) *Context { + return &Context{ + TaskID: taskID, + JobID: jobID, + JobName: jobName, + Frame: frame, + TaskType: taskType, + WorkDir: workDir, + JobToken: jobToken, + Metadata: metadata, + Manager: manager, + JobConn: jobConn, + Workspace: ws, + Blender: blenderMgr, + Encoder: encoder, + Processes: processes, + } +} + +// Log sends a log entry to the manager. +func (c *Context) Log(level types.LogLevel, message string) { + if c.JobConn != nil { + c.JobConn.Log(c.TaskID, level, message) + } +} + +// Info logs an info message. +func (c *Context) Info(message string) { + c.Log(types.LogLevelInfo, message) +} + +// Warn logs a warning message. +func (c *Context) Warn(message string) { + c.Log(types.LogLevelWarn, message) +} + +// Error logs an error message. +func (c *Context) Error(message string) { + c.Log(types.LogLevelError, message) +} + +// Progress sends a progress update. +func (c *Context) Progress(progress float64) { + if c.JobConn != nil { + c.JobConn.Progress(c.TaskID, progress) + } +} + +// OutputUploaded notifies that an output file was uploaded. +func (c *Context) OutputUploaded(fileName string) { + if c.JobConn != nil { + c.JobConn.OutputUploaded(c.TaskID, fileName) + } +} + +// Complete sends task completion. +func (c *Context) Complete(success bool, errorMsg error) { + if c.JobConn != nil { + c.JobConn.Complete(c.TaskID, success, errorMsg) + } +} + +// GetOutputFormat returns the output format from metadata or default. +func (c *Context) GetOutputFormat() string { + if c.Metadata != nil && c.Metadata.RenderSettings.OutputFormat != "" { + return c.Metadata.RenderSettings.OutputFormat + } + return "PNG" +} + +// GetFrameRate returns the frame rate from metadata or default. +func (c *Context) GetFrameRate() float64 { + if c.Metadata != nil && c.Metadata.RenderSettings.FrameRate > 0 { + return c.Metadata.RenderSettings.FrameRate + } + return 24.0 +} + +// GetBlenderVersion returns the Blender version from metadata. +func (c *Context) GetBlenderVersion() string { + if c.Metadata != nil { + return c.Metadata.BlenderVersion + } + return "" +} + +// ShouldUnhideObjects returns whether to unhide objects. +func (c *Context) ShouldUnhideObjects() bool { + return c.Metadata != nil && c.Metadata.UnhideObjects != nil && *c.Metadata.UnhideObjects +} + +// ShouldEnableExecution returns whether to enable auto-execution. +func (c *Context) ShouldEnableExecution() bool { + return c.Metadata != nil && c.Metadata.EnableExecution != nil && *c.Metadata.EnableExecution +} + +// ShouldPreserveHDR returns whether to preserve HDR range for EXR encoding. +func (c *Context) ShouldPreserveHDR() bool { + return c.Metadata != nil && c.Metadata.PreserveHDR != nil && *c.Metadata.PreserveHDR +} + +// ShouldPreserveAlpha returns whether to preserve alpha channel for EXR encoding. +func (c *Context) ShouldPreserveAlpha() bool { + return c.Metadata != nil && c.Metadata.PreserveAlpha != nil && *c.Metadata.PreserveAlpha +} diff --git a/internal/runner/tasks/render.go b/internal/runner/tasks/render.go new file mode 100644 index 0000000..a7d5b48 --- /dev/null +++ b/internal/runner/tasks/render.go @@ -0,0 +1,301 @@ +package tasks + +import ( + "bufio" + "encoding/json" + "errors" + "fmt" + "log" + "os" + "os/exec" + "path/filepath" + "strings" + + "jiggablend/internal/runner/blender" + "jiggablend/internal/runner/workspace" + "jiggablend/pkg/scripts" + "jiggablend/pkg/types" +) + +// RenderProcessor handles render tasks. +type RenderProcessor struct{} + +// NewRenderProcessor creates a new render processor. +func NewRenderProcessor() *RenderProcessor { + return &RenderProcessor{} +} + +// Process executes a render task. +func (p *RenderProcessor) Process(ctx *Context) error { + ctx.Info(fmt.Sprintf("Starting task: job %d, frame %d, format: %s", + ctx.JobID, ctx.Frame, ctx.GetOutputFormat())) + log.Printf("Processing task %d: job %d, frame %d", ctx.TaskID, ctx.JobID, ctx.Frame) + + // Find .blend file + blendFile, err := workspace.FindFirstBlendFile(ctx.WorkDir) + if err != nil { + return fmt.Errorf("failed to find blend file: %w", err) + } + + // Get Blender binary + blenderBinary := "blender" + if version := ctx.GetBlenderVersion(); version != "" { + ctx.Info(fmt.Sprintf("Job requires Blender %s", version)) + binaryPath, err := ctx.Blender.GetBinaryPath(version) + if err != nil { + ctx.Warn(fmt.Sprintf("Could not get Blender %s, using system blender: %v", version, err)) + } else { + blenderBinary = binaryPath + ctx.Info(fmt.Sprintf("Using Blender binary: %s", blenderBinary)) + } + } else { + ctx.Info("No Blender version specified, using system blender") + } + + // Create output directory + outputDir := filepath.Join(ctx.WorkDir, "output") + if err := os.MkdirAll(outputDir, 0755); err != nil { + return fmt.Errorf("failed to create output directory: %w", err) + } + + // Create home directory for Blender inside workspace + blenderHome := filepath.Join(ctx.WorkDir, "home") + if err := os.MkdirAll(blenderHome, 0755); err != nil { + return fmt.Errorf("failed to create Blender home directory: %w", err) + } + + // Determine render format + outputFormat := ctx.GetOutputFormat() + renderFormat := outputFormat + if outputFormat == "EXR_264_MP4" || outputFormat == "EXR_AV1_MP4" || outputFormat == "EXR_VP9_WEBM" { + renderFormat = "EXR" // Use EXR for maximum quality + } + + // Create render script + if err := p.createRenderScript(ctx, renderFormat); err != nil { + return err + } + + // Render + ctx.Info(fmt.Sprintf("Starting Blender render for frame %d...", ctx.Frame)) + if err := p.runBlender(ctx, blenderBinary, blendFile, outputDir, renderFormat, blenderHome); err != nil { + ctx.Error(fmt.Sprintf("Blender render failed: %v", err)) + return err + } + + // Verify output + if _, err := p.findOutputFile(ctx, outputDir, renderFormat); err != nil { + ctx.Error(fmt.Sprintf("Output verification failed: %v", err)) + return err + } + ctx.Info(fmt.Sprintf("Blender render completed for frame %d", ctx.Frame)) + + return nil +} + +func (p *RenderProcessor) createRenderScript(ctx *Context, renderFormat string) error { + formatFilePath := filepath.Join(ctx.WorkDir, "output_format.txt") + renderSettingsFilePath := filepath.Join(ctx.WorkDir, "render_settings.json") + + // Build unhide code conditionally + unhideCode := "" + if ctx.ShouldUnhideObjects() { + unhideCode = scripts.UnhideObjects + } + + // Load template and replace placeholders + scriptContent := scripts.RenderBlenderTemplate + scriptContent = strings.ReplaceAll(scriptContent, "{{UNHIDE_CODE}}", unhideCode) + scriptContent = strings.ReplaceAll(scriptContent, "{{FORMAT_FILE_PATH}}", fmt.Sprintf("%q", formatFilePath)) + scriptContent = strings.ReplaceAll(scriptContent, "{{RENDER_SETTINGS_FILE}}", fmt.Sprintf("%q", renderSettingsFilePath)) + + scriptPath := filepath.Join(ctx.WorkDir, "enable_gpu.py") + if err := os.WriteFile(scriptPath, []byte(scriptContent), 0644); err != nil { + errMsg := fmt.Sprintf("failed to create GPU enable script: %v", err) + ctx.Error(errMsg) + return errors.New(errMsg) + } + + // Write output format + outputFormat := ctx.GetOutputFormat() + ctx.Info(fmt.Sprintf("Writing output format '%s' to format file", outputFormat)) + if err := os.WriteFile(formatFilePath, []byte(outputFormat), 0644); err != nil { + errMsg := fmt.Sprintf("failed to create format file: %v", err) + ctx.Error(errMsg) + return errors.New(errMsg) + } + + // Write render settings if available + if ctx.Metadata != nil && ctx.Metadata.RenderSettings.EngineSettings != nil { + settingsJSON, err := json.Marshal(ctx.Metadata.RenderSettings) + if err == nil { + if err := os.WriteFile(renderSettingsFilePath, settingsJSON, 0644); err != nil { + ctx.Warn(fmt.Sprintf("Failed to write render settings file: %v", err)) + } + } + } + + return nil +} + +func (p *RenderProcessor) runBlender(ctx *Context, blenderBinary, blendFile, outputDir, renderFormat, blenderHome string) error { + scriptPath := filepath.Join(ctx.WorkDir, "enable_gpu.py") + + args := []string{"-b", blendFile, "--python", scriptPath} + if ctx.ShouldEnableExecution() { + args = append(args, "--enable-autoexec") + } + + // Output pattern + outputPattern := filepath.Join(outputDir, fmt.Sprintf("frame_####.%s", strings.ToLower(renderFormat))) + outputAbsPattern, _ := filepath.Abs(outputPattern) + args = append(args, "-o", outputAbsPattern) + + args = append(args, "-f", fmt.Sprintf("%d", ctx.Frame)) + + // Wrap with xvfb-run + xvfbArgs := []string{"-a", "-s", "-screen 0 800x600x24", blenderBinary} + xvfbArgs = append(xvfbArgs, args...) + cmd := exec.Command("xvfb-run", xvfbArgs...) + cmd.Dir = ctx.WorkDir + + // Set up environment with custom HOME directory + env := os.Environ() + // Remove existing HOME if present and add our custom one + newEnv := make([]string, 0, len(env)+1) + for _, e := range env { + if !strings.HasPrefix(e, "HOME=") { + newEnv = append(newEnv, e) + } + } + newEnv = append(newEnv, fmt.Sprintf("HOME=%s", blenderHome)) + cmd.Env = newEnv + + // Set up pipes + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("failed to create stdout pipe: %w", err) + } + + stderrPipe, err := cmd.StderrPipe() + if err != nil { + return fmt.Errorf("failed to create stderr pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start blender: %w", err) + } + + // Track process + ctx.Processes.Track(ctx.TaskID, cmd) + defer ctx.Processes.Untrack(ctx.TaskID) + + // Stream stdout + stdoutDone := make(chan bool) + go func() { + defer close(stdoutDone) + scanner := bufio.NewScanner(stdoutPipe) + for scanner.Scan() { + line := scanner.Text() + if line != "" { + shouldFilter, logLevel := blender.FilterLog(line) + if !shouldFilter { + ctx.Log(logLevel, line) + } + } + } + }() + + // Stream stderr + stderrDone := make(chan bool) + go func() { + defer close(stderrDone) + scanner := bufio.NewScanner(stderrPipe) + for scanner.Scan() { + line := scanner.Text() + if line != "" { + shouldFilter, logLevel := blender.FilterLog(line) + if !shouldFilter { + if logLevel == types.LogLevelInfo { + logLevel = types.LogLevelWarn + } + ctx.Log(logLevel, line) + } + } + } + }() + + // Wait for completion + err = cmd.Wait() + <-stdoutDone + <-stderrDone + + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + if exitErr.ExitCode() == 137 { + return errors.New("Blender was killed due to excessive memory usage (OOM)") + } + } + return fmt.Errorf("blender failed: %w", err) + } + + return nil +} + +func (p *RenderProcessor) findOutputFile(ctx *Context, outputDir, renderFormat string) (string, error) { + entries, err := os.ReadDir(outputDir) + if err != nil { + return "", fmt.Errorf("failed to read output directory: %w", err) + } + + ctx.Info("Checking output directory for files...") + + // Try exact match first + expectedFile := filepath.Join(outputDir, fmt.Sprintf("frame_%04d.%s", ctx.Frame, strings.ToLower(renderFormat))) + if _, err := os.Stat(expectedFile); err == nil { + ctx.Info(fmt.Sprintf("Found output file: %s", filepath.Base(expectedFile))) + return expectedFile, nil + } + + // Try without zero padding + altFile := filepath.Join(outputDir, fmt.Sprintf("frame_%d.%s", ctx.Frame, strings.ToLower(renderFormat))) + if _, err := os.Stat(altFile); err == nil { + ctx.Info(fmt.Sprintf("Found output file: %s", filepath.Base(altFile))) + return altFile, nil + } + + // Try just frame number + altFile2 := filepath.Join(outputDir, fmt.Sprintf("%04d.%s", ctx.Frame, strings.ToLower(renderFormat))) + if _, err := os.Stat(altFile2); err == nil { + ctx.Info(fmt.Sprintf("Found output file: %s", filepath.Base(altFile2))) + return altFile2, nil + } + + // Search through all files + for _, entry := range entries { + if !entry.IsDir() { + fileName := entry.Name() + if strings.Contains(fileName, "%04d") || strings.Contains(fileName, "%d") { + ctx.Warn(fmt.Sprintf("Skipping file with literal pattern: %s", fileName)) + continue + } + frameStr := fmt.Sprintf("%d", ctx.Frame) + frameStrPadded := fmt.Sprintf("%04d", ctx.Frame) + if strings.Contains(fileName, frameStrPadded) || + (strings.Contains(fileName, frameStr) && strings.HasSuffix(strings.ToLower(fileName), strings.ToLower(renderFormat))) { + outputFile := filepath.Join(outputDir, fileName) + ctx.Info(fmt.Sprintf("Found output file: %s", fileName)) + return outputFile, nil + } + } + } + + // Not found + fileList := []string{} + for _, entry := range entries { + if !entry.IsDir() { + fileList = append(fileList, entry.Name()) + } + } + return "", fmt.Errorf("output file not found: %s\nFiles in output directory: %v", expectedFile, fileList) +} diff --git a/internal/runner/workspace/archive.go b/internal/runner/workspace/archive.go new file mode 100644 index 0000000..f587ef8 --- /dev/null +++ b/internal/runner/workspace/archive.go @@ -0,0 +1,146 @@ +package workspace + +import ( + "archive/tar" + "fmt" + "io" + "log" + "os" + "path/filepath" + "strings" +) + +// ExtractTar extracts a tar archive from a reader to a directory. +func ExtractTar(reader io.Reader, destDir string) error { + if err := os.MkdirAll(destDir, 0755); err != nil { + return fmt.Errorf("failed to create destination directory: %w", err) + } + + tarReader := tar.NewReader(reader) + + for { + header, err := tarReader.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("failed to read tar header: %w", err) + } + + // Sanitize path to prevent directory traversal + targetPath := filepath.Join(destDir, header.Name) + if !strings.HasPrefix(filepath.Clean(targetPath), filepath.Clean(destDir)+string(os.PathSeparator)) { + return fmt.Errorf("invalid file path in tar: %s", header.Name) + } + + switch header.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + case tar.TypeReg: + if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil { + return fmt.Errorf("failed to create parent directory: %w", err) + } + + outFile, err := os.Create(targetPath) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + + if _, err := io.Copy(outFile, tarReader); err != nil { + outFile.Close() + return fmt.Errorf("failed to write file: %w", err) + } + outFile.Close() + + if err := os.Chmod(targetPath, os.FileMode(header.Mode)); err != nil { + log.Printf("Warning: failed to set file permissions: %v", err) + } + } + } + + return nil +} + +// ExtractTarStripPrefix extracts a tar archive, stripping the top-level directory. +// Useful for Blender archives like "blender-4.2.3-linux-x64/". +func ExtractTarStripPrefix(reader io.Reader, destDir string) error { + if err := os.MkdirAll(destDir, 0755); err != nil { + return err + } + + tarReader := tar.NewReader(reader) + stripPrefix := "" + + for { + header, err := tarReader.Next() + if err == io.EOF { + break + } + if err != nil { + return err + } + + // Determine strip prefix from first entry (e.g., "blender-4.2.3-linux-x64/") + if stripPrefix == "" { + parts := strings.SplitN(header.Name, "/", 2) + if len(parts) > 0 { + stripPrefix = parts[0] + "/" + } + } + + // Strip the top-level directory + name := strings.TrimPrefix(header.Name, stripPrefix) + if name == "" { + continue + } + + targetPath := filepath.Join(destDir, name) + + switch header.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil { + return err + } + + case tar.TypeReg: + if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil { + return err + } + outFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode)) + if err != nil { + return err + } + if _, err := io.Copy(outFile, tarReader); err != nil { + outFile.Close() + return err + } + outFile.Close() + + case tar.TypeSymlink: + if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil { + return err + } + os.Remove(targetPath) // Remove existing symlink if present + if err := os.Symlink(header.Linkname, targetPath); err != nil { + return err + } + } + } + + return nil +} + +// ExtractTarFile extracts a tar file to a directory. +func ExtractTarFile(tarPath, destDir string) error { + file, err := os.Open(tarPath) + if err != nil { + return fmt.Errorf("failed to open tar file: %w", err) + } + defer file.Close() + + return ExtractTar(file, destDir) +} + diff --git a/internal/runner/workspace/workspace.go b/internal/runner/workspace/workspace.go new file mode 100644 index 0000000..8f2fdf3 --- /dev/null +++ b/internal/runner/workspace/workspace.go @@ -0,0 +1,217 @@ +// Package workspace manages runner workspace directories. +package workspace + +import ( + "fmt" + "log" + "os" + "path/filepath" + "strings" +) + +// Manager handles workspace directory operations. +type Manager struct { + baseDir string + runnerName string +} + +// NewManager creates a new workspace manager. +func NewManager(runnerName string) *Manager { + m := &Manager{ + runnerName: sanitizeName(runnerName), + } + m.init() + return m +} + +func sanitizeName(name string) string { + name = strings.ReplaceAll(name, " ", "_") + name = strings.ReplaceAll(name, "/", "_") + name = strings.ReplaceAll(name, "\\", "_") + name = strings.ReplaceAll(name, ":", "_") + return name +} + +func (m *Manager) init() { + // Prefer current directory if writable, otherwise use temp + baseDir := os.TempDir() + if cwd, err := os.Getwd(); err == nil { + baseDir = cwd + } + + m.baseDir = filepath.Join(baseDir, "jiggablend-workspaces", m.runnerName) + if err := os.MkdirAll(m.baseDir, 0755); err != nil { + log.Printf("Warning: Failed to create workspace directory %s: %v", m.baseDir, err) + // Fallback to temp directory + m.baseDir = filepath.Join(os.TempDir(), "jiggablend-workspaces", m.runnerName) + if err := os.MkdirAll(m.baseDir, 0755); err != nil { + log.Printf("Error: Failed to create fallback workspace directory: %v", err) + // Last resort + m.baseDir = filepath.Join(os.TempDir(), "jiggablend-runner") + os.MkdirAll(m.baseDir, 0755) + } + } + log.Printf("Runner workspace initialized at: %s", m.baseDir) +} + +// BaseDir returns the base workspace directory. +func (m *Manager) BaseDir() string { + return m.baseDir +} + +// JobDir returns the directory for a specific job. +func (m *Manager) JobDir(jobID int64) string { + return filepath.Join(m.baseDir, fmt.Sprintf("job-%d", jobID)) +} + +// VideoDir returns the directory for encoding. +func (m *Manager) VideoDir(jobID int64) string { + return filepath.Join(m.baseDir, fmt.Sprintf("job-%d-video", jobID)) +} + +// BlenderDir returns the directory for Blender installations. +func (m *Manager) BlenderDir() string { + return filepath.Join(m.baseDir, "blender-versions") +} + +// CreateJobDir creates and returns the job directory. +func (m *Manager) CreateJobDir(jobID int64) (string, error) { + dir := m.JobDir(jobID) + if err := os.MkdirAll(dir, 0755); err != nil { + return "", fmt.Errorf("failed to create job directory: %w", err) + } + return dir, nil +} + +// CreateVideoDir creates and returns the encode directory. +func (m *Manager) CreateVideoDir(jobID int64) (string, error) { + dir := m.VideoDir(jobID) + if err := os.MkdirAll(dir, 0755); err != nil { + return "", fmt.Errorf("failed to create video directory: %w", err) + } + return dir, nil +} + +// CleanupJobDir removes a job directory. +func (m *Manager) CleanupJobDir(jobID int64) error { + dir := m.JobDir(jobID) + return os.RemoveAll(dir) +} + +// CleanupVideoDir removes an encode directory. +func (m *Manager) CleanupVideoDir(jobID int64) error { + dir := m.VideoDir(jobID) + return os.RemoveAll(dir) +} + +// Cleanup removes the entire workspace directory. +func (m *Manager) Cleanup() { + if m.baseDir != "" { + log.Printf("Cleaning up workspace directory: %s", m.baseDir) + if err := os.RemoveAll(m.baseDir); err != nil { + log.Printf("Warning: Failed to remove workspace directory %s: %v", m.baseDir, err) + } else { + log.Printf("Successfully removed workspace directory: %s", m.baseDir) + } + } + + // Also clean up any orphaned jiggablend directories + cleanupOrphanedWorkspaces() +} + +// cleanupOrphanedWorkspaces removes any jiggablend workspace directories +// that might be left behind from previous runs or crashes. +func cleanupOrphanedWorkspaces() { + log.Printf("Cleaning up orphaned jiggablend workspace directories...") + + dirsToCheck := []string{".", os.TempDir()} + for _, baseDir := range dirsToCheck { + workspaceDir := filepath.Join(baseDir, "jiggablend-workspaces") + if _, err := os.Stat(workspaceDir); err == nil { + log.Printf("Removing orphaned workspace directory: %s", workspaceDir) + if err := os.RemoveAll(workspaceDir); err != nil { + log.Printf("Warning: Failed to remove workspace directory %s: %v", workspaceDir, err) + } else { + log.Printf("Successfully removed workspace directory: %s", workspaceDir) + } + } + } +} + +// FindBlendFiles finds all .blend files in a directory. +func FindBlendFiles(dir string) ([]string, error) { + var blendFiles []string + + err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".blend") { + // Check it's not a Blender save file (.blend1, .blend2, etc.) + lower := strings.ToLower(info.Name()) + idx := strings.LastIndex(lower, ".blend") + if idx != -1 { + suffix := lower[idx+len(".blend"):] + isSaveFile := false + if len(suffix) > 0 { + isSaveFile = true + for _, r := range suffix { + if r < '0' || r > '9' { + isSaveFile = false + break + } + } + } + if !isSaveFile { + relPath, _ := filepath.Rel(dir, path) + blendFiles = append(blendFiles, relPath) + } + } + } + return nil + }) + + return blendFiles, err +} + +// FindFirstBlendFile finds the first .blend file in a directory. +func FindFirstBlendFile(dir string) (string, error) { + var blendFile string + + err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".blend") { + lower := strings.ToLower(info.Name()) + idx := strings.LastIndex(lower, ".blend") + if idx != -1 { + suffix := lower[idx+len(".blend"):] + isSaveFile := false + if len(suffix) > 0 { + isSaveFile = true + for _, r := range suffix { + if r < '0' || r > '9' { + isSaveFile = false + break + } + } + } + if !isSaveFile { + blendFile = path + return filepath.SkipAll + } + } + } + return nil + }) + + if err != nil { + return "", err + } + if blendFile == "" { + return "", fmt.Errorf("no .blend file found in %s", dir) + } + return blendFile, nil +} + diff --git a/jiggablend b/jiggablend index 1598af5..34275ae 100755 Binary files a/jiggablend and b/jiggablend differ diff --git a/pkg/scripts/scripts/extract_metadata.py b/pkg/scripts/scripts/extract_metadata.py index c25723d..3bb9819 100644 --- a/pkg/scripts/scripts/extract_metadata.py +++ b/pkg/scripts/scripts/extract_metadata.py @@ -339,12 +339,8 @@ object_count = len(scene.objects) material_count = len(bpy.data.materials) # Extract Blender version info -# bpy.app.version gives the current running Blender version -# For the file's saved version, we check bpy.data.version (version the file was saved with) -blender_version = { - "current": bpy.app.version_string, # Version of Blender running this script - "file_saved_with": ".".join(map(str, bpy.data.version)) if hasattr(bpy.data, 'version') else None, # Version file was saved with -} +# bpy.data.version gives the version the file was saved with +blender_version = ".".join(map(str, bpy.data.version)) if hasattr(bpy.data, 'version') else bpy.app.version_string # Build metadata dictionary metadata = { diff --git a/pkg/scripts/scripts/render_blender.py.template b/pkg/scripts/scripts/render_blender.py.template index ed13a3e..890f95d 100644 --- a/pkg/scripts/scripts/render_blender.py.template +++ b/pkg/scripts/scripts/render_blender.py.template @@ -12,6 +12,52 @@ try: except Exception as e: print(f"Warning: Could not make paths relative: {e}") +# Auto-enable addons from blender_addons folder in context +# Supports .zip files (installed via Blender API) and already-extracted addons +blend_dir = os.path.dirname(bpy.data.filepath) if bpy.data.filepath else os.getcwd() +addons_dir = os.path.join(blend_dir, "blender_addons") + +if os.path.isdir(addons_dir): + print(f"Found blender_addons folder: {addons_dir}") + + for item in os.listdir(addons_dir): + item_path = os.path.join(addons_dir, item) + + try: + if item.endswith('.zip'): + # Install and enable zip addon using Blender's API + bpy.ops.preferences.addon_install(filepath=item_path) + # Get module name from zip (usually the folder name inside) + import zipfile + with zipfile.ZipFile(item_path, 'r') as zf: + # Find the top-level module name + names = zf.namelist() + if names: + module_name = names[0].split('/')[0] + if module_name.endswith('.py'): + module_name = module_name[:-3] + bpy.ops.preferences.addon_enable(module=module_name) + print(f" Installed and enabled addon: {module_name}") + + elif item.endswith('.py') and not item.startswith('__'): + # Single-file addon + bpy.ops.preferences.addon_install(filepath=item_path) + module_name = item[:-3] + bpy.ops.preferences.addon_enable(module=module_name) + print(f" Installed and enabled addon: {module_name}") + + elif os.path.isdir(item_path) and os.path.exists(os.path.join(item_path, '__init__.py')): + # Multi-file addon directory - add to path and enable + if addons_dir not in sys.path: + sys.path.insert(0, addons_dir) + bpy.ops.preferences.addon_enable(module=item) + print(f" Enabled addon: {item}") + + except Exception as e: + print(f" Error with addon {item}: {e}") +else: + print(f"No blender_addons folder found at: {addons_dir}") + {{UNHIDE_CODE}} # Read output format from file (created by Go code) format_file_path = {{FORMAT_FILE_PATH}} @@ -53,10 +99,10 @@ print(f"Blend file output format: {current_output_format}") if output_format_override: print(f"Overriding output format from '{current_output_format}' to '{output_format_override}'") # Map common format names to Blender's format constants - # For video formats (EXR_264_MP4, EXR_AV1_MP4), we render as EXR frames first + # For video formats, we render as appropriate frame format first format_to_use = output_format_override.upper() - if format_to_use in ['EXR_264_MP4', 'EXR_AV1_MP4']: - format_to_use = 'EXR' # Render as EXR for video formats + if format_to_use in ['EXR_264_MP4', 'EXR_AV1_MP4', 'EXR_VP9_WEBM']: + format_to_use = 'EXR' # Render as EXR for EXR video formats format_map = { 'PNG': 'PNG', diff --git a/pkg/types/types.go b/pkg/types/types.go index 301dda5..b13a430 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -32,22 +32,20 @@ const ( // Job represents a render job type Job struct { - ID int64 `json:"id"` - UserID int64 `json:"user_id"` - JobType JobType `json:"job_type"` // "render" - Name string `json:"name"` - Status JobStatus `json:"status"` - Progress float64 `json:"progress"` // 0.0 to 100.0 - FrameStart *int `json:"frame_start,omitempty"` // Only for render jobs - FrameEnd *int `json:"frame_end,omitempty"` // Only for render jobs - OutputFormat *string `json:"output_format,omitempty"` // Only for render jobs - PNG, JPEG, EXR, etc. - AllowParallelRunners *bool `json:"allow_parallel_runners,omitempty"` // Only for render jobs - TimeoutSeconds int `json:"timeout_seconds"` // Job-level timeout (24 hours default) - BlendMetadata *BlendMetadata `json:"blend_metadata,omitempty"` // Extracted metadata from blend file - CreatedAt time.Time `json:"created_at"` - StartedAt *time.Time `json:"started_at,omitempty"` - CompletedAt *time.Time `json:"completed_at,omitempty"` - ErrorMessage string `json:"error_message,omitempty"` + ID int64 `json:"id"` + UserID int64 `json:"user_id"` + JobType JobType `json:"job_type"` // "render" + Name string `json:"name"` + Status JobStatus `json:"status"` + Progress float64 `json:"progress"` // 0.0 to 100.0 + FrameStart *int `json:"frame_start,omitempty"` // Only for render jobs + FrameEnd *int `json:"frame_end,omitempty"` // Only for render jobs + OutputFormat *string `json:"output_format,omitempty"` // Only for render jobs - PNG, JPEG, EXR, etc. + BlendMetadata *BlendMetadata `json:"blend_metadata,omitempty"` // Extracted metadata from blend file + CreatedAt time.Time `json:"created_at"` + StartedAt *time.Time `json:"started_at,omitempty"` + CompletedAt *time.Time `json:"completed_at,omitempty"` + ErrorMessage string `json:"error_message,omitempty"` } // RunnerStatus represents the status of a runner @@ -86,9 +84,8 @@ const ( type TaskType string const ( - TaskTypeRender TaskType = "render" - TaskTypeMetadata TaskType = "metadata" - TaskTypeVideoGeneration TaskType = "video_generation" + TaskTypeRender TaskType = "render" + TaskTypeEncode TaskType = "encode" ) // Task represents a render task assigned to a runner @@ -96,8 +93,7 @@ type Task struct { ID int64 `json:"id"` JobID int64 `json:"job_id"` RunnerID *int64 `json:"runner_id,omitempty"` - FrameStart int `json:"frame_start"` - FrameEnd int `json:"frame_end"` + Frame int `json:"frame"` TaskType TaskType `json:"task_type"` Status TaskStatus `json:"status"` CurrentStep string `json:"current_step,omitempty"` @@ -132,16 +128,18 @@ type JobFile struct { // CreateJobRequest represents a request to create a new job type CreateJobRequest struct { - JobType JobType `json:"job_type"` // "render" - Name string `json:"name"` - FrameStart *int `json:"frame_start,omitempty"` // Required for render jobs - FrameEnd *int `json:"frame_end,omitempty"` // Required for render jobs - OutputFormat *string `json:"output_format,omitempty"` // Required for render jobs - AllowParallelRunners *bool `json:"allow_parallel_runners,omitempty"` // Optional for render jobs, defaults to true - RenderSettings *RenderSettings `json:"render_settings,omitempty"` // Optional: Override blend file render settings - UploadSessionID *string `json:"upload_session_id,omitempty"` // Optional: Session ID from file upload - UnhideObjects *bool `json:"unhide_objects,omitempty"` // Optional: Enable unhide tweaks for objects/collections - EnableExecution *bool `json:"enable_execution,omitempty"` // Optional: Enable auto-execution in Blender (adds --enable-autoexec flag, defaults to false) + JobType JobType `json:"job_type"` // "render" + Name string `json:"name"` + FrameStart *int `json:"frame_start,omitempty"` // Required for render jobs + FrameEnd *int `json:"frame_end,omitempty"` // Required for render jobs + OutputFormat *string `json:"output_format,omitempty"` // Required for render jobs + RenderSettings *RenderSettings `json:"render_settings,omitempty"` // Optional: Override blend file render settings + UploadSessionID *string `json:"upload_session_id,omitempty"` // Optional: Session ID from file upload + UnhideObjects *bool `json:"unhide_objects,omitempty"` // Optional: Enable unhide tweaks for objects/collections + EnableExecution *bool `json:"enable_execution,omitempty"` // Optional: Enable auto-execution in Blender (adds --enable-autoexec flag, defaults to false) + BlenderVersion *string `json:"blender_version,omitempty"` // Optional: Override Blender version (e.g., "4.2" or "4.2.3") + PreserveHDR *bool `json:"preserve_hdr,omitempty"` // Optional: Preserve HDR range for EXR encoding (uses HLG with bt709 primaries) + PreserveAlpha *bool `json:"preserve_alpha,omitempty"` // Optional: Preserve alpha channel for EXR encoding (requires AV1 or VP9 codec) } // UpdateJobProgressRequest represents a request to update job progress @@ -227,23 +225,26 @@ type TaskLogEntry struct { // BlendMetadata represents extracted metadata from a blend file type BlendMetadata struct { - FrameStart int `json:"frame_start"` - FrameEnd int `json:"frame_end"` - HasNegativeFrames bool `json:"has_negative_frames"` // True if blend file has negative frame numbers (not supported) - RenderSettings RenderSettings `json:"render_settings"` - SceneInfo SceneInfo `json:"scene_info"` + FrameStart int `json:"frame_start"` + FrameEnd int `json:"frame_end"` + HasNegativeFrames bool `json:"has_negative_frames"` // True if blend file has negative frame numbers (not supported) + RenderSettings RenderSettings `json:"render_settings"` + SceneInfo SceneInfo `json:"scene_info"` MissingFilesInfo *MissingFilesInfo `json:"missing_files_info,omitempty"` - UnhideObjects *bool `json:"unhide_objects,omitempty"` // Enable unhide tweaks for objects/collections - EnableExecution *bool `json:"enable_execution,omitempty"` // Enable auto-execution in Blender (adds --enable-autoexec flag, defaults to false) + UnhideObjects *bool `json:"unhide_objects,omitempty"` // Enable unhide tweaks for objects/collections + EnableExecution *bool `json:"enable_execution,omitempty"` // Enable auto-execution in Blender (adds --enable-autoexec flag, defaults to false) + BlenderVersion string `json:"blender_version,omitempty"` // Detected or overridden Blender version (e.g., "4.2" or "4.2.3") + PreserveHDR *bool `json:"preserve_hdr,omitempty"` // Preserve HDR range for EXR encoding (uses HLG with bt709 primaries) + PreserveAlpha *bool `json:"preserve_alpha,omitempty"` // Preserve alpha channel for EXR encoding (requires AV1 or VP9 codec) } // MissingFilesInfo represents information about missing files/addons type MissingFilesInfo struct { - Checked bool `json:"checked"` - HasMissing bool `json:"has_missing"` - MissingFiles []string `json:"missing_files,omitempty"` + Checked bool `json:"checked"` + HasMissing bool `json:"has_missing"` + MissingFiles []string `json:"missing_files,omitempty"` MissingAddons []string `json:"missing_addons,omitempty"` - Error string `json:"error,omitempty"` + Error string `json:"error,omitempty"` } // RenderSettings represents render settings from a blend file diff --git a/web/src/components/AdminPanel.jsx b/web/src/components/AdminPanel.jsx index 11d9ec0..150e73d 100644 --- a/web/src/components/AdminPanel.jsx +++ b/web/src/components/AdminPanel.jsx @@ -33,14 +33,16 @@ export default function AdminPanel() { } }, message: (data) => { - // Handle subscription responses + // Handle subscription responses - update both local refs and wsManager if (data.type === 'subscribed' && data.channel) { pendingSubscriptionsRef.current.delete(data.channel); subscribedChannelsRef.current.add(data.channel); + wsManager.confirmSubscription(data.channel); console.log('Successfully subscribed to channel:', data.channel); } else if (data.type === 'subscription_error' && data.channel) { pendingSubscriptionsRef.current.delete(data.channel); subscribedChannelsRef.current.delete(data.channel); + wsManager.failSubscription(data.channel); console.error('Subscription failed for channel:', data.channel, data.error); } @@ -83,27 +85,22 @@ export default function AdminPanel() { const subscribeToRunners = () => { const channel = 'runners'; - if (wsManager.getReadyState() !== WebSocket.OPEN) { - return; - } // Don't subscribe if already subscribed or pending if (subscribedChannelsRef.current.has(channel) || pendingSubscriptionsRef.current.has(channel)) { return; } - wsManager.send({ type: 'subscribe', channel }); + wsManager.subscribeToChannel(channel); + subscribedChannelsRef.current.add(channel); pendingSubscriptionsRef.current.add(channel); console.log('Subscribing to runners channel'); }; const unsubscribeFromRunners = () => { const channel = 'runners'; - if (wsManager.getReadyState() !== WebSocket.OPEN) { - return; - } if (!subscribedChannelsRef.current.has(channel)) { return; // Not subscribed } - wsManager.send({ type: 'unsubscribe', channel }); + wsManager.unsubscribeFromChannel(channel); subscribedChannelsRef.current.delete(channel); pendingSubscriptionsRef.current.delete(channel); console.log('Unsubscribed from runners channel'); diff --git a/web/src/components/FileExplorer.jsx b/web/src/components/FileExplorer.jsx index 99a0b74..ade6215 100644 --- a/web/src/components/FileExplorer.jsx +++ b/web/src/components/FileExplorer.jsx @@ -1,6 +1,6 @@ import { useState } from 'react'; -export default function FileExplorer({ files, onDownload, onPreview, isImageFile }) { +export default function FileExplorer({ files, onDownload, onPreview, onVideoPreview, isImageFile }) { const [expandedPaths, setExpandedPaths] = useState(new Set()); // Root folder collapsed by default // Build directory tree from file paths @@ -69,19 +69,29 @@ export default function FileExplorer({ files, onDownload, onPreview, isImageFile if (item.isFile) { const file = item.file; const isImage = isImageFile && isImageFile(file.file_name); + const isVideo = file.file_name.toLowerCase().endsWith('.mp4'); const sizeMB = (file.file_size / 1024 / 1024).toFixed(2); const isArchive = file.file_name.endsWith('.tar') || file.file_name.endsWith('.zip'); return (
- {isArchive ? '📦' : '📄'} + {isArchive ? '📦' : isVideo ? '🎬' : '📄'} {item.name} {sizeMB} MB
+ {isVideo && onVideoPreview && ( + + )} {isImage && onPreview && (
)} + {/* Video Preview Modal */} + {previewVideo && ( +
setPreviewVideo(null)} + > +
e.stopPropagation()} + > +
+

{previewVideo.fileName}

+ +
+
+ +
+
+
+ )} +
@@ -940,15 +1011,6 @@ export default function JobDetails({ job, onClose, onUpdate }) {
- {videoUrl && (jobDetails.output_format === 'EXR_264_MP4' || jobDetails.output_format === 'EXR_AV1_MP4') && ( -
-

- Video Preview -

- -
- )} - {contextFiles.length > 0 && (

@@ -976,9 +1038,15 @@ export default function JobDetails({ job, onClose, onUpdate }) { files={outputFiles} onDownload={handleDownload} onPreview={(file) => { - const imageUrl = jobs.downloadFile(job.id, file.id); + // Use EXR preview endpoint for EXR files, regular download for others + const imageUrl = isEXRFile(file.file_name) + ? jobs.previewEXR(job.id, file.id) + : jobs.downloadFile(job.id, file.id); setPreviewImage({ url: imageUrl, fileName: file.file_name }); }} + onVideoPreview={(file) => { + setPreviewVideo({ url: jobs.getVideoUrl(job.id), fileName: file.file_name }); + }} isImageFile={isImageFile} />

@@ -997,15 +1065,8 @@ export default function JobDetails({ job, onClose, onUpdate }) { const taskInfo = taskData[task.id] || { steps: [], logs: [] }; const { steps, logs } = taskInfo; - // Group logs by step_name - const logsByStep = {}; - logs.forEach(log => { - const stepName = log.step_name || 'general'; - if (!logsByStep[stepName]) { - logsByStep[stepName] = []; - } - logsByStep[stepName].push(log); - }); + // Sort all logs chronologically (no grouping by step_name) + const sortedLogs = [...logs].sort((a, b) => new Date(a.created_at) - new Date(b.created_at)); return (
@@ -1022,9 +1083,9 @@ export default function JobDetails({ job, onClose, onUpdate }) { {task.status} - {task.task_type === 'metadata' ? 'Metadata Extraction' : `Frame ${task.frame_start}${task.frame_end !== task.frame_start ? `-${task.frame_end}` : ''}`} + {task.task_type === 'encode' ? `Encode (${jobDetails.frame_start} - ${jobDetails.frame_end})` : `Frame ${task.frame}`} - {task.task_type && task.task_type !== 'render' && ( + {task.task_type && task.task_type !== 'render' && task.task_type !== 'encode' && ( ({task.task_type}) )}
@@ -1033,153 +1094,46 @@ export default function JobDetails({ job, onClose, onUpdate }) {
- {/* Task Content (Steps and Logs) */} + {/* Task Content (Continuous Log Stream) */} {isExpanded && (
- {/* General logs (logs without step_name) */} - {logsByStep['general'] && logsByStep['general'].length > 0 && (() => { - const generalKey = `${task.id}-general`; - const isGeneralExpanded = expandedSteps.has(generalKey); - const generalLogs = logsByStep['general']; - - return ( -
-
toggleStep(task.id, 'general')} - className="flex items-center justify-between p-2 cursor-pointer hover:bg-gray-750 transition-colors" - > -
- - {isGeneralExpanded ? '▼' : '▶'} - - General -
- - {generalLogs.length} log{generalLogs.length !== 1 ? 's' : ''} - -
- {isGeneralExpanded && ( -
-
- Logs - + {/* Header with auto-scroll */} +
+
-
{ - if (el) { - logContainerRefs.current[generalKey] = el; - // Initialize auto-scroll to true (follow logs) when ref is first set - if (shouldAutoScrollRefs.current[generalKey] === undefined) { - shouldAutoScrollRefs.current[generalKey] = true; - } - } - }} - onWheel={() => handleLogWheel(task.id, 'general')} - onMouseDown={(e) => handleLogClick(task.id, 'general', e)} - onContextMenu={(e) => handleLogClick(task.id, 'general', e)} - className="bg-black text-green-400 font-mono text-sm p-3 rounded max-h-64 overflow-y-auto" - > - {generalLogs.map((log) => ( -
- - [{new Date(log.created_at).toLocaleTimeString()}] - - {log.message} -
- ))} -
-
- )} -
- ); - })()} - - {/* Steps */} - {steps.length > 0 ? ( - steps.map((step) => { - const stepKey = `${task.id}-${step.step_name}`; - const isStepExpanded = expandedSteps.has(stepKey); - const stepLogs = logsByStep[step.step_name] || []; - - return ( -
- {/* Step Header */} -
toggleStep(task.id, step.step_name)} - className="flex items-center justify-between p-2 cursor-pointer hover:bg-gray-750 transition-colors" - > -
- - {isStepExpanded ? '▼' : '▶'} - - - {getStepStatusIcon(step.status)} - - {step.step_name} -
-
- {step.duration_ms && ( - - {(step.duration_ms / 1000).toFixed(2)}s - - )} - {stepLogs.length > 0 && ( - - {stepLogs.length} log{stepLogs.length !== 1 ? 's' : ''} - - )} -
-
- - {/* Step Logs */} - {isStepExpanded && ( -
-
- Logs
+ + {/* Logs */}
{ if (el) { - logContainerRefs.current[stepKey] = el; + logContainerRefs.current[`${task.id}-logs`] = el; // Initialize auto-scroll to true (follow logs) when ref is first set - if (shouldAutoScrollRefs.current[stepKey] === undefined) { - shouldAutoScrollRefs.current[stepKey] = true; + if (shouldAutoScrollRefs.current[`${task.id}-logs`] === undefined) { + shouldAutoScrollRefs.current[`${task.id}-logs`] = true; } } }} - onWheel={() => handleLogWheel(task.id, step.step_name)} - onMouseDown={(e) => handleLogClick(task.id, step.step_name, e)} - onContextMenu={(e) => handleLogClick(task.id, step.step_name, e)} - className="bg-black text-green-400 font-mono text-sm p-3 rounded max-h-64 overflow-y-auto" + onWheel={() => handleLogWheel(task.id, 'logs')} + onMouseDown={(e) => handleLogClick(task.id, 'logs', e)} + onContextMenu={(e) => handleLogClick(task.id, 'logs', e)} + className="bg-black text-green-400 font-mono text-sm p-3 rounded max-h-96 overflow-y-auto" > - {stepLogs.length === 0 ? ( + {sortedLogs.length === 0 ? (

No logs yet...

) : ( - stepLogs.map((log) => ( + sortedLogs.map((log) => (
-
- )} -
- ); - }) - ) : ( - logsByStep['general'] && logsByStep['general'].length > 0 ? null : ( -

No steps yet...

- ) - )}
)}
diff --git a/web/src/components/JobSubmission.jsx b/web/src/components/JobSubmission.jsx index 44fcbda..607a7bb 100644 --- a/web/src/components/JobSubmission.jsx +++ b/web/src/components/JobSubmission.jsx @@ -12,10 +12,12 @@ export default function JobSubmission({ onSuccess }) { frame_start: 1, frame_end: 10, output_format: 'PNG', - allow_parallel_runners: true, render_settings: null, // Will contain engine settings unhide_objects: false, // Unhide objects/collections tweak enable_execution: false, // Enable auto-execution in Blender + blender_version: '', // Blender version override (empty = auto-detect) + preserve_hdr: false, // Preserve HDR range for EXR encoding + preserve_alpha: false, // Preserve alpha channel for EXR encoding }); const [showAdvancedSettings, setShowAdvancedSettings] = useState(false); const [file, setFile] = useState(null); @@ -32,6 +34,8 @@ export default function JobSubmission({ onSuccess }) { const [selectedMainBlend, setSelectedMainBlend] = useState(''); const [confirmedMissingFiles, setConfirmedMissingFiles] = useState(false); // Confirmation for missing files const [uploadTimeRemaining, setUploadTimeRemaining] = useState(null); // Estimated time remaining in seconds + const [blenderVersions, setBlenderVersions] = useState([]); // Available Blender versions from server + const [loadingBlenderVersions, setLoadingBlenderVersions] = useState(false); // Use refs to track cancellation state across re-renders const isCancelledRef = useRef(false); @@ -72,6 +76,25 @@ export default function JobSubmission({ onSuccess }) { } }; + // Fetch available Blender versions on mount + useEffect(() => { + const fetchBlenderVersions = async () => { + setLoadingBlenderVersions(true); + try { + const response = await fetch('/api/blender/versions'); + if (response.ok) { + const data = await response.json(); + setBlenderVersions(data.versions || []); + } + } catch (err) { + console.error('Failed to fetch Blender versions:', err); + } finally { + setLoadingBlenderVersions(false); + } + }; + fetchBlenderVersions(); + }, []); + // Connect to shared WebSocket on mount useEffect(() => { listenerIdRef.current = wsManager.subscribe('jobsubmission', { @@ -79,14 +102,16 @@ export default function JobSubmission({ onSuccess }) { console.log('JobSubmission: Shared WebSocket connected'); }, message: (data) => { - // Handle subscription responses + // Handle subscription responses - update both local refs and wsManager if (data.type === 'subscribed' && data.channel) { pendingSubscriptionsRef.current.delete(data.channel); subscribedChannelsRef.current.add(data.channel); + wsManager.confirmSubscription(data.channel); console.log('Successfully subscribed to channel:', data.channel); } else if (data.type === 'subscription_error' && data.channel) { pendingSubscriptionsRef.current.delete(data.channel); subscribedChannelsRef.current.delete(data.channel); + wsManager.failSubscription(data.channel); console.error('Subscription failed for channel:', data.channel, data.error); // If it's the upload channel we're trying to subscribe to, show error if (data.channel.startsWith('upload:')) { @@ -94,52 +119,7 @@ export default function JobSubmission({ onSuccess }) { } } - // Handle upload progress messages - if (data.channel && data.channel.startsWith('upload:') && subscribedChannelsRef.current.has(data.channel)) { - if (data.type === 'upload_progress' || data.type === 'processing_status') { - const progress = data.data?.progress || 0; - const status = data.data?.status || 'uploading'; - const message = data.data?.message || ''; - - setUploadProgress(progress); - - // Calculate time remaining for upload progress - if (status === 'uploading' && progress > 0 && progress < 100) { - if (!uploadStartTimeRef.current) { - uploadStartTimeRef.current = Date.now(); - } - const elapsed = (Date.now() - uploadStartTimeRef.current) / 1000; // seconds - const remaining = (elapsed / progress) * (100 - progress); - setUploadTimeRemaining(remaining); - } else if (status === 'completed' || status === 'error') { - setUploadTimeRemaining(null); - uploadStartTimeRef.current = null; - } - - if (status === 'uploading') { - setMetadataStatus('extracting'); - } else if (status === 'processing' || status === 'extracting_zip' || status === 'extracting_metadata' || status === 'creating_context') { - setMetadataStatus('processing'); - // Reset time remaining for processing phase - setUploadTimeRemaining(null); - } else if (status === 'completed') { - setMetadataStatus('completed'); - setIsUploading(false); - setUploadTimeRemaining(null); - uploadStartTimeRef.current = null; - // Unsubscribe from upload channel - unsubscribeFromUploadChannel(data.channel); - } else if (status === 'error') { - setMetadataStatus('error'); - setIsUploading(false); - setUploadTimeRemaining(null); - uploadStartTimeRef.current = null; - setError(message || 'Upload/processing failed'); - // Unsubscribe from upload channel - unsubscribeFromUploadChannel(data.channel); - } - } - } + // Upload progress is now handled via HTTP response - no WebSocket messages needed }, error: (error) => { console.error('JobSubmission: Shared WebSocket error:', error); @@ -166,13 +146,10 @@ export default function JobSubmission({ onSuccess }) { // Helper function to unsubscribe from upload channel const unsubscribeFromUploadChannel = (channel) => { - if (wsManager.getReadyState() !== WebSocket.OPEN) { - return; - } if (!subscribedChannelsRef.current.has(channel)) { return; // Not subscribed } - wsManager.send({ type: 'unsubscribe', channel }); + wsManager.unsubscribeFromChannel(channel); subscribedChannelsRef.current.delete(channel); pendingSubscriptionsRef.current.delete(channel); console.log('Unsubscribed from upload channel:', channel); @@ -180,11 +157,8 @@ export default function JobSubmission({ onSuccess }) { // Helper function to unsubscribe from all channels const unsubscribeFromAllChannels = () => { - if (wsManager.getReadyState() !== WebSocket.OPEN) { - return; - } subscribedChannelsRef.current.forEach(channel => { - wsManager.send({ type: 'unsubscribe', channel }); + wsManager.unsubscribeFromChannel(channel); }); subscribedChannelsRef.current.clear(); pendingSubscriptionsRef.current.clear(); @@ -223,40 +197,40 @@ export default function JobSubmission({ onSuccess }) { uploadStartTimeRef.current = Date.now(); setMetadataStatus('extracting'); - // Upload file to new endpoint (no job required) + // Upload file and get metadata in HTTP response const result = await jobs.uploadFileForJobCreation(selectedFile, (progress) => { - // XHR progress as fallback, but WebSocket is primary + // Show upload progress during upload setUploadProgress(progress); - // Calculate time remaining for XHR progress + // Calculate time remaining for upload progress if (progress > 0 && progress < 100 && uploadStartTimeRef.current) { const elapsed = (Date.now() - uploadStartTimeRef.current) / 1000; // seconds const remaining = (elapsed / progress) * (100 - progress); setUploadTimeRemaining(remaining); + } else if (progress >= 100) { + // Upload complete - switch to processing status + setUploadProgress(100); + setMetadataStatus('processing'); + setUploadTimeRemaining(null); } }, selectedMainBlend || undefined); // Store session ID for later use when creating the job if (result.session_id) { setUploadSessionId(result.session_id); - - // Subscribe to upload progress channel - if (wsManager.getReadyState() === WebSocket.OPEN) { - const channel = `upload:${result.session_id}`; - wsManager.send({ type: 'subscribe', channel }); - // Don't set subscribedUploadChannelRef yet - wait for confirmation - console.log('Subscribing to upload channel:', channel); - } } - // Check if ZIP extraction found multiple blend files - if (result.zip_extracted && result.blend_files && result.blend_files.length > 1) { - setBlendFiles(result.blend_files); + // Upload and processing complete - metadata is in the response + setIsUploading(false); + setUploadProgress(100); + setUploadTimeRemaining(null); + uploadStartTimeRef.current = null; + + // Handle ZIP extraction results - multiple blend files found + if (result.status === 'select_blend' || (result.zip_extracted && result.blend_files && result.blend_files.length > 1)) { + setBlendFiles(result.blend_files || []); setMetadataStatus('select_blend'); return; } - - // Upload and processing complete - setIsUploading(false); // If metadata was extracted, use it if (result.metadata_extracted && result.metadata) { @@ -286,6 +260,7 @@ export default function JobSubmission({ onSuccess }) { ...result.metadata.render_settings, engine_settings: result.metadata.render_settings.engine_settings || {}, } : null, + blender_version: result.metadata.blender_version || prev.blender_version, })); } else { setMetadataStatus('error'); @@ -323,36 +298,30 @@ export default function JobSubmission({ onSuccess }) { // Re-upload with selected main blend file const result = await jobs.uploadFileForJobCreation(file, (progress) => { - // XHR progress as fallback, but WebSocket is primary + // Show upload progress during upload setUploadProgress(progress); - // Calculate time remaining for XHR progress + // Calculate time remaining for upload progress if (progress > 0 && progress < 100 && uploadStartTimeRef.current) { const elapsed = (Date.now() - uploadStartTimeRef.current) / 1000; // seconds const remaining = (elapsed / progress) * (100 - progress); setUploadTimeRemaining(remaining); + } else if (progress >= 100) { + // Upload complete - switch to processing status + setUploadProgress(100); + setMetadataStatus('processing'); + setUploadTimeRemaining(null); } }, selectedMainBlend); setBlendFiles([]); - // Store session ID and subscribe to upload progress + // Store session ID if (result.session_id) { setUploadSessionId(result.session_id); - - // Subscribe to upload progress channel - if (wsManager.getReadyState() === WebSocket.OPEN) { - const channel = `upload:${result.session_id}`; - // Don't subscribe if already subscribed or pending - if (!subscribedChannelsRef.current.has(channel) && !pendingSubscriptionsRef.current.has(channel)) { - wsManager.send({ type: 'subscribe', channel }); - pendingSubscriptionsRef.current.add(channel); - console.log('Subscribing to upload channel:', channel); - } - } } - // Upload and processing complete - setIsUploading(false); + // Upload and processing complete - metadata is in the response + setIsUploading(false); // If metadata was extracted, use it if (result.metadata_extracted && result.metadata) { @@ -382,6 +351,7 @@ export default function JobSubmission({ onSuccess }) { ...result.metadata.render_settings, engine_settings: result.metadata.render_settings.engine_settings || {}, } : null, + blender_version: result.metadata.blender_version || prev.blender_version, })); } else { setMetadataStatus('error'); @@ -477,11 +447,13 @@ export default function JobSubmission({ onSuccess }) { frame_start: parseInt(formData.frame_start), frame_end: parseInt(formData.frame_end), output_format: formData.output_format, - allow_parallel_runners: formData.allow_parallel_runners, render_settings: renderSettings, upload_session_id: uploadSessionId || undefined, // Pass session ID to move context archive unhide_objects: formData.unhide_objects || undefined, // Pass unhide toggle enable_execution: formData.enable_execution || undefined, // Pass enable execution toggle + preserve_hdr: formData.preserve_hdr || undefined, // Pass preserve HDR toggle + preserve_alpha: formData.preserve_alpha || undefined, // Pass preserve alpha toggle + blender_version: formData.blender_version || undefined, // Pass Blender version override }); // Fetch the full job details @@ -508,10 +480,12 @@ export default function JobSubmission({ onSuccess }) { frame_start: 1, frame_end: 10, output_format: 'PNG', - allow_parallel_runners: true, render_settings: null, unhide_objects: false, enable_execution: false, + blender_version: '', + preserve_hdr: false, + preserve_alpha: false, }); setShowAdvancedSettings(false); formatManuallyChangedRef.current = false; @@ -534,6 +508,7 @@ export default function JobSubmission({ onSuccess }) { render_settings: null, unhide_objects: false, enable_execution: false, + blender_version: '', }); setShowAdvancedSettings(false); setFile(null); @@ -672,20 +647,9 @@ export default function JobSubmission({ onSuccess }) {
) : metadataStatus === 'processing' ? ( -
-
- Processing file and extracting metadata... - {Math.round(uploadProgress)}% -
-
-
-
-
- This may take a moment for large files... -
+
+
+ Processing file and extracting metadata...
) : (
@@ -868,20 +832,35 @@ export default function JobSubmission({ onSuccess }) { +
-
- setFormData({ ...formData, allow_parallel_runners: e.target.checked })} - className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-900 rounded" - /> -
+ {(formData.output_format === 'EXR_264_MP4' || formData.output_format === 'EXR_AV1_MP4' || formData.output_format === 'EXR_VP9_WEBM') && ( + <> +
+

+ Note: The preserve options below allow you to explicitly control HDR and alpha preservation. If autodetection finds HDR content or alpha channels in your EXR files, they will be automatically preserved even if these options are unchecked. Important: Alpha detection only checks the first frame, so if your render uses transparency later in the sequence, you should explicitly enable the preserve alpha option. HDR detection is not perfect and may miss some HDR content, so if you're certain your render contains HDR content, you should explicitly enable the preserve HDR option. +

+
+
+
+ setFormData({ ...formData, preserve_hdr: e.target.checked })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-900 rounded" + /> + +
+
+ + )} + + {(formData.output_format === 'EXR_AV1_MP4' || formData.output_format === 'EXR_VP9_WEBM') && ( +
+
+ setFormData({ ...formData, preserve_alpha: e.target.checked })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-900 rounded" + /> + +
+
+ )} + {metadata && metadataStatus === 'completed' && ( <>
diff --git a/web/src/utils/api.js b/web/src/utils/api.js index ee050ad..c22665c 100644 --- a/web/src/utils/api.js +++ b/web/src/utils/api.js @@ -376,6 +376,10 @@ export const jobs = { return `${API_BASE}/jobs/${jobId}/files/${fileId}/download`; }, + previewEXR(jobId, fileId) { + return `${API_BASE}/jobs/${jobId}/files/${fileId}/preview-exr`; + }, + getVideoUrl(jobId) { return `${API_BASE}/jobs/${jobId}/video`; }, diff --git a/web/src/utils/websocket.js b/web/src/utils/websocket.js index 680c351..223b802 100644 --- a/web/src/utils/websocket.js +++ b/web/src/utils/websocket.js @@ -10,6 +10,11 @@ class WebSocketManager { this.isConnecting = false; this.listenerIdCounter = 0; this.verboseLogging = false; // Set to true to enable verbose WebSocket logging + + // Track server-side channel subscriptions for re-subscription on reconnect + this.serverSubscriptions = new Set(); // Channels we want to be subscribed to + this.confirmedSubscriptions = new Set(); // Channels confirmed by server + this.pendingSubscriptions = new Set(); // Channels waiting for confirmation } connect() { @@ -37,6 +42,10 @@ class WebSocketManager { console.log('Shared WebSocket connected'); } this.isConnecting = false; + + // Re-subscribe to all channels that were previously subscribed + this.resubscribeToChannels(); + this.notifyListeners('open', {}); }; @@ -68,17 +77,24 @@ class WebSocketManager { } this.ws = null; this.isConnecting = false; + + // Clear confirmed/pending but keep serverSubscriptions for re-subscription + this.confirmedSubscriptions.clear(); + this.pendingSubscriptions.clear(); + this.notifyListeners('close', event); - // Always retry connection - if (this.reconnectTimeout) { - clearTimeout(this.reconnectTimeout); - } - this.reconnectTimeout = setTimeout(() => { - if (!this.ws || this.ws.readyState === WebSocket.CLOSED) { - this.connect(); + // Always retry connection if we have listeners + if (this.listeners.size > 0) { + if (this.reconnectTimeout) { + clearTimeout(this.reconnectTimeout); } - }, this.reconnectDelay); + this.reconnectTimeout = setTimeout(() => { + if (!this.ws || this.ws.readyState === WebSocket.CLOSED) { + this.connect(); + } + }, this.reconnectDelay); + } }; } catch (error) { console.error('Failed to create WebSocket:', error); @@ -159,6 +175,81 @@ class WebSocketManager { return this.ws ? this.ws.readyState : WebSocket.CLOSED; } + // Subscribe to a server-side channel (will be re-subscribed on reconnect) + subscribeToChannel(channel) { + if (this.serverSubscriptions.has(channel)) { + // Already subscribed or pending + return; + } + + this.serverSubscriptions.add(channel); + + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + if (!this.confirmedSubscriptions.has(channel) && !this.pendingSubscriptions.has(channel)) { + this.pendingSubscriptions.add(channel); + this.send({ type: 'subscribe', channel }); + if (this.verboseLogging) { + console.log('WebSocketManager: Subscribing to channel:', channel); + } + } + } + } + + // Unsubscribe from a server-side channel (won't be re-subscribed on reconnect) + unsubscribeFromChannel(channel) { + this.serverSubscriptions.delete(channel); + this.confirmedSubscriptions.delete(channel); + this.pendingSubscriptions.delete(channel); + + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + this.send({ type: 'unsubscribe', channel }); + if (this.verboseLogging) { + console.log('WebSocketManager: Unsubscribing from channel:', channel); + } + } + } + + // Mark a channel subscription as confirmed (call this when server confirms) + confirmSubscription(channel) { + this.pendingSubscriptions.delete(channel); + this.confirmedSubscriptions.add(channel); + if (this.verboseLogging) { + console.log('WebSocketManager: Subscription confirmed for channel:', channel); + } + } + + // Mark a channel subscription as failed (call this when server rejects) + failSubscription(channel) { + this.pendingSubscriptions.delete(channel); + this.serverSubscriptions.delete(channel); + if (this.verboseLogging) { + console.log('WebSocketManager: Subscription failed for channel:', channel); + } + } + + // Check if subscribed to a channel + isSubscribedToChannel(channel) { + return this.confirmedSubscriptions.has(channel); + } + + // Re-subscribe to all channels after reconnect + resubscribeToChannels() { + if (this.serverSubscriptions.size === 0) { + return; + } + + if (this.verboseLogging) { + console.log('WebSocketManager: Re-subscribing to channels:', Array.from(this.serverSubscriptions)); + } + + for (const channel of this.serverSubscriptions) { + if (!this.pendingSubscriptions.has(channel)) { + this.pendingSubscriptions.add(channel); + this.send({ type: 'subscribe', channel }); + } + } + } + disconnect() { if (this.reconnectTimeout) { clearTimeout(this.reconnectTimeout); @@ -169,6 +260,9 @@ class WebSocketManager { this.ws = null; } this.listeners.clear(); + this.serverSubscriptions.clear(); + this.confirmedSubscriptions.clear(); + this.pendingSubscriptions.clear(); } }