diff --git a/.gitignore b/.gitignore index c9c579f..f21d121 100644 --- a/.gitignore +++ b/.gitignore @@ -69,6 +69,7 @@ lerna-debug.log* # Logs *.log +/logs/ # OS files Thumbs.db diff --git a/Makefile b/Makefile index 9a7a7e2..d6c9970 100644 --- a/Makefile +++ b/Makefile @@ -1,114 +1,66 @@ -.PHONY: build build-manager build-runner build-web run-manager run-runner run cleanup cleanup-manager cleanup-runner kill-all clean test help +.PHONY: build build-web run run-manager run-runner cleanup cleanup-manager cleanup-runner kill-all clean-bin clean-web test help install -# Build all components -build: clean-bin build-manager build-runner +# Build the jiggablend binary (includes embedded web UI) +build: clean-bin build-web + go build -o bin/jiggablend ./cmd/jiggablend -# Build manager -build-manager: clean-bin build-web - go build -o bin/manager ./cmd/manager - -# Build runner -build-runner: clean-bin - GOOS=linux GOARCH=amd64 go build -o bin/runner ./cmd/runner +# Build for Linux (cross-compile) +build-linux: clean-bin build-web + GOOS=linux GOARCH=amd64 go build -o bin/jiggablend ./cmd/jiggablend # Build web UI build-web: clean-web cd web && npm install && npm run build -# Cleanup manager (database and storage) +# Cleanup manager logs cleanup-manager: - @echo "Cleaning up manager database and storage..." - @rm -f jiggablend.db 2>/dev/null || true - @rm -f jiggablend.db-shm 2>/dev/null || true - @rm -f jiggablend.db-wal 2>/dev/null || true - @rm -rf jiggablend-storage 2>/dev/null || true + @echo "Cleaning up manager logs..." + @rm -rf logs/manager.log 2>/dev/null || true @echo "Manager cleanup complete" -# Cleanup runner workspaces +# Cleanup runner logs cleanup-runner: - @echo "Cleaning up runner workspaces..." - @rm -rf jiggablend-workspaces jiggablend-workspace* *workspace* 2>/dev/null || true + @echo "Cleaning up runner logs..." + @rm -rf logs/runner*.log 2>/dev/null || true @echo "Runner cleanup complete" -# Cleanup both manager and runner +# Cleanup both manager and runner logs cleanup: cleanup-manager cleanup-runner -# Kill all manager and runner processes +# Kill all jiggablend processes kill-all: - @echo "Killing all manager and runner processes..." - @# Kill manager processes (compiled binaries in bin/, root, and go run) - @-pkill -f "bin/manager" 2>/dev/null || true - @-pkill -f "\./manager" 2>/dev/null || true - @-pkill -f "manager" 2>/dev/null || true - @-pkill -f "main.*cmd/manager" 2>/dev/null || true - @-pkill -f "go run.*cmd/manager" 2>/dev/null || true - @# Kill runner processes (compiled binaries in bin/, root, and go run) - @-pkill -f "bin/runner" 2>/dev/null || true - @-pkill -f "\./runner" 2>/dev/null || true - @-pkill -f "runner" 2>/dev/null || true - @-pkill -f "main.*cmd/runner" 2>/dev/null || true - @-pkill -f "go run.*cmd/runner" 2>/dev/null || true - @# Wait a moment for graceful shutdown - @echo "Waiting for 5 seconds for graceful shutdown..." - @sleep 1 - @echo "5" - @sleep 1 - @echo "4" - @sleep 1 - @echo "3" - @sleep 1 - @echo "2" - @sleep 1 - @echo "1" - @sleep 1 - @echo "0" + @echo "Not implemented" - @# Check if any manager or runner processes are still running - @MANAGER_COUNT=$$(pgrep -f "bin/manager\|\./manager\|manager\|main.*cmd/manager\|go run.*cmd/manager" | wc -l); \ - RUNNER_COUNT=$$(pgrep -f "bin/runner\|\./runner\|runner\|main.*cmd/runner\|go run.*cmd/runner" | wc -l); \ - if [ $$MANAGER_COUNT -eq 0 ] && [ $$RUNNER_COUNT -eq 0 ]; then \ - echo "All manager and runner processes have shut down gracefully"; \ - exit 0; \ - else \ - echo "Some processes still running ($$MANAGER_COUNT managers, $$RUNNER_COUNT runners), proceeding with force kill..."; \ - fi - @# Force kill any remaining processes - @-pkill -9 -f "bin/manager" 2>/dev/null || true - @-pkill -9 -f "\./manager" 2>/dev/null || true - @-pkill -9 -f "main.*cmd/manager" 2>/dev/null || true - @-pkill -9 -f "go run.*cmd/manager" 2>/dev/null || true - @-pkill -9 -f "bin/runner" 2>/dev/null || true - @-pkill -9 -f "\./runner" 2>/dev/null || true - @-pkill -9 -f "main.*cmd/runner" 2>/dev/null || true - @-pkill -9 -f "go run.*cmd/runner" 2>/dev/null || true - @echo "All manager and runner processes killed after 5 seconds" - -# Run all parallel -run: cleanup-manager cleanup-runner build-manager build-runner +# Run manager and runner in parallel (for testing) +run: cleanup build init-test @echo "Starting manager and runner in parallel..." @echo "Press Ctrl+C to stop both..." - @echo "Note: This will create a test API key for the runner to use" @trap 'kill $$MANAGER_PID $$RUNNER_PID 2>/dev/null; exit' INT TERM; \ - FIXED_API_KEY=jk_r0_test_key_123456789012345678901234567890 ENABLE_LOCAL_AUTH=true LOCAL_TEST_EMAIL=test@example.com LOCAL_TEST_PASSWORD=testpassword bin/manager & \ + bin/jiggablend manager & \ MANAGER_PID=$$!; \ sleep 2; \ - API_KEY=jk_r0_test_key_123456789012345678901234567890 bin/runner & \ + bin/jiggablend runner --api-key=jk_r0_test_key_123456789012345678901234567890 & \ RUNNER_PID=$$!; \ wait $$MANAGER_PID $$RUNNER_PID -# Run manager with test API key -# Note: ENABLE_LOCAL_AUTH enables local user registration/login -# LOCAL_TEST_EMAIL and LOCAL_TEST_PASSWORD create a test user on startup (if it doesn't exist) -# FIXED_API_KEY provides a pre-configured API key for testing (jk_r0_... format) -# The manager will accept this API key for runner registration -run-manager: cleanup-manager build-manager - FIXED_API_KEY=jk_r0_test_key_123456789012345678901234567890 ENABLE_LOCAL_AUTH=true LOCAL_TEST_EMAIL=test@example.com LOCAL_TEST_PASSWORD=testpassword bin/manager +# Run manager server +run-manager: cleanup-manager build init-test + bin/jiggablend manager -# Run runner with test API key -# Note: API_KEY must match what the manager accepts (see run-manager) -# The runner will use this API key for all authentication with the manager -run-runner: cleanup-runner build-runner - API_KEY=jk_r0_test_key_123456789012345678901234567890 bin/runner +# Run runner +run-runner: cleanup-runner build + bin/jiggablend runner --api-key=jk_r0_test_key_123456789012345678901234567890 + +# Initialize for testing (first run setup) +init-test: build + @echo "Initializing test configuration..." + bin/jiggablend manager config enable localauth + bin/jiggablend manager config set fixed-apikey jk_r0_test_key_123456789012345678901234567890 -f -y + bin/jiggablend manager config add user test@example.com testpassword --admin -f -y + @echo "Test configuration complete!" + @echo "fixed api key: jk_r0_test_key_123456789012345678901234567890" + @echo "test user: test@example.com" + @echo "test password: testpassword" # Clean bin build artifacts clean-bin: @@ -122,39 +74,45 @@ clean-web: test: go test ./... -timeout 30s +# Install to /usr/local/bin +install: build + sudo cp bin/jiggablend /usr/local/bin/ + # Show help help: @echo "Jiggablend Build and Run Makefile" @echo "" @echo "Build targets:" - @echo " build - Build manager, runner, and web UI" - @echo " build-manager - Build manager with web UI" - @echo " build-runner - Build runner binary" - @echo " build-web - Build web UI" + @echo " build - Build jiggablend binary with embedded web UI" + @echo " build-linux - Cross-compile for Linux amd64" + @echo " build-web - Build web UI only" @echo "" @echo "Run targets:" - @echo " run - Run manager and runner in parallel with test API key" - @echo " run-manager - Run manager with test API key enabled" + @echo " run - Run manager and runner in parallel (for testing)" + @echo " run-manager - Run manager server" @echo " run-runner - Run runner with test API key" + @echo " init-test - Initialize test configuration (run once)" @echo "" @echo "Cleanup targets:" - @echo " cleanup - Clean manager and runner data" - @echo " cleanup-manager - Clean manager database and storage" - @echo " cleanup-runner - Clean runner workspaces and API keys" + @echo " cleanup - Clean all logs" + @echo " cleanup-manager - Clean manager logs" + @echo " cleanup-runner - Clean runner logs" + @echo " kill-all - Kill all running jiggablend processes" @echo "" @echo "Other targets:" - @echo " clean - Clean build artifacts" - @echo " kill-all - Kill all running manager and runner processes (binaries in bin/, root, or go run)" + @echo " clean-bin - Clean build artifacts" + @echo " clean-web - Clean web build artifacts" @echo " test - Run Go tests" + @echo " install - Install to /usr/local/bin" @echo " help - Show this help" @echo "" - @echo "API Key System:" - @echo " - FIXED_API_KEY: Pre-configured API key for manager (optional)" - @echo " - API_KEY: API key for runner authentication" - @echo " - Format: jk_r{N}_{32-char-hex}" - @echo " - Generate via admin UI or set FIXED_API_KEY for testing" - @echo "" - @echo "Timeouts:" - @echo " - Use 'timeout make run' to prevent hanging during testing" - @echo " - Example: timeout 30s make run" - + @echo "CLI Usage:" + @echo " jiggablend manager serve - Start the manager server" + @echo " jiggablend runner - Start a runner" + @echo " jiggablend manager config show - Show configuration" + @echo " jiggablend manager config enable localauth" + @echo " jiggablend manager config add user --email=x --password=y" + @echo " jiggablend manager config add apikey --name=mykey" + @echo " jiggablend manager config set fixed-apikey " + @echo " jiggablend manager config list users" + @echo " jiggablend manager config list apikeys" diff --git a/cmd/jiggablend/cmd/manager.go b/cmd/jiggablend/cmd/manager.go new file mode 100644 index 0000000..96ac197 --- /dev/null +++ b/cmd/jiggablend/cmd/manager.go @@ -0,0 +1,152 @@ +package cmd + +import ( + "fmt" + "net/http" + "os/exec" + "strings" + + "jiggablend/internal/api" + "jiggablend/internal/auth" + "jiggablend/internal/config" + "jiggablend/internal/database" + "jiggablend/internal/logger" + "jiggablend/internal/storage" + + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +var managerCmd = &cobra.Command{ + Use: "manager", + Short: "Start the Jiggablend manager server", + Long: `Start the Jiggablend manager server to coordinate render jobs.`, + Run: runManager, +} + +func init() { + rootCmd.AddCommand(managerCmd) + + // Flags with env binding via viper + managerCmd.Flags().StringP("port", "p", "8080", "Server port") + managerCmd.Flags().String("db", "jiggablend.db", "Database path") + managerCmd.Flags().String("storage", "./jiggablend-storage", "Storage path") + managerCmd.Flags().StringP("log-file", "l", "", "Log file path (truncated on start, if not set logs only to stdout)") + managerCmd.Flags().String("log-level", "info", "Log level (debug, info, warn, error)") + managerCmd.Flags().BoolP("verbose", "v", false, "Enable verbose logging (same as --log-level=debug)") + + // Bind flags to viper with JIGGABLEND_ prefix + viper.SetEnvPrefix("JIGGABLEND") + viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) + viper.AutomaticEnv() + + viper.BindPFlag("port", managerCmd.Flags().Lookup("port")) + viper.BindPFlag("db", managerCmd.Flags().Lookup("db")) + viper.BindPFlag("storage", managerCmd.Flags().Lookup("storage")) + viper.BindPFlag("log_file", managerCmd.Flags().Lookup("log-file")) + viper.BindPFlag("log_level", managerCmd.Flags().Lookup("log-level")) + viper.BindPFlag("verbose", managerCmd.Flags().Lookup("verbose")) +} + +func runManager(cmd *cobra.Command, args []string) { + // Get config values (flags take precedence over env vars) + port := viper.GetString("port") + dbPath := viper.GetString("db") + storagePath := viper.GetString("storage") + logFile := viper.GetString("log_file") + logLevel := viper.GetString("log_level") + verbose := viper.GetBool("verbose") + + // Initialize logger + if logFile != "" { + if err := logger.InitWithFile(logFile); err != nil { + logger.Fatalf("Failed to initialize logger: %v", err) + } + defer func() { + if l := logger.GetDefault(); l != nil { + l.Close() + } + }() + } else { + logger.InitStdout() + } + + // Set log level + if verbose { + logger.SetLevel(logger.LevelDebug) + } else { + logger.SetLevel(logger.ParseLevel(logLevel)) + } + + if logFile != "" { + logger.Infof("Logging to file: %s", logFile) + } + logger.Debugf("Log level: %s", logLevel) + + // Initialize database + db, err := database.NewDB(dbPath) + if err != nil { + logger.Fatalf("Failed to initialize database: %v", err) + } + defer db.Close() + + // Initialize config from database + cfg := config.NewConfig(db) + if err := cfg.InitializeFromEnv(); err != nil { + logger.Fatalf("Failed to initialize config: %v", err) + } + logger.Info("Configuration loaded from database") + + // Initialize auth + authHandler, err := auth.NewAuth(db, cfg) + if err != nil { + logger.Fatalf("Failed to initialize auth: %v", err) + } + + // Initialize storage + storageHandler, err := storage.NewStorage(storagePath) + if err != nil { + logger.Fatalf("Failed to initialize storage: %v", err) + } + + // Check if Blender is available + if err := checkBlenderAvailable(); err != nil { + logger.Fatalf("Blender is not available: %v\n"+ + "The manager requires Blender to be installed and in PATH for metadata extraction.\n"+ + "Please install Blender and ensure it's accessible via the 'blender' command.", err) + } + logger.Info("Blender is available") + + // Create API server + server, err := api.NewServer(db, cfg, authHandler, storageHandler) + if err != nil { + logger.Fatalf("Failed to create server: %v", err) + } + + // Start server + addr := fmt.Sprintf(":%s", port) + logger.Infof("Starting manager server on %s", addr) + logger.Infof("Database: %s", dbPath) + logger.Infof("Storage: %s", storagePath) + + httpServer := &http.Server{ + Addr: addr, + Handler: server, + MaxHeaderBytes: 1 << 20, + ReadTimeout: 0, + WriteTimeout: 0, + } + + if err := httpServer.ListenAndServe(); err != nil { + logger.Fatalf("Server failed: %v", err) + } +} + +func checkBlenderAvailable() error { + cmd := exec.Command("blender", "--version") + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to run 'blender --version': %w (output: %s)", err, string(output)) + } + return nil +} diff --git a/cmd/jiggablend/cmd/managerconfig.go b/cmd/jiggablend/cmd/managerconfig.go new file mode 100644 index 0000000..5bf3b4c --- /dev/null +++ b/cmd/jiggablend/cmd/managerconfig.go @@ -0,0 +1,621 @@ +package cmd + +import ( + "bufio" + "crypto/rand" + "crypto/sha256" + "database/sql" + "encoding/hex" + "fmt" + "os" + "strings" + + "jiggablend/internal/config" + "jiggablend/internal/database" + + "github.com/spf13/cobra" + "golang.org/x/crypto/bcrypt" +) + +var ( + configDBPath string + configYes bool // Auto-confirm prompts + configForce bool // Force override existing +) + +var configCmd = &cobra.Command{ + Use: "config", + Short: "Configure the manager", + Long: `Configure the Jiggablend manager settings stored in the database.`, +} + +// --- Enable/Disable commands --- + +var enableCmd = &cobra.Command{ + Use: "enable", + Short: "Enable a feature", +} + +var disableCmd = &cobra.Command{ + Use: "disable", + Short: "Disable a feature", +} + +var enableLocalAuthCmd = &cobra.Command{ + Use: "localauth", + Short: "Enable local authentication", + Run: func(cmd *cobra.Command, args []string) { + withConfig(func(cfg *config.Config, db *database.DB) { + if err := cfg.SetBool(config.KeyEnableLocalAuth, true); err != nil { + exitWithError("Failed to enable local auth: %v", err) + } + fmt.Println("Local authentication enabled") + }) + }, +} + +var disableLocalAuthCmd = &cobra.Command{ + Use: "localauth", + Short: "Disable local authentication", + Run: func(cmd *cobra.Command, args []string) { + withConfig(func(cfg *config.Config, db *database.DB) { + if err := cfg.SetBool(config.KeyEnableLocalAuth, false); err != nil { + exitWithError("Failed to disable local auth: %v", err) + } + fmt.Println("Local authentication disabled") + }) + }, +} + +var enableRegistrationCmd = &cobra.Command{ + Use: "registration", + Short: "Enable user registration", + Run: func(cmd *cobra.Command, args []string) { + withConfig(func(cfg *config.Config, db *database.DB) { + if err := cfg.SetBool(config.KeyRegistrationEnabled, true); err != nil { + exitWithError("Failed to enable registration: %v", err) + } + fmt.Println("User registration enabled") + }) + }, +} + +var disableRegistrationCmd = &cobra.Command{ + Use: "registration", + Short: "Disable user registration", + Run: func(cmd *cobra.Command, args []string) { + withConfig(func(cfg *config.Config, db *database.DB) { + if err := cfg.SetBool(config.KeyRegistrationEnabled, false); err != nil { + exitWithError("Failed to disable registration: %v", err) + } + fmt.Println("User registration disabled") + }) + }, +} + +var enableProductionCmd = &cobra.Command{ + Use: "production", + Short: "Enable production mode", + Run: func(cmd *cobra.Command, args []string) { + withConfig(func(cfg *config.Config, db *database.DB) { + if err := cfg.SetBool(config.KeyProductionMode, true); err != nil { + exitWithError("Failed to enable production mode: %v", err) + } + fmt.Println("Production mode enabled") + }) + }, +} + +var disableProductionCmd = &cobra.Command{ + Use: "production", + Short: "Disable production mode", + Run: func(cmd *cobra.Command, args []string) { + withConfig(func(cfg *config.Config, db *database.DB) { + if err := cfg.SetBool(config.KeyProductionMode, false); err != nil { + exitWithError("Failed to disable production mode: %v", err) + } + fmt.Println("Production mode disabled") + }) + }, +} + +// --- Add commands --- + +var addCmd = &cobra.Command{ + Use: "add", + Short: "Add a resource", +} + +var ( + addUserName string + addUserAdmin bool +) + +var addUserCmd = &cobra.Command{ + Use: "user ", + Short: "Add a local user", + Long: `Add a new local user account to the database.`, + Args: cobra.ExactArgs(2), + Run: func(cmd *cobra.Command, args []string) { + email := args[0] + password := args[1] + + name := addUserName + if name == "" { + // Use email prefix as name + if atIndex := strings.Index(email, "@"); atIndex > 0 { + name = email[:atIndex] + } else { + name = email + } + } + if len(password) < 8 { + exitWithError("Password must be at least 8 characters") + } + + withConfig(func(cfg *config.Config, db *database.DB) { + // Check if user exists + var exists bool + err := db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ?)", email).Scan(&exists) + }) + if err != nil { + exitWithError("Failed to check user: %v", err) + } + isAdmin := addUserAdmin + if exists { + if !configForce { + exitWithError("User with email %s already exists (use -f to override)", email) + } + // Confirm override + if !configYes && !confirm(fmt.Sprintf("User %s already exists. Override?", email)) { + fmt.Println("Aborted") + return + } + // Update existing user + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + exitWithError("Failed to hash password: %v", err) + } + err = db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + "UPDATE users SET name = ?, password_hash = ?, is_admin = ? WHERE email = ?", + name, string(hashedPassword), isAdmin, email, + ) + return err + }) + if err != nil { + exitWithError("Failed to update user: %v", err) + } + fmt.Printf("Updated user: %s (admin: %v)\n", email, isAdmin) + return + } + + // Hash password + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + exitWithError("Failed to hash password: %v", err) + } + + // Check if first user (make admin) + var userCount int + db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount) + }) + if userCount == 0 { + isAdmin = true + } + + // Confirm creation + if !configYes && !confirm(fmt.Sprintf("Create user %s (admin: %v)?", email, isAdmin)) { + fmt.Println("Aborted") + return + } + + // Create user + err = db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + "INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?)", + email, name, email, string(hashedPassword), isAdmin, + ) + return err + }) + if err != nil { + exitWithError("Failed to create user: %v", err) + } + + fmt.Printf("Created user: %s (admin: %v)\n", email, isAdmin) + }) + }, +} + +var addAPIKeyScope string + +var addAPIKeyCmd = &cobra.Command{ + Use: "apikey [name]", + Short: "Add a runner API key", + Long: `Generate a new API key for runner authentication.`, + Args: cobra.MaximumNArgs(1), + Run: func(cmd *cobra.Command, args []string) { + name := "cli-generated" + if len(args) > 0 { + name = args[0] + } + + withConfig(func(cfg *config.Config, db *database.DB) { + // Check if API key with same name exists + var exists bool + err := db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM runner_api_keys WHERE name = ?)", name).Scan(&exists) + }) + if err != nil { + exitWithError("Failed to check API key: %v", err) + } + if exists { + if !configForce { + exitWithError("API key with name %s already exists (use -f to create another)", name) + } + if !configYes && !confirm(fmt.Sprintf("API key named '%s' already exists. Create another?", name)) { + fmt.Println("Aborted") + return + } + } + + // Confirm creation + if !configYes && !confirm(fmt.Sprintf("Generate new API key '%s' (scope: %s)?", name, addAPIKeyScope)) { + fmt.Println("Aborted") + return + } + + // Generate API key + key, keyPrefix, keyHash, err := generateAPIKey() + if err != nil { + exitWithError("Failed to generate API key: %v", err) + } + + // Get first user ID for created_by (or use 0 if no users) + var createdBy int64 + db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT id FROM users ORDER BY id ASC LIMIT 1").Scan(&createdBy) + }) + + // Store in database + err = db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + `INSERT INTO runner_api_keys (key_prefix, key_hash, name, scope, is_active, created_by) + VALUES (?, ?, ?, ?, true, ?)`, + keyPrefix, keyHash, name, addAPIKeyScope, createdBy, + ) + return err + }) + if err != nil { + exitWithError("Failed to store API key: %v", err) + } + + fmt.Printf("Generated API key: %s\n", key) + fmt.Printf("Name: %s, Scope: %s\n", name, addAPIKeyScope) + fmt.Println("\nSave this key - it cannot be retrieved later!") + }) + }, +} + +// --- Set commands --- + +var setCmd = &cobra.Command{ + Use: "set", + Short: "Set a configuration value", +} + +var setFixedAPIKeyCmd = &cobra.Command{ + Use: "fixed-apikey [key]", + Short: "Set a fixed API key for testing", + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + withConfig(func(cfg *config.Config, db *database.DB) { + // Check if already set + existing := cfg.FixedAPIKey() + if existing != "" && !configForce { + exitWithError("Fixed API key already set (use -f to override)") + } + if existing != "" && !configYes && !confirm("Fixed API key already set. Override?") { + fmt.Println("Aborted") + return + } + if err := cfg.Set(config.KeyFixedAPIKey, args[0]); err != nil { + exitWithError("Failed to set fixed API key: %v", err) + } + fmt.Println("Fixed API key set") + }) + }, +} + +var setAllowedOriginsCmd = &cobra.Command{ + Use: "allowed-origins [origins]", + Short: "Set allowed CORS origins (comma-separated)", + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + withConfig(func(cfg *config.Config, db *database.DB) { + if err := cfg.Set(config.KeyAllowedOrigins, args[0]); err != nil { + exitWithError("Failed to set allowed origins: %v", err) + } + fmt.Printf("Allowed origins set to: %s\n", args[0]) + }) + }, +} + +var setGoogleOAuthRedirectURL string + +var setGoogleOAuthCmd = &cobra.Command{ + Use: "google-oauth ", + Short: "Set Google OAuth credentials", + Args: cobra.ExactArgs(2), + Run: func(cmd *cobra.Command, args []string) { + clientID := args[0] + clientSecret := args[1] + + withConfig(func(cfg *config.Config, db *database.DB) { + // Check if already configured + existing := cfg.GoogleClientID() + if existing != "" && !configForce { + exitWithError("Google OAuth already configured (use -f to override)") + } + if existing != "" && !configYes && !confirm("Google OAuth already configured. Override?") { + fmt.Println("Aborted") + return + } + if err := cfg.Set(config.KeyGoogleClientID, clientID); err != nil { + exitWithError("Failed to set Google client ID: %v", err) + } + if err := cfg.Set(config.KeyGoogleClientSecret, clientSecret); err != nil { + exitWithError("Failed to set Google client secret: %v", err) + } + if setGoogleOAuthRedirectURL != "" { + if err := cfg.Set(config.KeyGoogleRedirectURL, setGoogleOAuthRedirectURL); err != nil { + exitWithError("Failed to set Google redirect URL: %v", err) + } + } + fmt.Println("Google OAuth configured") + }) + }, +} + +var setDiscordOAuthRedirectURL string + +var setDiscordOAuthCmd = &cobra.Command{ + Use: "discord-oauth ", + Short: "Set Discord OAuth credentials", + Args: cobra.ExactArgs(2), + Run: func(cmd *cobra.Command, args []string) { + clientID := args[0] + clientSecret := args[1] + + withConfig(func(cfg *config.Config, db *database.DB) { + // Check if already configured + existing := cfg.DiscordClientID() + if existing != "" && !configForce { + exitWithError("Discord OAuth already configured (use -f to override)") + } + if existing != "" && !configYes && !confirm("Discord OAuth already configured. Override?") { + fmt.Println("Aborted") + return + } + if err := cfg.Set(config.KeyDiscordClientID, clientID); err != nil { + exitWithError("Failed to set Discord client ID: %v", err) + } + if err := cfg.Set(config.KeyDiscordClientSecret, clientSecret); err != nil { + exitWithError("Failed to set Discord client secret: %v", err) + } + if setDiscordOAuthRedirectURL != "" { + if err := cfg.Set(config.KeyDiscordRedirectURL, setDiscordOAuthRedirectURL); err != nil { + exitWithError("Failed to set Discord redirect URL: %v", err) + } + } + fmt.Println("Discord OAuth configured") + }) + }, +} + +// --- Show command --- + +var showCmd = &cobra.Command{ + Use: "show", + Short: "Show current configuration", + Run: func(cmd *cobra.Command, args []string) { + withConfig(func(cfg *config.Config, db *database.DB) { + all, err := cfg.GetAll() + if err != nil { + exitWithError("Failed to get config: %v", err) + } + + if len(all) == 0 { + fmt.Println("No configuration stored") + return + } + + fmt.Println("Current configuration:") + fmt.Println("----------------------") + for key, value := range all { + // Redact sensitive values + if strings.Contains(key, "secret") || strings.Contains(key, "api_key") || strings.Contains(key, "password") { + fmt.Printf(" %s: [REDACTED]\n", key) + } else { + fmt.Printf(" %s: %s\n", key, value) + } + } + }) + }, +} + +// --- List commands --- + +var listCmd = &cobra.Command{ + Use: "list", + Short: "List resources", +} + +var listUsersCmd = &cobra.Command{ + Use: "users", + Short: "List all users", + Run: func(cmd *cobra.Command, args []string) { + withConfig(func(cfg *config.Config, db *database.DB) { + var rows *sql.Rows + err := db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query("SELECT id, email, name, oauth_provider, is_admin, created_at FROM users ORDER BY id") + return err + }) + if err != nil { + exitWithError("Failed to list users: %v", err) + } + defer rows.Close() + + fmt.Printf("%-6s %-30s %-20s %-10s %-6s %s\n", "ID", "Email", "Name", "Provider", "Admin", "Created") + fmt.Println(strings.Repeat("-", 100)) + + for rows.Next() { + var id int64 + var email, name, provider string + var isAdmin bool + var createdAt string + if err := rows.Scan(&id, &email, &name, &provider, &isAdmin, &createdAt); err != nil { + continue + } + adminStr := "no" + if isAdmin { + adminStr = "yes" + } + fmt.Printf("%-6d %-30s %-20s %-10s %-6s %s\n", id, email, name, provider, adminStr, createdAt[:19]) + } + }) + }, +} + +var listAPIKeysCmd = &cobra.Command{ + Use: "apikeys", + Short: "List all API keys", + Run: func(cmd *cobra.Command, args []string) { + withConfig(func(cfg *config.Config, db *database.DB) { + var rows *sql.Rows + err := db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query("SELECT id, key_prefix, name, scope, is_active, created_at FROM runner_api_keys ORDER BY id") + return err + }) + if err != nil { + exitWithError("Failed to list API keys: %v", err) + } + defer rows.Close() + + fmt.Printf("%-6s %-12s %-20s %-10s %-8s %s\n", "ID", "Prefix", "Name", "Scope", "Active", "Created") + fmt.Println(strings.Repeat("-", 80)) + + for rows.Next() { + var id int64 + var prefix, name, scope string + var isActive bool + var createdAt string + if err := rows.Scan(&id, &prefix, &name, &scope, &isActive, &createdAt); err != nil { + continue + } + activeStr := "no" + if isActive { + activeStr = "yes" + } + fmt.Printf("%-6d %-12s %-20s %-10s %-8s %s\n", id, prefix, name, scope, activeStr, createdAt[:19]) + } + }) + }, +} + +func init() { + managerCmd.AddCommand(configCmd) + + // Global config flags + configCmd.PersistentFlags().StringVar(&configDBPath, "db", "jiggablend.db", "Database path") + configCmd.PersistentFlags().BoolVarP(&configYes, "yes", "y", false, "Auto-confirm prompts") + configCmd.PersistentFlags().BoolVarP(&configForce, "force", "f", false, "Force override existing") + + // Enable/Disable + configCmd.AddCommand(enableCmd) + configCmd.AddCommand(disableCmd) + enableCmd.AddCommand(enableLocalAuthCmd) + enableCmd.AddCommand(enableRegistrationCmd) + enableCmd.AddCommand(enableProductionCmd) + disableCmd.AddCommand(disableLocalAuthCmd) + disableCmd.AddCommand(disableRegistrationCmd) + disableCmd.AddCommand(disableProductionCmd) + + // Add + configCmd.AddCommand(addCmd) + addCmd.AddCommand(addUserCmd) + addUserCmd.Flags().StringVarP(&addUserName, "name", "n", "", "User display name") + addUserCmd.Flags().BoolVarP(&addUserAdmin, "admin", "a", false, "Make user an admin") + + addCmd.AddCommand(addAPIKeyCmd) + addAPIKeyCmd.Flags().StringVarP(&addAPIKeyScope, "scope", "s", "manager", "API key scope (manager or user)") + + // Set + configCmd.AddCommand(setCmd) + setCmd.AddCommand(setFixedAPIKeyCmd) + setCmd.AddCommand(setAllowedOriginsCmd) + setCmd.AddCommand(setGoogleOAuthCmd) + setCmd.AddCommand(setDiscordOAuthCmd) + + setGoogleOAuthCmd.Flags().StringVarP(&setGoogleOAuthRedirectURL, "redirect-url", "r", "", "Google OAuth redirect URL") + setDiscordOAuthCmd.Flags().StringVarP(&setDiscordOAuthRedirectURL, "redirect-url", "r", "", "Discord OAuth redirect URL") + + // Show + configCmd.AddCommand(showCmd) + + // List + configCmd.AddCommand(listCmd) + listCmd.AddCommand(listUsersCmd) + listCmd.AddCommand(listAPIKeysCmd) +} + +// withConfig opens the database and runs the callback with config access +func withConfig(fn func(cfg *config.Config, db *database.DB)) { + db, err := database.NewDB(configDBPath) + if err != nil { + exitWithError("Failed to open database: %v", err) + } + defer db.Close() + + cfg := config.NewConfig(db) + fn(cfg, db) +} + +// generateAPIKey generates a new API key +func generateAPIKey() (key, prefix, hash string, err error) { + randomBytes := make([]byte, 16) + if _, err := rand.Read(randomBytes); err != nil { + return "", "", "", err + } + randomStr := hex.EncodeToString(randomBytes) + + prefixDigit := make([]byte, 1) + if _, err := rand.Read(prefixDigit); err != nil { + return "", "", "", err + } + + prefix = fmt.Sprintf("jk_r%d", prefixDigit[0]%10) + key = fmt.Sprintf("%s_%s", prefix, randomStr) + + keyHash := sha256.Sum256([]byte(key)) + hash = hex.EncodeToString(keyHash[:]) + + return key, prefix, hash, nil +} + +// confirm prompts the user for confirmation +func confirm(prompt string) bool { + fmt.Printf("%s [y/N]: ", prompt) + reader := bufio.NewReader(os.Stdin) + response, err := reader.ReadString('\n') + if err != nil { + return false + } + response = strings.TrimSpace(strings.ToLower(response)) + return response == "y" || response == "yes" +} + diff --git a/cmd/jiggablend/cmd/root.go b/cmd/jiggablend/cmd/root.go new file mode 100644 index 0000000..21d14db --- /dev/null +++ b/cmd/jiggablend/cmd/root.go @@ -0,0 +1,35 @@ +package cmd + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" +) + +var rootCmd = &cobra.Command{ + Use: "jiggablend", + Short: "Jiggablend - Distributed Blender Render Farm", + Long: `Jiggablend is a distributed render farm for Blender. + +Run 'jiggablend manager' to start the manager server. +Run 'jiggablend runner' to start a render runner. +Run 'jiggablend manager config' to configure the manager.`, +} + +// Execute runs the root command +func Execute() error { + return rootCmd.Execute() +} + +func init() { + // Global flags can be added here if needed + rootCmd.CompletionOptions.DisableDefaultCmd = true +} + +// exitWithError prints an error and exits +func exitWithError(msg string, args ...interface{}) { + fmt.Fprintf(os.Stderr, "Error: "+msg+"\n", args...) + os.Exit(1) +} + diff --git a/cmd/jiggablend/cmd/runner.go b/cmd/jiggablend/cmd/runner.go new file mode 100644 index 0000000..e9df31d --- /dev/null +++ b/cmd/jiggablend/cmd/runner.go @@ -0,0 +1,211 @@ +package cmd + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "jiggablend/internal/logger" + "jiggablend/internal/runner" + + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +var runnerViper = viper.New() + +var runnerCmd = &cobra.Command{ + Use: "runner", + Short: "Start the Jiggablend render runner", + Long: `Start the Jiggablend render runner that connects to a manager and processes render tasks.`, + Run: runRunner, +} + +func init() { + rootCmd.AddCommand(runnerCmd) + + runnerCmd.Flags().StringP("manager", "m", "http://localhost:8080", "Manager URL") + runnerCmd.Flags().StringP("name", "n", "", "Runner name") + runnerCmd.Flags().String("hostname", "", "Runner hostname") + runnerCmd.Flags().StringP("api-key", "k", "", "API key for authentication") + runnerCmd.Flags().StringP("log-file", "l", "", "Log file path (truncated on start, if not set logs only to stdout)") + runnerCmd.Flags().String("log-level", "info", "Log level (debug, info, warn, error)") + runnerCmd.Flags().BoolP("verbose", "v", false, "Enable verbose logging (same as --log-level=debug)") + + // Bind flags to viper with JIGGABLEND_ prefix + runnerViper.SetEnvPrefix("JIGGABLEND") + runnerViper.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) + runnerViper.AutomaticEnv() + + runnerViper.BindPFlag("manager", runnerCmd.Flags().Lookup("manager")) + runnerViper.BindPFlag("name", runnerCmd.Flags().Lookup("name")) + runnerViper.BindPFlag("hostname", runnerCmd.Flags().Lookup("hostname")) + runnerViper.BindPFlag("api_key", runnerCmd.Flags().Lookup("api-key")) + runnerViper.BindPFlag("log_file", runnerCmd.Flags().Lookup("log-file")) + runnerViper.BindPFlag("log_level", runnerCmd.Flags().Lookup("log-level")) + runnerViper.BindPFlag("verbose", runnerCmd.Flags().Lookup("verbose")) +} + +func runRunner(cmd *cobra.Command, args []string) { + // Get config values (flags take precedence over env vars) + managerURL := runnerViper.GetString("manager") + name := runnerViper.GetString("name") + hostname := runnerViper.GetString("hostname") + apiKey := runnerViper.GetString("api_key") + logFile := runnerViper.GetString("log_file") + logLevel := runnerViper.GetString("log_level") + verbose := runnerViper.GetBool("verbose") + + var client *runner.Client + + defer func() { + if r := recover(); r != nil { + logger.Errorf("Runner panicked: %v", r) + if client != nil { + client.CleanupWorkspace() + } + os.Exit(1) + } + }() + + if hostname == "" { + hostname, _ = os.Hostname() + } + + // Generate unique runner ID + runnerIDStr := generateShortID() + + // Generate runner name with ID if not provided + if name == "" { + name = fmt.Sprintf("runner-%s-%s", hostname, runnerIDStr) + } else { + name = fmt.Sprintf("%s-%s", name, runnerIDStr) + } + + // Initialize logger + if logFile != "" { + if err := logger.InitWithFile(logFile); err != nil { + logger.Fatalf("Failed to initialize logger: %v", err) + } + defer func() { + if l := logger.GetDefault(); l != nil { + l.Close() + } + }() + } else { + logger.InitStdout() + } + + // Set log level + if verbose { + logger.SetLevel(logger.LevelDebug) + } else { + logger.SetLevel(logger.ParseLevel(logLevel)) + } + + logger.Info("Runner starting up...") + logger.Debugf("Generated runner ID suffix: %s", runnerIDStr) + if logFile != "" { + logger.Infof("Logging to file: %s", logFile) + } + + client = runner.NewClient(managerURL, name, hostname) + + // Clean up orphaned workspace directories + client.CleanupWorkspace() + + // Probe capabilities + logger.Debug("Probing runner capabilities...") + client.ProbeCapabilities() + capabilities := client.GetCapabilities() + capList := []string{} + for cap, value := range capabilities { + if enabled, ok := value.(bool); ok && enabled { + capList = append(capList, cap) + } else if count, ok := value.(int); ok && count > 0 { + capList = append(capList, fmt.Sprintf("%s=%d", cap, count)) + } else if count, ok := value.(float64); ok && count > 0 { + capList = append(capList, fmt.Sprintf("%s=%.0f", cap, count)) + } + } + if len(capList) > 0 { + logger.Infof("Detected capabilities: %s", strings.Join(capList, ", ")) + } else { + logger.Warn("No capabilities detected") + } + + // Register with API key + if apiKey == "" { + logger.Fatal("API key required (use --api-key or set JIGGABLEND_API_KEY env var)") + } + + // Retry registration with exponential backoff + backoff := 1 * time.Second + maxBackoff := 30 * time.Second + maxRetries := 10 + retryCount := 0 + + var runnerID int64 + + for { + var err error + runnerID, _, _, err = client.Register(apiKey) + if err == nil { + logger.Infof("Registered runner with ID: %d", runnerID) + break + } + + errMsg := err.Error() + if strings.Contains(errMsg, "token error:") { + logger.Fatalf("Registration failed (token error): %v", err) + } + + retryCount++ + if retryCount >= maxRetries { + logger.Fatalf("Failed to register runner after %d attempts: %v", maxRetries, err) + } + + logger.Warnf("Registration failed (attempt %d/%d): %v, retrying in %v", retryCount, maxRetries, err, backoff) + time.Sleep(backoff) + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + } + + // Start WebSocket connection + go client.ConnectWebSocketWithReconnect() + + // Start heartbeat loop + go client.HeartbeatLoop() + + logger.Info("Runner started, connecting to manager via WebSocket...") + + // Signal handlers + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + + go func() { + sig := <-sigChan + logger.Infof("Received signal: %v, killing all processes and cleaning up...", sig) + client.KillAllProcesses() + client.CleanupWorkspace() + os.Exit(0) + }() + + // Block forever + select {} +} + +func generateShortID() string { + bytes := make([]byte, 4) + if _, err := rand.Read(bytes); err != nil { + return fmt.Sprintf("%x", os.Getpid()^int(time.Now().Unix())) + } + return hex.EncodeToString(bytes) +} diff --git a/cmd/jiggablend/main.go b/cmd/jiggablend/main.go new file mode 100644 index 0000000..f5e7d13 --- /dev/null +++ b/cmd/jiggablend/main.go @@ -0,0 +1,14 @@ +package main + +import ( + "os" + + "jiggablend/cmd/jiggablend/cmd" +) + +func main() { + if err := cmd.Execute(); err != nil { + os.Exit(1) + } +} + diff --git a/cmd/manager/main.go b/cmd/manager/main.go deleted file mode 100644 index e966f52..0000000 --- a/cmd/manager/main.go +++ /dev/null @@ -1,125 +0,0 @@ -package main - -import ( - "flag" - "fmt" - "log" - "net/http" - "os" - "os/exec" - - "jiggablend/internal/api" - "jiggablend/internal/auth" - "jiggablend/internal/database" - "jiggablend/internal/logger" - "jiggablend/internal/storage" -) - -func main() { - var ( - port = flag.String("port", getEnv("PORT", "8080"), "Server port") - dbPath = flag.String("db", getEnv("DB_PATH", "jiggablend.db"), "Database path") - storagePath = flag.String("storage", getEnv("STORAGE_PATH", "./jiggablend-storage"), "Storage path") - logDir = flag.String("log-dir", getEnv("LOG_DIR", "./logs"), "Log directory") - logMaxSize = flag.Int("log-max-size", getEnvInt("LOG_MAX_SIZE", 100), "Maximum log file size in MB before rotation") - logMaxBackups = flag.Int("log-max-backups", getEnvInt("LOG_MAX_BACKUPS", 5), "Maximum number of rotated log files to keep") - logMaxAge = flag.Int("log-max-age", getEnvInt("LOG_MAX_AGE", 30), "Maximum age in days for rotated log files") - ) - flag.Parse() - - // Initialize logger (writes to both stdout and log file with rotation) - logDirPath := *logDir - if err := logger.Init(logDirPath, "manager.log", *logMaxSize, *logMaxBackups, *logMaxAge); err != nil { - log.Fatalf("Failed to initialize logger: %v", err) - } - defer func() { - if l := logger.GetDefault(); l != nil { - l.Close() - } - }() - log.Printf("Log rotation configured: max_size=%dMB, max_backups=%d, max_age=%d days", *logMaxSize, *logMaxBackups, *logMaxAge) - - // Initialize database - db, err := database.NewDB(*dbPath) - if err != nil { - log.Fatalf("Failed to initialize database: %v", err) - } - defer db.Close() - - // Initialize auth - authHandler, err := auth.NewAuth(db.DB) - if err != nil { - log.Fatalf("Failed to initialize auth: %v", err) - } - - // Initialize storage - storageHandler, err := storage.NewStorage(*storagePath) - if err != nil { - log.Fatalf("Failed to initialize storage: %v", err) - } - - // Check if Blender is available (required for metadata extraction) - if err := checkBlenderAvailable(); err != nil { - log.Fatalf("Blender is not available: %v\n"+ - "The manager requires Blender to be installed and in PATH for metadata extraction.\n"+ - "Please install Blender and ensure it's accessible via the 'blender' command.", err) - } - log.Printf("Blender is available") - - // Create API server - server, err := api.NewServer(db, authHandler, storageHandler) - if err != nil { - log.Fatalf("Failed to create server: %v", err) - } - - // Start server with increased request body size limit for large file uploads - addr := fmt.Sprintf(":%s", *port) - log.Printf("Starting manager server on %s", addr) - log.Printf("Database: %s", *dbPath) - log.Printf("Storage: %s", *storagePath) - - httpServer := &http.Server{ - Addr: addr, - Handler: server, - MaxHeaderBytes: 1 << 20, // 1 MB for headers - ReadTimeout: 0, // No read timeout (for large uploads) - WriteTimeout: 0, // No write timeout (for large uploads) - } - - // Note: MaxRequestBodySize is not directly configurable in http.Server - // It's handled by ParseMultipartForm in handlers, which we've already configured - // But we need to ensure the server can handle large requests - // The default limit is 10MB, but we bypass it by using ParseMultipartForm with larger limit - - if err := httpServer.ListenAndServe(); err != nil { - log.Fatalf("Server failed: %v", err) - } -} - -func getEnv(key, defaultValue string) string { - if value := os.Getenv(key); value != "" { - return value - } - return defaultValue -} - -func getEnvInt(key string, defaultValue int) int { - if value := os.Getenv(key); value != "" { - var result int - if _, err := fmt.Sscanf(value, "%d", &result); err == nil { - return result - } - } - return defaultValue -} - -// checkBlenderAvailable checks if Blender is available by running `blender --version` -func checkBlenderAvailable() error { - cmd := exec.Command("blender", "--version") - output, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("failed to run 'blender --version': %w (output: %s)", err, string(output)) - } - // If we got here, Blender is available - return nil -} diff --git a/cmd/runner/main.go b/cmd/runner/main.go deleted file mode 100644 index 387e6d6..0000000 --- a/cmd/runner/main.go +++ /dev/null @@ -1,207 +0,0 @@ -package main - -import ( - "crypto/rand" - "encoding/hex" - "flag" - "fmt" - "log" - "os" - "os/signal" - "strings" - "syscall" - "time" - - "jiggablend/internal/logger" - "jiggablend/internal/runner" -) - -// Removed SecretsFile - runners now generate ephemeral instance IDs - -func main() { - log.Printf("Runner starting up...") - - // Create client early so we can clean it up on panic - var client *runner.Client - - defer func() { - if r := recover(); r != nil { - log.Printf("Runner panicked: %v", r) - // Clean up workspace even on panic - if client != nil { - client.CleanupWorkspace() - } - os.Exit(1) - } - }() - - var ( - managerURL = flag.String("manager", getEnv("MANAGER_URL", "http://localhost:8080"), "Manager URL") - name = flag.String("name", getEnv("RUNNER_NAME", ""), "Runner name") - hostname = flag.String("hostname", getEnv("RUNNER_HOSTNAME", ""), "Runner hostname") - apiKeyFlag = flag.String("api-key", getEnv("API_KEY", ""), "API key for authentication") - logDir = flag.String("log-dir", getEnv("LOG_DIR", "./logs"), "Log directory") - logMaxSize = flag.Int("log-max-size", getEnvInt("LOG_MAX_SIZE", 100), "Maximum log file size in MB before rotation") - logMaxBackups = flag.Int("log-max-backups", getEnvInt("LOG_MAX_BACKUPS", 5), "Maximum number of rotated log files to keep") - logMaxAge = flag.Int("log-max-age", getEnvInt("LOG_MAX_AGE", 30), "Maximum age in days for rotated log files") - ) - flag.Parse() - log.Printf("Flags parsed, hostname: %s", *hostname) - - if *hostname == "" { - *hostname, _ = os.Hostname() - } - - // Always generate a random runner ID suffix on startup - // This ensures every runner has a unique local identifier - runnerIDStr := generateShortID() - log.Printf("Generated runner ID suffix: %s", runnerIDStr) - - // Generate runner name with ID if not provided - if *name == "" { - *name = fmt.Sprintf("runner-%s-%s", *hostname, runnerIDStr) - } else { - // Append ID to provided name to ensure uniqueness - *name = fmt.Sprintf("%s-%s", *name, runnerIDStr) - } - - // Initialize logger (writes to both stdout and log file with rotation) - // Use runner-specific log file name based on the final name - sanitizedName := strings.ReplaceAll(*name, "/", "_") - sanitizedName = strings.ReplaceAll(sanitizedName, "\\", "_") - logFileName := fmt.Sprintf("runner-%s.log", sanitizedName) - - if err := logger.Init(*logDir, logFileName, *logMaxSize, *logMaxBackups, *logMaxAge); err != nil { - log.Fatalf("Failed to initialize logger: %v", err) - } - defer func() { - if l := logger.GetDefault(); l != nil { - l.Close() - } - }() - log.Printf("Logger initialized, continuing with startup...") - log.Printf("Log rotation configured: max_size=%dMB, max_backups=%d, max_age=%d days", *logMaxSize, *logMaxBackups, *logMaxAge) - - log.Printf("About to create client...") - client = runner.NewClient(*managerURL, *name, *hostname) - log.Printf("Client created successfully") - - // Clean up any orphaned workspace directories from previous runs - client.CleanupWorkspace() - - // Probe capabilities once at startup (before any registration attempts) - log.Printf("Probing runner capabilities...") - client.ProbeCapabilities() - capabilities := client.GetCapabilities() - capList := []string{} - for cap, value := range capabilities { - // Only show boolean true capabilities and numeric GPU counts - if enabled, ok := value.(bool); ok && enabled { - capList = append(capList, cap) - } else if count, ok := value.(int); ok && count > 0 { - capList = append(capList, fmt.Sprintf("%s=%d", cap, count)) - } else if count, ok := value.(float64); ok && count > 0 { - capList = append(capList, fmt.Sprintf("%s=%.0f", cap, count)) - } - } - if len(capList) > 0 { - log.Printf("Detected capabilities: %s", strings.Join(capList, ", ")) - } else { - log.Printf("Warning: No capabilities detected") - } - - // Register with API key (with retry logic) - if *apiKeyFlag == "" { - log.Fatalf("API key required (use --api-key or set API_KEY env var)") - } - - // Retry registration with exponential backoff - backoff := 1 * time.Second - maxBackoff := 30 * time.Second - maxRetries := 10 - retryCount := 0 - - var runnerID int64 - - for { - var err error - runnerID, _, _, err = client.Register(*apiKeyFlag) - if err == nil { - log.Printf("Registered runner with ID: %d", runnerID) - break - } - - // Check if it's a token error (invalid/expired/used token) - shutdown immediately - errMsg := err.Error() - if strings.Contains(errMsg, "token error:") { - log.Fatalf("Registration failed (token error): %v", err) - } - - // Only retry on connection errors or other retryable errors - retryCount++ - if retryCount >= maxRetries { - log.Fatalf("Failed to register runner after %d attempts: %v", maxRetries, err) - } - - log.Printf("Registration failed (attempt %d/%d): %v, retrying in %v", retryCount, maxRetries, err, backoff) - time.Sleep(backoff) - backoff *= 2 - if backoff > maxBackoff { - backoff = maxBackoff - } - } - - // Start WebSocket connection with reconnection - go client.ConnectWebSocketWithReconnect() - - // Start heartbeat loop (for WebSocket ping/pong and HTTP fallback) - go client.HeartbeatLoop() - - // ProcessTasks is now handled via WebSocket, but kept for HTTP fallback - // WebSocket will handle task assignment automatically - log.Printf("Runner started, connecting to manager via WebSocket...") - - // Set up signal handlers to kill processes on shutdown - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) - - go func() { - sig := <-sigChan - log.Printf("Received signal: %v, killing all processes and cleaning up...", sig) - client.KillAllProcesses() - // Cleanup happens in defer, but also do it here for good measure - client.CleanupWorkspace() - os.Exit(0) - }() - - // Block forever - select {} -} - - -func getEnv(key, defaultValue string) string { - if value := os.Getenv(key); value != "" { - return value - } - return defaultValue -} - -func getEnvInt(key string, defaultValue int) int { - if value := os.Getenv(key); value != "" { - var result int - if _, err := fmt.Sscanf(value, "%d", &result); err == nil { - return result - } - } - return defaultValue -} - -// generateShortID generates a short random ID (8 hex characters) -func generateShortID() string { - bytes := make([]byte, 4) - if _, err := rand.Read(bytes); err != nil { - // Fallback to timestamp-based ID if crypto/rand fails - return fmt.Sprintf("%x", os.Getpid()^int(time.Now().Unix())) - } - return hex.EncodeToString(bytes) -} diff --git a/go.mod b/go.mod index 5340e12..1cd729e 100644 --- a/go.mod +++ b/go.mod @@ -7,34 +7,28 @@ require ( github.com/go-chi/cors v1.2.2 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 - github.com/marcboeker/go-duckdb/v2 v2.4.3 + github.com/mattn/go-sqlite3 v1.14.22 + github.com/spf13/cobra v1.10.1 + github.com/spf13/viper v1.21.0 golang.org/x/crypto v0.45.0 golang.org/x/oauth2 v0.33.0 - gopkg.in/natefinch/lumberjack.v2 v2.2.1 ) require ( cloud.google.com/go/compute/metadata v0.3.0 // indirect - github.com/apache/arrow-go/v18 v18.4.1 // indirect - github.com/duckdb/duckdb-go-bindings v0.1.21 // indirect - github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.21 // indirect - github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.21 // indirect - github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.21 // indirect - github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.21 // indirect - github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.21 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect - github.com/goccy/go-json v0.10.5 // indirect - github.com/google/flatbuffers v25.2.10+incompatible // indirect - github.com/klauspost/compress v1.18.0 // indirect - github.com/klauspost/cpuid/v2 v2.3.0 // indirect - github.com/marcboeker/go-duckdb/arrowmapping v0.0.21 // indirect - github.com/marcboeker/go-duckdb/mapping v0.0.21 // indirect - github.com/pierrec/lz4/v4 v4.1.22 // indirect - github.com/zeebo/xxh3 v1.0.2 // indirect - golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 // indirect - golang.org/x/mod v0.27.0 // indirect - golang.org/x/sync v0.16.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/sagikazarmark/locafero v0.11.0 // indirect + github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect + github.com/spf13/afero v1.15.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/sys v0.38.0 // indirect - golang.org/x/tools v0.36.0 // indirect - golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect + golang.org/x/text v0.31.0 // indirect ) diff --git a/go.sum b/go.sum index 6dc6df1..4438838 100644 --- a/go.sum +++ b/go.sum @@ -1,88 +1,70 @@ cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= -github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= -github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= -github.com/apache/arrow-go/v18 v18.4.1 h1:q/jVkBWCJOB9reDgaIZIdruLQUb1kbkvOnOFezVH1C4= -github.com/apache/arrow-go/v18 v18.4.1/go.mod h1:tLyFubsAl17bvFdUAy24bsSvA/6ww95Iqi67fTpGu3E= -github.com/apache/thrift v0.22.0 h1:r7mTJdj51TMDe6RtcmNdQxgn9XcyfGDOzegMDRg47uc= -github.com/apache/thrift v0.22.0/go.mod h1:1e7J/O1Ae6ZQMTYdy9xa3w9k+XHWPfRvdPyJeynQ+/g= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/duckdb/duckdb-go-bindings v0.1.21 h1:bOb/MXNT4PN5JBZ7wpNg6hrj9+cuDjWDa4ee9UdbVyI= -github.com/duckdb/duckdb-go-bindings v0.1.21/go.mod h1:pBnfviMzANT/9hi4bg+zW4ykRZZPCXlVuvBWEcZofkc= -github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.21 h1:Sjjhf2F/zCjPF53c2VXOSKk0PzieMriSoyr5wfvr9d8= -github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.21/go.mod h1:Ezo7IbAfB8NP7CqPIN8XEHKUg5xdRRQhcPPlCXImXYA= -github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.21 h1:IUk0FFUB6dpWLhlN9hY1mmdPX7Hkn3QpyrAmn8pmS8g= -github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.21/go.mod h1:eS7m/mLnPQgVF4za1+xTyorKRBuK0/BA44Oy6DgrGXI= -github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.21 h1:Qpc7ZE3n6Nwz30KTvaAwI6nGkXjXmMxBTdFpC8zDEYI= -github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.21/go.mod h1:1GOuk1PixiESxLaCGFhag+oFi7aP+9W8byymRAvunBk= -github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.21 h1:eX2DhobAZOgjXkh8lPnKAyrxj8gXd2nm+K71f6KV/mo= -github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.21/go.mod h1:o7crKMpT2eOIi5/FY6HPqaXcvieeLSqdXXaXbruGX7w= -github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.21 h1:hhziFnGV7mpA+v5J5G2JnYQ+UWCCP3NQ+OTvxFX10D8= -github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.21/go.mod h1:IlOhJdVKUJCAPj3QsDszUo8DVdvp1nBFp4TUJVdw99s= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= github.com/go-chi/cors v1.2.2 h1:Jmey33TE+b+rB7fT8MUy1u0I4L+NARQlK6LhzKPSyQE= github.com/go-chi/cors v1.2.2/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= -github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= -github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= -github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= -github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/google/flatbuffers v25.2.10+incompatible h1:F3vclr7C3HpB1k9mxCGRMXq6FdUalZ6H/pNX4FP1v0Q= -github.com/google/flatbuffers v25.2.10+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4= -github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE= -github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= -github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= -github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= -github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= -github.com/marcboeker/go-duckdb/arrowmapping v0.0.21 h1:geHnVjlsAJGczSWEqYigy/7ARuD+eBtjd0kLN80SPJQ= -github.com/marcboeker/go-duckdb/arrowmapping v0.0.21/go.mod h1:flFTc9MSqQCh2Xm62RYvG3Kyj29h7OtsTb6zUx1CdK8= -github.com/marcboeker/go-duckdb/mapping v0.0.21 h1:6woNXZn8EfYdc9Vbv0qR6acnt0TM1s1eFqnrJZVrqEs= -github.com/marcboeker/go-duckdb/mapping v0.0.21/go.mod h1:q3smhpLyv2yfgkQd7gGHMd+H/Z905y+WYIUjrl29vT4= -github.com/marcboeker/go-duckdb/v2 v2.4.3 h1:bHUkphPsAp2Bh/VFEdiprGpUekxBNZiWWtK+Bv/ljRk= -github.com/marcboeker/go-duckdb/v2 v2.4.3/go.mod h1:taim9Hktg2igHdNBmg5vgTfHAlV26z3gBI0QXQOcuyI= -github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs= -github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY= -github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI= -github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE= -github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU= -github.com/pierrec/lz4/v4 v4.1.22/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.11.0 h1:ib4sjIrwZKxE5u/Japgo/7SJV3PvgjGiRNAvTVGqQl8= -github.com/stretchr/testify v1.11.0/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= -github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= -github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= -github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc= +github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U= +github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= +github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s= +github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= +github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= -golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM= -golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8= -golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ= -golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc= golang.org/x/oauth2 v0.33.0 h1:4Q+qn+E5z8gPRJfmRy7C2gGG3T4jIprK6aSYgTXGRpo= golang.org/x/oauth2 v0.33.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= -golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= -golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= -golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY= -golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= -gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= -gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= -gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= -gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= +golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= +golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/api/admin.go b/internal/api/admin.go index 2106d38..5b36bd6 100644 --- a/internal/api/admin.go +++ b/internal/api/admin.go @@ -131,14 +131,19 @@ func (s *Server) handleVerifyRunner(w http.ResponseWriter, r *http.Request) { // Check if runner exists var exists bool - err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists) + }) if err != nil || !exists { s.respondError(w, http.StatusNotFound, "Runner not found") return } // Mark runner as verified - _, err = s.db.Exec("UPDATE runners SET verified = 1 WHERE id = ?", runnerID) + err = s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec("UPDATE runners SET verified = 1 WHERE id = ?", runnerID) + return err + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify runner: %v", err)) return @@ -157,14 +162,19 @@ func (s *Server) handleDeleteRunner(w http.ResponseWriter, r *http.Request) { // Check if runner exists var exists bool - err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists) + }) if err != nil || !exists { s.respondError(w, http.StatusNotFound, "Runner not found") return } // Delete runner - _, err = s.db.Exec("DELETE FROM runners WHERE id = ?", runnerID) + err = s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec("DELETE FROM runners WHERE id = ?", runnerID) + return err + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to delete runner: %v", err)) return @@ -175,17 +185,31 @@ func (s *Server) handleDeleteRunner(w http.ResponseWriter, r *http.Request) { // handleListRunnersAdmin lists all runners with admin details func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request) { - rows, err := s.db.Query( + var rows *sql.Rows + err := s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query( `SELECT id, name, hostname, status, last_heartbeat, capabilities, api_key_id, api_key_scope, priority, created_at FROM runners ORDER BY created_at DESC`, ) + return err + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query runners: %v", err)) return } defer rows.Close() + // Get the set of currently connected runners via WebSocket + // This is the source of truth for online status + s.runnerConnsMu.RLock() + connectedRunners := make(map[int64]bool) + for runnerID := range s.runnerConns { + connectedRunners[runnerID] = true + } + s.runnerConnsMu.RUnlock() + runners := []map[string]interface{}{} for rows.Next() { var runner types.Runner @@ -202,11 +226,21 @@ func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request) return } + // Override status based on actual WebSocket connection state + // The WebSocket connection is the source of truth for runner status + actualStatus := runner.Status + if connectedRunners[runner.ID] { + actualStatus = types.RunnerStatusOnline + } else if runner.Status == types.RunnerStatusOnline { + // Database says online but not connected via WebSocket - mark as offline + actualStatus = types.RunnerStatusOffline + } + runners = append(runners, map[string]interface{}{ "id": runner.ID, "name": runner.Name, "hostname": runner.Hostname, - "status": runner.Status, + "status": actualStatus, "last_heartbeat": runner.LastHeartbeat, "capabilities": runner.Capabilities, "api_key_id": apiKeyID.Int64, @@ -228,10 +262,15 @@ func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) { firstUserID = 0 } - rows, err := s.db.Query( + var rows *sql.Rows + err = s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query( `SELECT id, email, name, oauth_provider, is_admin, created_at FROM users ORDER BY created_at DESC`, ) + return err + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query users: %v", err)) return @@ -253,7 +292,9 @@ func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) { // Get job count for this user var jobCount int - err = s.db.QueryRow("SELECT COUNT(*) FROM jobs WHERE user_id = ?", userID).Scan(&jobCount) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT COUNT(*) FROM jobs WHERE user_id = ?", userID).Scan(&jobCount) + }) if err != nil { jobCount = 0 // Default to 0 if query fails } @@ -283,18 +324,25 @@ func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) { // Verify user exists var exists bool - err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", userID).Scan(&exists) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", userID).Scan(&exists) + }) if err != nil || !exists { s.respondError(w, http.StatusNotFound, "User not found") return } - rows, err := s.db.Query( + var rows *sql.Rows + err = s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query( `SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format, allow_parallel_runners, timeout_seconds, blend_metadata, created_at, started_at, completed_at, error_message FROM jobs WHERE user_id = ? ORDER BY created_at DESC`, userID, ) + return err + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query jobs: %v", err)) return diff --git a/internal/api/jobs.go b/internal/api/jobs.go index 43d8fb4..5b08eda 100644 --- a/internal/api/jobs.go +++ b/internal/api/jobs.go @@ -2,8 +2,6 @@ package api import ( "archive/tar" - "bufio" - "bytes" "crypto/md5" "database/sql" _ "embed" @@ -15,7 +13,6 @@ import ( "log" "net/http" "os" - "os/exec" "path/filepath" "strconv" "strings" @@ -23,11 +20,11 @@ import ( "time" authpkg "jiggablend/internal/auth" + "jiggablend/pkg/executils" + "jiggablend/pkg/scripts" "jiggablend/pkg/types" "github.com/gorilla/websocket" - - "jiggablend/pkg/scripts" ) // generateETag generates an ETag from data hash @@ -82,6 +79,14 @@ func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) { s.respondError(w, http.StatusBadRequest, "frame_start and frame_end are required for render jobs") return } + if *req.FrameStart < 0 { + s.respondError(w, http.StatusBadRequest, "frame_start must be 0 or greater. Negative starting frames are not supported.") + return + } + if *req.FrameEnd < 0 { + s.respondError(w, http.StatusBadRequest, "frame_end must be 0 or greater. Negative frame numbers are not supported.") + return + } if *req.FrameEnd < *req.FrameStart { s.respondError(w, http.StatusBadRequest, "Invalid frame range") return @@ -129,30 +134,22 @@ func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) { metadataStr := string(metadataBytes) blendMetadataJSON = &metadataStr } - } else if req.UnhideObjects != nil || req.EnableExecution != nil { - // Even if no render settings, store unhide_objects and enable_execution flags - metadata := types.BlendMetadata{ - FrameStart: *req.FrameStart, - FrameEnd: *req.FrameEnd, - RenderSettings: types.RenderSettings{}, - UnhideObjects: req.UnhideObjects, - EnableExecution: req.EnableExecution, - } - metadataBytes, err := json.Marshal(metadata) - if err == nil { - metadataStr := string(metadataBytes) - blendMetadataJSON = &metadataStr - } } log.Printf("Creating render job with output_format: '%s' (from user selection)", *req.OutputFormat) var jobID int64 - err = s.db.QueryRow( - `INSERT INTO jobs (user_id, job_type, name, status, progress, frame_start, frame_end, output_format, allow_parallel_runners, timeout_seconds, blend_metadata) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - RETURNING id`, - userID, req.JobType, req.Name, types.JobStatusPending, 0.0, *req.FrameStart, *req.FrameEnd, *req.OutputFormat, allowParallelRunners, jobTimeout, blendMetadataJSON, - ).Scan(&jobID) + err = s.db.With(func(conn *sql.DB) error { + result, err := conn.Exec( + `INSERT INTO jobs (user_id, job_type, name, status, progress, frame_start, frame_end, output_format, allow_parallel_runners, timeout_seconds, blend_metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + userID, req.JobType, req.Name, types.JobStatusPending, 0.0, *req.FrameStart, *req.FrameEnd, *req.OutputFormat, allowParallelRunners, jobTimeout, blendMetadataJSON, + ) + if err != nil { + return err + } + jobID, err = result.LastInsertId() + return err + }) if err == nil { log.Printf("Created render job %d with output_format: '%s'", jobID, *req.OutputFormat) } @@ -231,12 +228,18 @@ func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) { } var fileID int64 - err = s.db.QueryRow( - `INSERT INTO job_files (job_id, file_type, file_path, file_name, file_size) - VALUES (?, ?, ?, ?, ?) - RETURNING id`, - jobID, types.JobFileTypeInput, jobContextPath, filepath.Base(jobContextPath), contextInfo.Size(), - ).Scan(&fileID) + err = s.db.With(func(conn *sql.DB) error { + result, err := conn.Exec( + `INSERT INTO job_files (job_id, file_type, file_path, file_name, file_size) + VALUES (?, ?, ?, ?, ?)`, + jobID, types.JobFileTypeInput, jobContextPath, filepath.Base(jobContextPath), contextInfo.Size(), + ) + if err != nil { + return err + } + fileID, err = result.LastInsertId() + return err + }) if err != nil { log.Printf("ERROR: Failed to record context archive in database for job %d: %v", jobID, err) s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to record context archive: %v", err)) @@ -283,12 +286,18 @@ func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) { if allowParallelRunners != nil && !*allowParallelRunners { // Single task for entire frame range var taskID int64 - err = s.db.QueryRow( - `INSERT INTO tasks (job_id, frame_start, frame_end, task_type, status, timeout_seconds, max_retries) - VALUES (?, ?, ?, ?, ?, ?, ?) - RETURNING id`, - jobID, *req.FrameStart, *req.FrameEnd, types.TaskTypeRender, types.TaskStatusPending, taskTimeout, 3, - ).Scan(&taskID) + err = s.db.With(func(conn *sql.DB) error { + result, err := conn.Exec( + `INSERT INTO tasks (job_id, frame_start, frame_end, task_type, status, timeout_seconds, max_retries) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + jobID, *req.FrameStart, *req.FrameEnd, types.TaskTypeRender, types.TaskStatusPending, taskTimeout, 3, + ) + if err != nil { + return err + } + taskID, err = result.LastInsertId() + return err + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create task: %v", err)) return @@ -299,12 +308,18 @@ func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) { // One task per frame for parallel processing for frame := *req.FrameStart; frame <= *req.FrameEnd; frame++ { var taskID int64 - err = s.db.QueryRow( - `INSERT INTO tasks (job_id, frame_start, frame_end, task_type, status, timeout_seconds, max_retries) - VALUES (?, ?, ?, ?, ?, ?, ?) - RETURNING id`, - jobID, frame, frame, types.TaskTypeRender, types.TaskStatusPending, taskTimeout, 3, - ).Scan(&taskID) + err = s.db.With(func(conn *sql.DB) error { + result, err := conn.Exec( + `INSERT INTO tasks (job_id, frame_start, frame_end, task_type, status, timeout_seconds, max_retries) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + jobID, frame, frame, types.TaskTypeRender, types.TaskStatusPending, taskTimeout, 3, + ) + if err != nil { + return err + } + taskID, err = result.LastInsertId() + return err + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create tasks: %v", err)) return @@ -344,6 +359,23 @@ func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) { job.AllowParallelRunners = allowParallelRunners } + // Broadcast job_created to all clients via jobs channel + s.broadcastToAllClients("jobs", map[string]interface{}{ + "type": "job_created", + "job_id": jobID, + "data": map[string]interface{}{ + "id": jobID, + "name": req.Name, + "status": types.JobStatusPending, + "progress": 0.0, + "frame_start": *req.FrameStart, + "frame_end": *req.FrameEnd, + "output_format": *req.OutputFormat, + "created_at": time.Now(), + }, + "timestamp": time.Now().Unix(), + }) + // Immediately try to distribute tasks to connected runners s.triggerTaskDistribution() @@ -419,32 +451,40 @@ func (s *Server) handleListJobs(w http.ResponseWriter, r *http.Request) { query += fmt.Sprintf(" ORDER BY %s %s LIMIT ? OFFSET ?", sortField, sortDir) args = append(args, limit, offset) - rows, err := s.db.Query(query, args...) + var rows *sql.Rows + var total int + err = s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query(query, args...) + if err != nil { + return err + } + + // Get total count for pagination metadata + countQuery := `SELECT COUNT(*) FROM jobs WHERE user_id = ?` + countArgs := []interface{}{userID} + if statusFilter != "" { + statuses := strings.Split(statusFilter, ",") + placeholders := make([]string, len(statuses)) + for i, status := range statuses { + placeholders[i] = "?" + countArgs = append(countArgs, strings.TrimSpace(status)) + } + countQuery += fmt.Sprintf(" AND status IN (%s)", strings.Join(placeholders, ",")) + } + err = conn.QueryRow(countQuery, countArgs...).Scan(&total) + if err != nil { + // If count fails, continue without it + total = -1 + } + return nil + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query jobs: %v", err)) return } defer rows.Close() - // Get total count for pagination metadata - var total int - countQuery := `SELECT COUNT(*) FROM jobs WHERE user_id = ?` - countArgs := []interface{}{userID} - if statusFilter != "" { - statuses := strings.Split(statusFilter, ",") - placeholders := make([]string, len(statuses)) - for i, status := range statuses { - placeholders[i] = "?" - countArgs = append(countArgs, strings.TrimSpace(status)) - } - countQuery += fmt.Sprintf(" AND status IN (%s)", strings.Join(placeholders, ",")) - } - err = s.db.QueryRow(countQuery, countArgs...).Scan(&total) - if err != nil { - // If count fails, continue without it - total = -1 - } - jobs := []types.Job{} for rows.Next() { var job types.Job @@ -583,27 +623,30 @@ func (s *Server) handleListJobsSummary(w http.ResponseWriter, r *http.Request) { query += fmt.Sprintf(" ORDER BY %s %s LIMIT ? OFFSET ?", sortField, sortDir) args = append(args, limit, offset) - rows, err := s.db.Query(query, args...) - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query jobs: %v", err)) - return - } - defer rows.Close() - - // Get total count + var rows *sql.Rows var total int - countQuery := `SELECT COUNT(*) FROM jobs WHERE user_id = ?` - countArgs := []interface{}{userID} - if statusFilter != "" { - statuses := strings.Split(statusFilter, ",") - placeholders := make([]string, len(statuses)) - for i, status := range statuses { - placeholders[i] = "?" - countArgs = append(countArgs, strings.TrimSpace(status)) + err = s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query(query, args...) + if err != nil { + return err } - countQuery += fmt.Sprintf(" AND status IN (%s)", strings.Join(placeholders, ",")) - } - err = s.db.QueryRow(countQuery, countArgs...).Scan(&total) + + // Get total count + countQuery := `SELECT COUNT(*) FROM jobs WHERE user_id = ?` + countArgs := []interface{}{userID} + if statusFilter != "" { + statuses := strings.Split(statusFilter, ",") + placeholders := make([]string, len(statuses)) + for i, status := range statuses { + placeholders[i] = "?" + countArgs = append(countArgs, strings.TrimSpace(status)) + } + countQuery += fmt.Sprintf(" AND status IN (%s)", strings.Join(placeholders, ",")) + } + err = conn.QueryRow(countQuery, countArgs...).Scan(&total) + return err + }) if err != nil { total = -1 } @@ -697,7 +740,12 @@ func (s *Server) handleBatchGetJobs(w http.ResponseWriter, r *http.Request) { allow_parallel_runners, timeout_seconds, blend_metadata, created_at, started_at, completed_at, error_message FROM jobs WHERE user_id = ? AND id IN (%s) ORDER BY created_at DESC`, strings.Join(placeholders, ",")) - rows, err := s.db.Query(query, args...) + var rows *sql.Rows + err = s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query(query, args...) + return err + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query jobs: %v", err)) return @@ -788,29 +836,31 @@ func (s *Server) handleGetJob(w http.ResponseWriter, r *http.Request) { // Allow admins to view any job, regular users can only view their own isAdmin := isAdminUser(r) var err2 error - if isAdmin { - err2 = s.db.QueryRow( - `SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format, + err2 = s.db.With(func(conn *sql.DB) error { + if isAdmin { + return conn.QueryRow( + `SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format, allow_parallel_runners, timeout_seconds, blend_metadata, created_at, started_at, completed_at, error_message FROM jobs WHERE id = ?`, - jobID, - ).Scan( - &job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress, - &frameStart, &frameEnd, &outputFormat, &allowParallelRunners, &job.TimeoutSeconds, - &blendMetadataJSON, &job.CreatedAt, &startedAt, &completedAt, &errorMessage, - ) - } else { - err2 = s.db.QueryRow( - `SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format, + jobID, + ).Scan( + &job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress, + &frameStart, &frameEnd, &outputFormat, &allowParallelRunners, &job.TimeoutSeconds, + &blendMetadataJSON, &job.CreatedAt, &startedAt, &completedAt, &errorMessage, + ) + } else { + return conn.QueryRow( + `SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format, allow_parallel_runners, timeout_seconds, blend_metadata, created_at, started_at, completed_at, error_message FROM jobs WHERE id = ? AND user_id = ?`, - jobID, userID, - ).Scan( - &job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress, - &frameStart, &frameEnd, &outputFormat, &allowParallelRunners, &job.TimeoutSeconds, - &blendMetadataJSON, &job.CreatedAt, &startedAt, &completedAt, &errorMessage, - ) - } + jobID, userID, + ).Scan( + &job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress, + &frameStart, &frameEnd, &outputFormat, &allowParallelRunners, &job.TimeoutSeconds, + &blendMetadataJSON, &job.CreatedAt, &startedAt, &completedAt, &errorMessage, + ) + } + }) if err2 == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") @@ -881,7 +931,9 @@ func (s *Server) handleCancelJob(w http.ResponseWriter, r *http.Request) { // Check if this is a metadata extraction job - if so, don't cancel running metadata tasks var jobType string var jobStatus string - err = s.db.QueryRow("SELECT job_type, status FROM jobs WHERE id = ? AND user_id = ?", jobID, userID).Scan(&jobType, &jobStatus) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT job_type, status FROM jobs WHERE id = ? AND user_id = ?", jobID, userID).Scan(&jobType, &jobStatus) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -897,34 +949,38 @@ func (s *Server) handleCancelJob(w http.ResponseWriter, r *http.Request) { return } - result, err := s.db.Exec( - `UPDATE jobs SET status = ? WHERE id = ? AND user_id = ?`, - types.JobStatusCancelled, jobID, userID, - ) + var rowsAffected int64 + err = s.db.With(func(conn *sql.DB) error { + result, err := conn.Exec( + `UPDATE jobs SET status = ? WHERE id = ? AND user_id = ?`, + types.JobStatusCancelled, jobID, userID, + ) + if err != nil { + return err + } + rowsAffected, _ = result.RowsAffected() + if rowsAffected == 0 { + return sql.ErrNoRows + } + + // Cancel all pending tasks + _, err = conn.Exec( + `UPDATE tasks SET status = ? WHERE job_id = ? AND status = ?`, + types.TaskStatusFailed, jobID, types.TaskStatusPending, + ) + return err + }) + if err == sql.ErrNoRows { + s.respondError(w, http.StatusNotFound, "Job not found") + return + } if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to cancel job: %v", err)) return } - rowsAffected, _ := result.RowsAffected() - if rowsAffected == 0 { - s.respondError(w, http.StatusNotFound, "Job not found") - return - } - log.Printf("Cancelling job %d (type: %s)", jobID, jobType) - // Cancel all pending tasks - _, err = s.db.Exec( - `UPDATE tasks SET status = ? WHERE job_id = ? AND status = ?`, - types.TaskStatusFailed, jobID, types.TaskStatusPending, - ) - - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to cancel tasks: %v", err)) - return - } - s.respondJSON(w, http.StatusOK, map[string]string{"message": "Job cancelled"}) } @@ -946,12 +1002,14 @@ func (s *Server) handleDeleteJob(w http.ResponseWriter, r *http.Request) { isAdmin := isAdminUser(r) var jobUserID int64 var jobStatus string - if isAdmin { - err = s.db.QueryRow("SELECT user_id, status FROM jobs WHERE id = ?", jobID).Scan(&jobUserID, &jobStatus) - } else { - // Non-admin users can only delete their own jobs - err = s.db.QueryRow("SELECT user_id, status FROM jobs WHERE id = ? AND user_id = ?", jobID, userID).Scan(&jobUserID, &jobStatus) - } + err = s.db.With(func(conn *sql.DB) error { + if isAdmin { + return conn.QueryRow("SELECT user_id, status FROM jobs WHERE id = ?", jobID).Scan(&jobUserID, &jobStatus) + } else { + // Non-admin users can only delete their own jobs + return conn.QueryRow("SELECT user_id, status FROM jobs WHERE id = ? AND user_id = ?", jobID, userID).Scan(&jobUserID, &jobStatus) + } + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -972,56 +1030,41 @@ func (s *Server) handleDeleteJob(w http.ResponseWriter, r *http.Request) { } // Delete in transaction to ensure consistency - tx, err := s.db.Begin() - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to start transaction: %v", err)) - return - } - defer tx.Rollback() + err = s.db.WithTx(func(tx *sql.Tx) error { + // Delete task logs + _, err := tx.Exec(`DELETE FROM task_logs WHERE task_id IN (SELECT id FROM tasks WHERE job_id = ?)`, jobID) + if err != nil { + return fmt.Errorf("failed to delete task logs: %w", err) + } - // Delete task logs - _, err = tx.Exec(`DELETE FROM task_logs WHERE task_id IN (SELECT id FROM tasks WHERE job_id = ?)`, jobID) - if err != nil { - tx.Rollback() - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to delete task logs: %v", err)) - return - } + // Delete task steps + _, err = tx.Exec(`DELETE FROM task_steps WHERE task_id IN (SELECT id FROM tasks WHERE job_id = ?)`, jobID) + if err != nil { + return fmt.Errorf("failed to delete task steps: %w", err) + } - // Delete task steps - _, err = tx.Exec(`DELETE FROM task_steps WHERE task_id IN (SELECT id FROM tasks WHERE job_id = ?)`, jobID) - if err != nil { - tx.Rollback() - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to delete task steps: %v", err)) - return - } + // Delete tasks + _, err = tx.Exec("DELETE FROM tasks WHERE job_id = ?", jobID) + if err != nil { + return fmt.Errorf("failed to delete tasks: %w", err) + } - // Delete tasks - _, err = tx.Exec("DELETE FROM tasks WHERE job_id = ?", jobID) - if err != nil { - tx.Rollback() - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to delete tasks: %v", err)) - return - } + // Delete job files + _, err = tx.Exec("DELETE FROM job_files WHERE job_id = ?", jobID) + if err != nil { + return fmt.Errorf("failed to delete job files: %w", err) + } - // Delete job files - _, err = tx.Exec("DELETE FROM job_files WHERE job_id = ?", jobID) - if err != nil { - tx.Rollback() - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to delete job files: %v", err)) - return - } + // Delete the job + _, err = tx.Exec("DELETE FROM jobs WHERE id = ?", jobID) + if err != nil { + return fmt.Errorf("failed to delete job: %w", err) + } - // Delete the job - _, err = tx.Exec("DELETE FROM jobs WHERE id = ?", jobID) + return nil + }) if err != nil { - tx.Rollback() - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to delete job: %v", err)) - return - } - - // Commit transaction - if err = tx.Commit(); err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to commit transaction: %v", err)) + s.respondError(w, http.StatusInternalServerError, err.Error()) return } @@ -1059,16 +1102,21 @@ func (s *Server) cleanupOldRenderJobsOnce() { // Find render jobs older than 1 month that are in a final state (completed, failed, or cancelled) // Don't delete running or pending jobs - rows, err := s.db.Query( - `SELECT id FROM jobs + var rows *sql.Rows + err := s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query( + `SELECT id FROM jobs WHERE job_type = ? AND status IN (?, ?, ?) - AND created_at < CURRENT_TIMESTAMP - INTERVAL '1 month'`, - types.JobTypeRender, - types.JobStatusCompleted, - types.JobStatusFailed, - types.JobStatusCancelled, - ) + AND created_at < datetime('now', '-1 month')`, + types.JobTypeRender, + types.JobStatusCompleted, + types.JobStatusFailed, + types.JobStatusCancelled, + ) + return err + }) if err != nil { log.Printf("Failed to query old render jobs: %v", err) return @@ -1095,58 +1143,44 @@ func (s *Server) cleanupOldRenderJobsOnce() { // Delete each job for _, jobID := range jobIDs { // Delete in transaction to ensure consistency - tx, err := s.db.Begin() - if err != nil { - log.Printf("Failed to start transaction for job %d: %v", jobID, err) - continue - } + err := s.db.WithTx(func(tx *sql.Tx) error { + // Delete task logs + _, err := tx.Exec(`DELETE FROM task_logs WHERE task_id IN (SELECT id FROM tasks WHERE job_id = ?)`, jobID) + if err != nil { + return fmt.Errorf("failed to delete task logs: %w", err) + } - // Delete task logs - _, err = tx.Exec(`DELETE FROM task_logs WHERE task_id IN (SELECT id FROM tasks WHERE job_id = ?)`, jobID) - if err != nil { - tx.Rollback() - log.Printf("Failed to delete task logs for job %d: %v", jobID, err) - continue - } + // Delete task steps + _, err = tx.Exec(`DELETE FROM task_steps WHERE task_id IN (SELECT id FROM tasks WHERE job_id = ?)`, jobID) + if err != nil { + return fmt.Errorf("failed to delete task steps: %w", err) + } - // Delete task steps - _, err = tx.Exec(`DELETE FROM task_steps WHERE task_id IN (SELECT id FROM tasks WHERE job_id = ?)`, jobID) - if err != nil { - tx.Rollback() - log.Printf("Failed to delete task steps for job %d: %v", jobID, err) - continue - } + // Delete tasks + _, err = tx.Exec("DELETE FROM tasks WHERE job_id = ?", jobID) + if err != nil { + return fmt.Errorf("failed to delete tasks: %w", err) + } - // Delete tasks - _, err = tx.Exec("DELETE FROM tasks WHERE job_id = ?", jobID) - if err != nil { - tx.Rollback() - log.Printf("Failed to delete tasks for job %d: %v", jobID, err) - continue - } + // Delete job files + _, err = tx.Exec("DELETE FROM job_files WHERE job_id = ?", jobID) + if err != nil { + return fmt.Errorf("failed to delete job files: %w", err) + } - // Delete job files - _, err = tx.Exec("DELETE FROM job_files WHERE job_id = ?", jobID) - if err != nil { - tx.Rollback() - log.Printf("Failed to delete job files for job %d: %v", jobID, err) - continue - } + // Delete the job + _, err = tx.Exec("DELETE FROM jobs WHERE id = ?", jobID) + if err != nil { + return fmt.Errorf("failed to delete job: %w", err) + } - // Delete the job - _, err = tx.Exec("DELETE FROM jobs WHERE id = ?", jobID) + return nil + }) if err != nil { - tx.Rollback() log.Printf("Failed to delete job %d: %v", jobID, err) continue } - // Commit transaction - if err = tx.Commit(); err != nil { - log.Printf("Failed to commit transaction for job %d: %v", jobID, err) - continue - } - // Delete physical files (best effort, don't fail if this errors) if err := s.storage.DeleteJobFiles(jobID); err != nil { log.Printf("Warning: Failed to delete files for render job %d: %v", jobID, err) @@ -1172,7 +1206,9 @@ func (s *Server) handleUploadJobFile(w http.ResponseWriter, r *http.Request) { // Verify job belongs to user var jobUserID int64 - err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -1355,12 +1391,18 @@ func (s *Server) handleUploadJobFile(w http.ResponseWriter, r *http.Request) { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to stat context archive: %v", err)) return } - err = s.db.QueryRow( - `INSERT INTO job_files (job_id, file_type, file_path, file_name, file_size) - VALUES (?, ?, ?, ?, ?) - RETURNING id`, - jobID, types.JobFileTypeInput, contextPath, filepath.Base(contextPath), contextInfo.Size(), - ).Scan(&fileID) + err = s.db.With(func(conn *sql.DB) error { + result, err := conn.Exec( + `INSERT INTO job_files (job_id, file_type, file_path, file_name, file_size) + VALUES (?, ?, ?, ?, ?)`, + jobID, types.JobFileTypeInput, contextPath, filepath.Base(contextPath), contextInfo.Size(), + ) + if err != nil { + return err + } + fileID, err = result.LastInsertId() + return err + }) if err != nil { log.Printf("ERROR: Failed to record context archive in database for job %d: %v", jobID, err) s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to record context archive: %v", err)) @@ -1386,15 +1428,18 @@ func (s *Server) handleUploadJobFile(w http.ResponseWriter, r *http.Request) { // Update job with metadata metadataJSON, err := json.Marshal(metadata) if err == nil { - _, err = s.db.Exec( - `UPDATE jobs SET blend_metadata = ? WHERE id = ?`, - string(metadataJSON), jobID, - ) - if err != nil { - log.Printf("Warning: Failed to update job metadata in database: %v", err) - } else { - log.Printf("Successfully extracted and stored metadata for job %d", jobID) - } + err = s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + `UPDATE jobs SET blend_metadata = ? WHERE id = ?`, + string(metadataJSON), jobID, + ) + if err != nil { + log.Printf("Warning: Failed to update job metadata in database: %v", err) + } else { + log.Printf("Successfully extracted and stored metadata for job %d", jobID) + } + return err + }) } else { log.Printf("Warning: Failed to marshal metadata: %v", err) } @@ -1456,6 +1501,25 @@ func (s *Server) handleUploadFileForJobCreation(w http.ResponseWriter, r *http.R s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create temporary directory: %v", err)) return } + + // Generate session ID (use temp directory path as session ID) + sessionID := tmpDir + + // Create upload session + s.uploadSessionsMu.Lock() + s.uploadSessions[sessionID] = &UploadSession{ + SessionID: sessionID, + UserID: userID, + Progress: 0.0, + Status: "uploading", + Message: "Uploading file...", + CreatedAt: time.Now(), + } + s.uploadSessionsMu.Unlock() + + // Broadcast initial upload status + s.broadcastUploadProgress(sessionID, 0.0, "uploading", "Uploading file...") + // Note: We'll clean this up after job creation or after timeout // For now, we rely on the session cleanup mechanism, but also add defer for safety defer func() { @@ -1486,7 +1550,11 @@ func (s *Server) handleUploadFileForJobCreation(w http.ResponseWriter, r *http.R } log.Printf("Successfully copied %d bytes to ZIP file", copied) + // Broadcast upload complete, processing starts + s.broadcastUploadProgress(sessionID, 100.0, "processing", "Upload complete, processing file...") + // Extract ZIP file to temporary directory + s.broadcastUploadProgress(sessionID, 25.0, "extracting_zip", "Extracting ZIP file...") extractedFiles, err = s.storage.ExtractZip(zipPath, tmpDir) if err != nil { os.RemoveAll(tmpDir) @@ -1494,6 +1562,7 @@ func (s *Server) handleUploadFileForJobCreation(w http.ResponseWriter, r *http.R return } log.Printf("Successfully extracted %d files from ZIP", len(extractedFiles)) + s.broadcastUploadProgress(sessionID, 50.0, "extracting_zip", "ZIP extraction complete") // Find main blend file mainBlendParam := r.FormValue("main_blend_file") @@ -1565,6 +1634,9 @@ func (s *Server) handleUploadFileForJobCreation(w http.ResponseWriter, r *http.R fileReader.Close() outFile.Close() + // Broadcast upload complete for non-ZIP files + s.broadcastUploadProgress(sessionID, 100.0, "processing", "Upload complete, processing file...") + if strings.HasSuffix(strings.ToLower(header.Filename), ".blend") { mainBlendFile = filePath } @@ -1577,16 +1649,19 @@ func (s *Server) handleUploadFileForJobCreation(w http.ResponseWriter, r *http.R } // Create context in temp directory (we'll move it to job directory later) + s.broadcastUploadProgress(sessionID, 75.0, "creating_context", "Creating context archive...") contextPath := filepath.Join(tmpDir, "context.tar") contextPath, err = s.createContextFromDir(tmpDir, contextPath, excludeFiles...) if err != nil { os.RemoveAll(tmpDir) log.Printf("ERROR: Failed to create context archive: %v", err) + s.broadcastUploadProgress(sessionID, 0.0, "error", fmt.Sprintf("Failed to create context archive: %v", err)) s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create context archive: %v", err)) return } // Extract metadata from context archive + s.broadcastUploadProgress(sessionID, 85.0, "extracting_metadata", "Extracting metadata from blend file...") metadata, err := s.extractMetadataFromTempContext(contextPath) if err != nil { log.Printf("Warning: Failed to extract metadata: %v", err) @@ -1594,10 +1669,6 @@ func (s *Server) handleUploadFileForJobCreation(w http.ResponseWriter, r *http.R metadata = nil } - // Generate a session ID to track this upload - // Store the full temp directory path as session ID for easy lookup - sessionID := tmpDir - response := map[string]interface{}{ "session_id": sessionID, // Full temp directory path "file_name": header.Filename, @@ -1624,6 +1695,17 @@ func (s *Server) handleUploadFileForJobCreation(w http.ResponseWriter, r *http.R response["metadata_extracted"] = false } + // Broadcast processing complete + s.broadcastUploadProgress(sessionID, 100.0, "completed", "Processing complete") + + // Clean up upload session after a delay (client may still be subscribed) + go func() { + time.Sleep(5 * time.Minute) + s.uploadSessionsMu.Lock() + delete(s.uploadSessions, sessionID) + s.uploadSessionsMu.Unlock() + }() + s.respondJSON(w, http.StatusOK, response) } @@ -1681,11 +1763,17 @@ func (s *Server) extractMetadataFromTempContext(contextPath string) (*types.Blen // Use the same extraction script and process as extractMetadataFromContext // (Copy the logic from extractMetadataFromContext but use tmpDir and blendFile) - return s.runBlenderMetadataExtraction(blendFile, tmpDir) + // Log stderr for debugging (not shown to user) + stderrCallback := func(line string) { + log.Printf("Blender stderr during metadata extraction: %s", line) + } + + return s.runBlenderMetadataExtraction(blendFile, tmpDir, stderrCallback) } // runBlenderMetadataExtraction runs Blender to extract metadata from a blend file -func (s *Server) runBlenderMetadataExtraction(blendFile, workDir string) (*types.BlendMetadata, error) { +// stderrCallback is optional and will be called for each stderr line (note: with RunCommand, this is called after completion) +func (s *Server) runBlenderMetadataExtraction(blendFile, workDir string, stderrCallback func(string)) (*types.BlendMetadata, error) { // Use embedded Python script scriptPath := filepath.Join(workDir, "extract_metadata.py") if err := os.WriteFile(scriptPath, []byte(scripts.ExtractMetadata), 0644); err != nil { @@ -1698,57 +1786,32 @@ func (s *Server) runBlenderMetadataExtraction(blendFile, workDir string) (*types return nil, fmt.Errorf("failed to get relative path for blend file: %w", err) } - // Execute Blender - cmd := exec.Command("blender", "-b", blendFileRel, "--python", "extract_metadata.py") - cmd.Dir = workDir + // Execute Blender using executils + result, err := executils.RunCommand( + "blender", + []string{"-b", blendFileRel, "--python", "extract_metadata.py"}, + workDir, + nil, // inherit environment + 0, // no task ID for metadata extraction + nil, // no process tracker needed + ) - stdoutPipe, err := cmd.StdoutPipe() - if err != nil { - return nil, fmt.Errorf("failed to create stdout pipe: %w", err) - } - - stderrPipe, err := cmd.StderrPipe() - if err != nil { - return nil, fmt.Errorf("failed to create stderr pipe: %w", err) - } - - var stdoutBuffer bytes.Buffer - - if err := cmd.Start(); err != nil { - return nil, fmt.Errorf("failed to start blender: %w", err) - } - - stdoutDone := make(chan bool) - go func() { - defer close(stdoutDone) - scanner := bufio.NewScanner(stdoutPipe) - for scanner.Scan() { - line := scanner.Text() - stdoutBuffer.WriteString(line) - stdoutBuffer.WriteString("\n") + // Forward stderr via callback if provided + if result != nil && stderrCallback != nil && result.Stderr != "" { + for _, line := range strings.Split(result.Stderr, "\n") { + if line != "" { + stderrCallback(line) + } } - }() - - // Capture stderr for error reporting - var stderrBuffer bytes.Buffer - stderrDone := make(chan bool) - go func() { - defer close(stderrDone) - scanner := bufio.NewScanner(stderrPipe) - for scanner.Scan() { - line := scanner.Text() - stderrBuffer.WriteString(line) - stderrBuffer.WriteString("\n") - } - }() - - err = cmd.Wait() - <-stdoutDone - <-stderrDone + } if err != nil { - stderrOutput := strings.TrimSpace(stderrBuffer.String()) - stdoutOutput := strings.TrimSpace(stdoutBuffer.String()) + stderrOutput := "" + stdoutOutput := "" + if result != nil { + stderrOutput = strings.TrimSpace(result.Stderr) + stdoutOutput = strings.TrimSpace(result.Stdout) + } log.Printf("Blender metadata extraction failed:") if stderrOutput != "" { log.Printf("Blender stderr: %s", stderrOutput) @@ -1762,7 +1825,7 @@ func (s *Server) runBlenderMetadataExtraction(blendFile, workDir string) (*types return nil, fmt.Errorf("blender metadata extraction failed: %w", err) } - metadataJSON := strings.TrimSpace(stdoutBuffer.String()) + metadataJSON := strings.TrimSpace(result.Stdout) jsonStart := strings.Index(metadataJSON, "{") jsonEnd := strings.LastIndex(metadataJSON, "}") if jsonStart == -1 || jsonEnd == -1 || jsonEnd <= jsonStart { @@ -1976,7 +2039,9 @@ func (s *Server) handleListJobFiles(w http.ResponseWriter, r *http.Request) { isAdmin := isAdminUser(r) if !isAdmin { var jobUserID int64 - err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -1992,7 +2057,9 @@ func (s *Server) handleListJobFiles(w http.ResponseWriter, r *http.Request) { } else { // Admin: verify job exists var exists bool - err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM jobs WHERE id = ?)", jobID).Scan(&exists) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM jobs WHERE id = ?)", jobID).Scan(&exists) + }) if err != nil || !exists { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -2035,30 +2102,38 @@ func (s *Server) handleListJobFiles(w http.ResponseWriter, r *http.Request) { query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" args = append(args, limit, offset) - rows, err := s.db.Query(query, args...) + var rows *sql.Rows + var total int + err = s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query(query, args...) + if err != nil { + return err + } + + // Get total count + countQuery := `SELECT COUNT(*) FROM job_files WHERE job_id = ?` + countArgs := []interface{}{jobID} + if fileTypeFilter != "" { + countQuery += " AND file_type = ?" + countArgs = append(countArgs, fileTypeFilter) + } + if extensionFilter != "" { + countQuery += " AND file_name LIKE ?" + countArgs = append(countArgs, "%."+extensionFilter) + } + err = conn.QueryRow(countQuery, countArgs...).Scan(&total) + if err != nil { + total = -1 + } + return nil + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query files: %v", err)) return } defer rows.Close() - // Get total count - var total int - countQuery := `SELECT COUNT(*) FROM job_files WHERE job_id = ?` - countArgs := []interface{}{jobID} - if fileTypeFilter != "" { - countQuery += " AND file_type = ?" - countArgs = append(countArgs, fileTypeFilter) - } - if extensionFilter != "" { - countQuery += " AND file_name LIKE ?" - countArgs = append(countArgs, "%."+extensionFilter) - } - err = s.db.QueryRow(countQuery, countArgs...).Scan(&total) - if err != nil { - total = -1 - } - files := []types.JobFile{} for rows.Next() { var file types.JobFile @@ -2100,7 +2175,9 @@ func (s *Server) handleGetJobFilesCount(w http.ResponseWriter, r *http.Request) isAdmin := isAdminUser(r) if !isAdmin { var jobUserID int64 - err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -2115,7 +2192,9 @@ func (s *Server) handleGetJobFilesCount(w http.ResponseWriter, r *http.Request) } } else { var exists bool - err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM jobs WHERE id = ?)", jobID).Scan(&exists) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM jobs WHERE id = ?)", jobID).Scan(&exists) + }) if err != nil || !exists { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -2133,7 +2212,9 @@ func (s *Server) handleGetJobFilesCount(w http.ResponseWriter, r *http.Request) args = append(args, fileTypeFilter) } - err = s.db.QueryRow(query, args...).Scan(&count) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow(query, args...).Scan(&count) + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to count files: %v", err)) return @@ -2161,7 +2242,9 @@ func (s *Server) handleListContextArchive(w http.ResponseWriter, r *http.Request isAdmin := isAdminUser(r) if !isAdmin { var jobUserID int64 - err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -2352,7 +2435,9 @@ func (s *Server) handleDownloadJobFile(w http.ResponseWriter, r *http.Request) { isAdmin := isAdminUser(r) if !isAdmin { var jobUserID int64 - err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -2368,7 +2453,9 @@ func (s *Server) handleDownloadJobFile(w http.ResponseWriter, r *http.Request) { } else { // Admin: verify job exists var exists bool - err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM jobs WHERE id = ?)", jobID).Scan(&exists) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM jobs WHERE id = ?)", jobID).Scan(&exists) + }) if err != nil || !exists { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -2377,10 +2464,12 @@ func (s *Server) handleDownloadJobFile(w http.ResponseWriter, r *http.Request) { // Get file info var filePath, fileName string - err = s.db.QueryRow( - `SELECT file_path, file_name FROM job_files WHERE id = ? AND job_id = ?`, - fileID, jobID, - ).Scan(&filePath, &fileName) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + `SELECT file_path, file_name FROM job_files WHERE id = ? AND job_id = ?`, + fileID, jobID, + ).Scan(&filePath, &fileName) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "File not found") return @@ -2450,11 +2539,13 @@ func (s *Server) handleStreamVideo(w http.ResponseWriter, r *http.Request) { isAdmin := isAdminUser(r) var jobUserID int64 var outputFormat string - if isAdmin { - err = s.db.QueryRow("SELECT user_id, output_format FROM jobs WHERE id = ?", jobID).Scan(&jobUserID, &outputFormat) - } else { - err = s.db.QueryRow("SELECT user_id, output_format FROM jobs WHERE id = ? AND user_id = ?", jobID, userID).Scan(&jobUserID, &outputFormat) - } + err = s.db.With(func(conn *sql.DB) error { + if isAdmin { + return conn.QueryRow("SELECT user_id, output_format FROM jobs WHERE id = ?", jobID).Scan(&jobUserID, &outputFormat) + } else { + return conn.QueryRow("SELECT user_id, output_format FROM jobs WHERE id = ? AND user_id = ?", jobID, userID).Scan(&jobUserID, &outputFormat) + } + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -2470,12 +2561,14 @@ func (s *Server) handleStreamVideo(w http.ResponseWriter, r *http.Request) { // Find MP4 file var filePath, fileName string - err = s.db.QueryRow( - `SELECT file_path, file_name FROM job_files + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + `SELECT file_path, file_name FROM job_files WHERE job_id = ? AND file_type = ? AND file_name LIKE '%.mp4' ORDER BY created_at DESC LIMIT 1`, - jobID, types.JobFileTypeOutput, - ).Scan(&filePath, &fileName) + jobID, types.JobFileTypeOutput, + ).Scan(&filePath, &fileName) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Video file not found") return @@ -2551,7 +2644,9 @@ func (s *Server) handleListJobTasks(w http.ResponseWriter, r *http.Request) { isAdmin := isAdminUser(r) if !isAdmin { var jobUserID int64 - err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -2567,7 +2662,9 @@ func (s *Server) handleListJobTasks(w http.ResponseWriter, r *http.Request) { } else { // Admin: verify job exists var exists bool - err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM jobs WHERE id = ?)", jobID).Scan(&exists) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM jobs WHERE id = ?)", jobID).Scan(&exists) + }) if err != nil || !exists { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -2650,39 +2747,47 @@ func (s *Server) handleListJobTasks(w http.ResponseWriter, r *http.Request) { query += fmt.Sprintf(" ORDER BY %s %s LIMIT ? OFFSET ?", sortField, sortDir) args = append(args, limit, offset) - rows, err := s.db.Query(query, args...) + var rows *sql.Rows + var total int + err = s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query(query, args...) + if err != nil { + return err + } + + // Get total count + countQuery := `SELECT COUNT(*) FROM tasks WHERE job_id = ?` + countArgs := []interface{}{jobID} + if statusFilter != "" { + statuses := strings.Split(statusFilter, ",") + placeholders := make([]string, len(statuses)) + for i, status := range statuses { + placeholders[i] = "?" + countArgs = append(countArgs, strings.TrimSpace(status)) + } + countQuery += fmt.Sprintf(" AND status IN (%s)", strings.Join(placeholders, ",")) + } + if frameStartFilter != "" { + if fs, err := strconv.Atoi(frameStartFilter); err == nil { + countQuery += " AND frame_start >= ?" + countArgs = append(countArgs, fs) + } + } + if frameEndFilter != "" { + if fe, err := strconv.Atoi(frameEndFilter); err == nil { + countQuery += " AND frame_end <= ?" + countArgs = append(countArgs, fe) + } + } + err = conn.QueryRow(countQuery, countArgs...).Scan(&total) + return err + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query tasks: %v", err)) return } defer rows.Close() - - // Get total count - var total int - countQuery := `SELECT COUNT(*) FROM tasks WHERE job_id = ?` - countArgs := []interface{}{jobID} - if statusFilter != "" { - statuses := strings.Split(statusFilter, ",") - placeholders := make([]string, len(statuses)) - for i, status := range statuses { - placeholders[i] = "?" - countArgs = append(countArgs, strings.TrimSpace(status)) - } - countQuery += fmt.Sprintf(" AND status IN (%s)", strings.Join(placeholders, ",")) - } - if frameStartFilter != "" { - if fs, err := strconv.Atoi(frameStartFilter); err == nil { - countQuery += " AND frame_start >= ?" - countArgs = append(countArgs, fs) - } - } - if frameEndFilter != "" { - if fe, err := strconv.Atoi(frameEndFilter); err == nil { - countQuery += " AND frame_end <= ?" - countArgs = append(countArgs, fe) - } - } - err = s.db.QueryRow(countQuery, countArgs...).Scan(&total) if err != nil { total = -1 } @@ -2771,7 +2876,9 @@ func (s *Server) handleListJobTasksSummary(w http.ResponseWriter, r *http.Reques isAdmin := isAdminUser(r) if !isAdmin { var jobUserID int64 - err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -2786,7 +2893,9 @@ func (s *Server) handleListJobTasksSummary(w http.ResponseWriter, r *http.Reques } } else { var exists bool - err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM jobs WHERE id = ?)", jobID).Scan(&exists) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM jobs WHERE id = ?)", jobID).Scan(&exists) + }) if err != nil || !exists { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -2794,9 +2903,9 @@ func (s *Server) handleListJobTasksSummary(w http.ResponseWriter, r *http.Reques } // Parse query parameters - limit := 100 + limit := 0 // 0 means unlimited if limitStr := r.URL.Query().Get("limit"); limitStr != "" { - if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 5000 { + if l, err := strconv.Atoi(limitStr); err == nil && l > 0 { limit = l } } @@ -2846,34 +2955,51 @@ func (s *Server) handleListJobTasksSummary(w http.ResponseWriter, r *http.Reques query += fmt.Sprintf(" AND status IN (%s)", strings.Join(placeholders, ",")) } - query += fmt.Sprintf(" ORDER BY %s %s LIMIT ? OFFSET ?", sortField, sortDir) - args = append(args, limit, offset) + query += fmt.Sprintf(" ORDER BY %s %s", sortField, sortDir) + if limit > 0 { + query += " LIMIT ? OFFSET ?" + args = append(args, limit, offset) + } else { + // Unlimited - only apply offset if specified + if offset > 0 { + query += " OFFSET ?" + args = append(args, offset) + } + } - rows, err := s.db.Query(query, args...) + var rows *sql.Rows + var total int + err = s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query(query, args...) + if err != nil { + return err + } + + // Get total count + countQuery := `SELECT COUNT(*) FROM tasks WHERE job_id = ?` + countArgs := []interface{}{jobID} + if statusFilter != "" { + statuses := strings.Split(statusFilter, ",") + placeholders := make([]string, len(statuses)) + for i, status := range statuses { + placeholders[i] = "?" + countArgs = append(countArgs, strings.TrimSpace(status)) + } + countQuery += fmt.Sprintf(" AND status IN (%s)", strings.Join(placeholders, ",")) + } + err = conn.QueryRow(countQuery, countArgs...).Scan(&total) + if err != nil { + total = -1 + } + return nil + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query tasks: %v", err)) return } defer rows.Close() - // Get total count - var total int - countQuery := `SELECT COUNT(*) FROM tasks WHERE job_id = ?` - countArgs := []interface{}{jobID} - if statusFilter != "" { - statuses := strings.Split(statusFilter, ",") - placeholders := make([]string, len(statuses)) - for i, status := range statuses { - placeholders[i] = "?" - countArgs = append(countArgs, strings.TrimSpace(status)) - } - countQuery += fmt.Sprintf(" AND status IN (%s)", strings.Join(placeholders, ",")) - } - err = s.db.QueryRow(countQuery, countArgs...).Scan(&total) - if err != nil { - total = -1 - } - type TaskSummary struct { ID int64 `json:"id"` FrameStart int `json:"frame_start"` @@ -2931,7 +3057,9 @@ func (s *Server) handleBatchGetTasks(w http.ResponseWriter, r *http.Request) { isAdmin := isAdminUser(r) if !isAdmin { var jobUserID int64 - err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -2946,7 +3074,9 @@ func (s *Server) handleBatchGetTasks(w http.ResponseWriter, r *http.Request) { } } else { var exists bool - err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM jobs WHERE id = ?)", jobID).Scan(&exists) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM jobs WHERE id = ?)", jobID).Scan(&exists) + }) if err != nil || !exists { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -2985,7 +3115,12 @@ func (s *Server) handleBatchGetTasks(w http.ResponseWriter, r *http.Request) { completed_at, error_message, timeout_seconds FROM tasks WHERE job_id = ? AND id IN (%s) ORDER BY frame_start ASC`, strings.Join(placeholders, ",")) - rows, err := s.db.Query(query, args...) + var rows *sql.Rows + err = s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query(query, args...) + return err + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query tasks: %v", err)) return @@ -3066,7 +3201,9 @@ func (s *Server) handleGetTaskLogs(w http.ResponseWriter, r *http.Request) { isAdmin := isAdminUser(r) if !isAdmin { var jobUserID int64 - err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -3082,7 +3219,9 @@ func (s *Server) handleGetTaskLogs(w http.ResponseWriter, r *http.Request) { } else { // Admin: verify job exists var exists bool - err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM jobs WHERE id = ?)", jobID).Scan(&exists) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM jobs WHERE id = ?)", jobID).Scan(&exists) + }) if err != nil || !exists { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -3091,7 +3230,9 @@ func (s *Server) handleGetTaskLogs(w http.ResponseWriter, r *http.Request) { // Verify task belongs to job var taskJobID int64 - err = s.db.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&taskJobID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&taskJobID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Task not found") return @@ -3141,7 +3282,12 @@ func (s *Server) handleGetTaskLogs(w http.ResponseWriter, r *http.Request) { query += " ORDER BY id ASC LIMIT ?" args = append(args, limit) - rows, err := s.db.Query(query, args...) + var rows *sql.Rows + err = s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query(query, args...) + return err + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query logs: %v", err)) return @@ -3204,7 +3350,9 @@ func (s *Server) handleGetTaskSteps(w http.ResponseWriter, r *http.Request) { isAdmin := isAdminUser(r) if !isAdmin { var jobUserID int64 - err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -3220,7 +3368,9 @@ func (s *Server) handleGetTaskSteps(w http.ResponseWriter, r *http.Request) { } else { // Admin: verify job exists var exists bool - err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM jobs WHERE id = ?)", jobID).Scan(&exists) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM jobs WHERE id = ?)", jobID).Scan(&exists) + }) if err != nil || !exists { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -3229,7 +3379,9 @@ func (s *Server) handleGetTaskSteps(w http.ResponseWriter, r *http.Request) { // Verify task belongs to job var taskJobID int64 - err = s.db.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&taskJobID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&taskJobID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Task not found") return @@ -3243,11 +3395,16 @@ func (s *Server) handleGetTaskSteps(w http.ResponseWriter, r *http.Request) { return } - rows, err := s.db.Query( - `SELECT id, task_id, step_name, status, started_at, completed_at, duration_ms, error_message + var rows *sql.Rows + err = s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query( + `SELECT id, task_id, step_name, status, started_at, completed_at, duration_ms, error_message FROM task_steps WHERE task_id = ? ORDER BY started_at ASC`, - taskID, - ) + taskID, + ) + return err + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query steps: %v", err)) return @@ -3309,7 +3466,9 @@ func (s *Server) handleRetryTask(w http.ResponseWriter, r *http.Request) { // Verify job belongs to user var jobUserID int64 - err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -3327,10 +3486,12 @@ func (s *Server) handleRetryTask(w http.ResponseWriter, r *http.Request) { var taskJobID int64 var taskStatus string var retryCount, maxRetries int - err = s.db.QueryRow( - "SELECT job_id, status, retry_count, max_retries FROM tasks WHERE id = ?", - taskID, - ).Scan(&taskJobID, &taskStatus, &retryCount, &maxRetries) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + "SELECT job_id, status, retry_count, max_retries FROM tasks WHERE id = ?", + taskID, + ).Scan(&taskJobID, &taskStatus, &retryCount, &maxRetries) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Task not found") return @@ -3355,12 +3516,15 @@ func (s *Server) handleRetryTask(w http.ResponseWriter, r *http.Request) { } // Reset task to pending - _, err = s.db.Exec( - `UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL, + err = s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + `UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL, error_message = NULL, started_at = NULL, completed_at = NULL WHERE id = ?`, - types.TaskStatusPending, taskID, - ) + types.TaskStatusPending, taskID, + ) + return err + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to retry task: %v", err)) return @@ -3400,7 +3564,9 @@ func (s *Server) handleStreamTaskLogsWebSocket(w http.ResponseWriter, r *http.Re // Verify job belongs to user var jobUserID int64 - err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -3416,7 +3582,9 @@ func (s *Server) handleStreamTaskLogsWebSocket(w http.ResponseWriter, r *http.Re // Verify task belongs to job var taskJobID int64 - err = s.db.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&taskJobID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&taskJobID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Task not found") return @@ -3481,11 +3649,16 @@ func (s *Server) handleStreamTaskLogsWebSocket(w http.ResponseWriter, r *http.Re // Send existing logs // Order by id ASC to ensure consistent ordering and avoid race conditions - rows, err := s.db.Query( - `SELECT id, task_id, runner_id, log_level, message, step_name, created_at + var rows *sql.Rows + err = s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query( + `SELECT id, task_id, runner_id, log_level, message, step_name, created_at FROM task_logs WHERE task_id = ? AND id > ? ORDER BY id ASC LIMIT 100`, - taskID, lastID, - ) + taskID, lastID, + ) + return err + }) if err == nil { defer rows.Close() for rows.Next() { @@ -3532,47 +3705,693 @@ func (s *Server) handleStreamTaskLogsWebSocket(w http.ResponseWriter, r *http.Re case <-ctx.Done(): return case <-ticker.C: - rows, err := s.db.Query( - `SELECT id, task_id, runner_id, log_level, message, step_name, created_at + var logs []types.TaskLog + err := s.db.With(func(dbConn *sql.DB) error { + rows, err := dbConn.Query( + `SELECT id, task_id, runner_id, log_level, message, step_name, created_at FROM task_logs WHERE task_id = ? AND id > ? ORDER BY id ASC LIMIT 100`, - taskID, lastID, - ) + taskID, lastID, + ) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var log types.TaskLog + var runnerID sql.NullInt64 + err := rows.Scan( + &log.ID, &log.TaskID, &runnerID, &log.LogLevel, &log.Message, + &log.StepName, &log.CreatedAt, + ) + if err != nil { + continue + } + if runnerID.Valid { + log.RunnerID = &runnerID.Int64 + } + lastID = log.ID + logs = append(logs, log) + } + return nil + }) if err != nil { continue } + // Send logs to client (outside With callback to access websocket conn) + for _, log := range logs { + msg := map[string]interface{}{ + "type": "log", + "task_id": taskID, + "data": log, + "timestamp": time.Now().Unix(), + } + writeMu.Lock() + writeErr := conn.WriteJSON(msg) + writeMu.Unlock() + if writeErr != nil { + return + } + } + } + } +} +// handleClientWebSocket handles the unified client WebSocket connection with subscription protocol +func (s *Server) handleClientWebSocket(w http.ResponseWriter, r *http.Request) { + userID, err := getUserID(r) + if err != nil { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // Check if user is admin + isAdmin := isAdminUser(r) + + // Upgrade to WebSocket + conn, err := s.wsUpgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("Failed to upgrade WebSocket: %v", err) + return + } + defer conn.Close() + + // Create client connection + clientConn := &ClientConnection{ + Conn: conn, + UserID: userID, + IsAdmin: isAdmin, + Subscriptions: make(map[string]bool), + WriteMu: &sync.Mutex{}, + } + + // Register connection + // Fix race condition: Close old connection BEFORE registering new one + var oldConn *ClientConnection + s.clientConnsMu.Lock() + if existingConn, exists := s.clientConns[userID]; exists && existingConn != nil { + oldConn = existingConn + } + s.clientConnsMu.Unlock() + + // Close old connection BEFORE registering new one to prevent race conditions + if oldConn != nil { + log.Printf("handleClientWebSocket: Closing existing connection for user %d", userID) + oldConn.Conn.Close() + } + + // Now register the new connection + s.clientConnsMu.Lock() + s.clientConns[userID] = clientConn + s.clientConnsMu.Unlock() + log.Printf("handleClientWebSocket: Registered client connection for user %d", userID) + + defer func() { + s.clientConnsMu.Lock() + // Only remove if this is still the current connection (not replaced by a newer one) + if existingConn, exists := s.clientConns[userID]; exists && existingConn == clientConn { + delete(s.clientConns, userID) + log.Printf("handleClientWebSocket: Removed client connection for user %d", userID) + } else { + log.Printf("handleClientWebSocket: Skipping removal for user %d (connection was replaced)", userID) + } + s.clientConnsMu.Unlock() + }() + + // Send initial connection message + clientConn.WriteMu.Lock() + err = conn.WriteJSON(map[string]interface{}{ + "type": "connected", + "timestamp": time.Now().Unix(), + }) + clientConn.WriteMu.Unlock() + if err != nil { + log.Printf("Failed to send initial connection message: %v", err) + return + } + + // Set up ping/pong + conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Increased timeout + conn.SetPongHandler(func(string) error { + conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Reset deadline on pong + return nil + }) + + // Start ping ticker (send ping every 30 seconds) + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + // Message handling channel - increased buffer size to prevent blocking + messageChan := make(chan map[string]interface{}, 100) + + // Read messages in background + readDone := make(chan struct{}) + go func() { + defer close(readDone) + for { + conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Increased timeout + messageType, message, err := conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Printf("WebSocket read error for client %d: %v", userID, err) + } else { + log.Printf("WebSocket read error for client %d (expected close): %v", userID, err) + } + return + } + + // Handle control frames (pong, ping, close) + if messageType == websocket.PongMessage { + // Pong received - connection is alive, reset deadline + conn.SetReadDeadline(time.Now().Add(90 * time.Second)) + continue + } + if messageType == websocket.PingMessage { + // Respond to ping with pong + conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + conn.WriteMessage(websocket.PongMessage, message) + conn.SetReadDeadline(time.Now().Add(90 * time.Second)) + continue + } + if messageType != websocket.TextMessage { + // Skip non-text messages + continue + } + + // Parse JSON message + var msg map[string]interface{} + if err := json.Unmarshal(message, &msg); err != nil { + log.Printf("Failed to parse JSON message from client %d: %v", userID, err) + continue + } + messageChan <- msg + conn.SetReadDeadline(time.Now().Add(90 * time.Second)) + } + }() + + ctx := r.Context() + for { + select { + case <-ctx.Done(): + log.Printf("handleClientWebSocket: Context cancelled for user %d", userID) + return + case <-readDone: + log.Printf("handleClientWebSocket: Read done for user %d", userID) + return + case msg := <-messageChan: + s.handleClientMessage(clientConn, msg) + case <-ticker.C: + // Reset read deadline before sending ping to ensure we can receive pong + conn.SetReadDeadline(time.Now().Add(90 * time.Second)) + clientConn.WriteMu.Lock() + // Use WriteControl for ping frames (control frames) + if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { + log.Printf("handleClientWebSocket: Ping failed for user %d: %v", userID, err) + clientConn.WriteMu.Unlock() + return + } + clientConn.WriteMu.Unlock() + } + } +} + +// handleClientMessage processes messages from client WebSocket +func (s *Server) handleClientMessage(clientConn *ClientConnection, msg map[string]interface{}) { + msgType, ok := msg["type"].(string) + if !ok { + return + } + + switch msgType { + case "subscribe": + channel, ok := msg["channel"].(string) + if !ok { + // Send error for invalid channel format + clientConn.WriteMu.Lock() + if err := clientConn.Conn.WriteJSON(map[string]interface{}{ + "type": "subscription_error", + "channel": channel, + "error": "Invalid channel format", + }); err != nil { + log.Printf("Failed to send subscription_error to client %d: %v", clientConn.UserID, err) + } + clientConn.WriteMu.Unlock() + return + } + // Check if already subscribed + clientConn.SubsMu.Lock() + alreadySubscribed := clientConn.Subscriptions[channel] + clientConn.SubsMu.Unlock() + + if alreadySubscribed { + // Already subscribed - just send confirmation, don't send initial state again + if s.verboseWSLogging { + log.Printf("Client %d already subscribed to channel: %s (skipping initial state)", clientConn.UserID, channel) + } + clientConn.WriteMu.Lock() + if err := clientConn.Conn.WriteJSON(map[string]interface{}{ + "type": "subscribed", + "channel": channel, + }); err != nil { + log.Printf("Failed to send subscribed confirmation to client %d: %v", clientConn.UserID, err) + } + clientConn.WriteMu.Unlock() + return + } + + // Validate channel access + if s.canSubscribe(clientConn, channel) { + clientConn.SubsMu.Lock() + clientConn.Subscriptions[channel] = true + clientConn.SubsMu.Unlock() + if s.verboseWSLogging { + log.Printf("Client %d subscribed to channel: %s", clientConn.UserID, channel) + } + // Send success confirmation + clientConn.WriteMu.Lock() + if err := clientConn.Conn.WriteJSON(map[string]interface{}{ + "type": "subscribed", + "channel": channel, + }); err != nil { + log.Printf("Failed to send subscribed confirmation to client %d: %v", clientConn.UserID, err) + clientConn.WriteMu.Unlock() + return + } + clientConn.WriteMu.Unlock() + // Send initial state for the subscribed channel (only on first subscription) + go s.sendInitialState(clientConn, channel) + } else { + // Subscription failed - send error to client + log.Printf("Client %d failed to subscribe to channel: %s (job may not exist or access denied)", clientConn.UserID, channel) + clientConn.WriteMu.Lock() + if err := clientConn.Conn.WriteJSON(map[string]interface{}{ + "type": "subscription_error", + "channel": channel, + "error": "Channel not found or access denied", + }); err != nil { + log.Printf("Failed to send subscription_error to client %d: %v", clientConn.UserID, err) + } + clientConn.WriteMu.Unlock() + } + case "unsubscribe": + channel, ok := msg["channel"].(string) + if !ok { + return + } + clientConn.SubsMu.Lock() + delete(clientConn.Subscriptions, channel) + clientConn.SubsMu.Unlock() + if s.verboseWSLogging { + log.Printf("Client %d unsubscribed from channel: %s", clientConn.UserID, channel) + } + } +} + +// canSubscribe checks if a client can subscribe to a channel +func (s *Server) canSubscribe(clientConn *ClientConnection, channel string) bool { + // Always allow jobs channel (always broadcasted, but subscription doesn't hurt) + if channel == "jobs" { + return true + } + + // Check channel format + if strings.HasPrefix(channel, "job:") { + // Extract job ID + jobIDStr := strings.TrimPrefix(channel, "job:") + jobID, err := strconv.ParseInt(jobIDStr, 10, 64) + if err != nil { + return false + } + // Verify job belongs to user (unless admin) + if clientConn.IsAdmin { + var exists bool + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM jobs WHERE id = ?)", jobID).Scan(&exists) + }) + return err == nil && exists + } + var jobUserID int64 + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + }) + return err == nil && jobUserID == clientConn.UserID + } + + if strings.HasPrefix(channel, "logs:") { + // Format: logs:jobId:taskId + parts := strings.Split(channel, ":") + if len(parts) != 3 { + return false + } + jobID, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return false + } + // Verify job belongs to user (unless admin) + if clientConn.IsAdmin { + var exists bool + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM jobs WHERE id = ?)", jobID).Scan(&exists) + }) + return err == nil && exists + } + var jobUserID int64 + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + }) + return err == nil && jobUserID == clientConn.UserID + } + + if strings.HasPrefix(channel, "upload:") { + // Format: upload:sessionId + sessionID := strings.TrimPrefix(channel, "upload:") + s.uploadSessionsMu.RLock() + session, exists := s.uploadSessions[sessionID] + s.uploadSessionsMu.RUnlock() + // Verify session belongs to user + return exists && session.UserID == clientConn.UserID + } + + if channel == "runners" { + // Only admins can subscribe to runners + return clientConn.IsAdmin + } + + return false +} + +// sendInitialState sends the current state when a client subscribes to a channel +func (s *Server) sendInitialState(clientConn *ClientConnection, channel string) { + // Use a shorter write deadline for initial state to avoid blocking too long + // If the connection is slow/dead, we want to fail fast + writeTimeout := 5 * time.Second + + // Check if connection is still valid before starting + clientConn.WriteMu.Lock() + // Set a reasonable write deadline + clientConn.Conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + clientConn.WriteMu.Unlock() + + if strings.HasPrefix(channel, "job:") { + // Send initial job state + jobIDStr := strings.TrimPrefix(channel, "job:") + jobID, err := strconv.ParseInt(jobIDStr, 10, 64) + if err != nil { + return + } + + // Get job from database + var job types.Job + var jobType string + var startedAt, completedAt sql.NullTime + var blendMetadataJSON sql.NullString + var errorMessage sql.NullString + var frameStart, frameEnd sql.NullInt64 + var outputFormat sql.NullString + var allowParallelRunners sql.NullBool + + query := "SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format, allow_parallel_runners, blend_metadata, created_at, started_at, completed_at, error_message FROM jobs WHERE id = ?" + if !clientConn.IsAdmin { + query += " AND user_id = ?" + } + + var err2 error + err2 = s.db.With(func(conn *sql.DB) error { + if clientConn.IsAdmin { + return conn.QueryRow(query, jobID).Scan( + &job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress, + &frameStart, &frameEnd, &outputFormat, &allowParallelRunners, + &blendMetadataJSON, &job.CreatedAt, &startedAt, &completedAt, &errorMessage, + ) + } else { + return conn.QueryRow(query, jobID, clientConn.UserID).Scan( + &job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress, + &frameStart, &frameEnd, &outputFormat, &allowParallelRunners, + &blendMetadataJSON, &job.CreatedAt, &startedAt, &completedAt, &errorMessage, + ) + } + }) + + if err2 != nil { + return + } + + if frameStart.Valid { + fs := int(frameStart.Int64) + job.FrameStart = &fs + } + if frameEnd.Valid { + fe := int(frameEnd.Int64) + job.FrameEnd = &fe + } + if outputFormat.Valid { + of := outputFormat.String + job.OutputFormat = &of + } + if allowParallelRunners.Valid { + apr := allowParallelRunners.Bool + job.AllowParallelRunners = &apr + } + if startedAt.Valid { + job.StartedAt = &startedAt.Time + } + if completedAt.Valid { + job.CompletedAt = &completedAt.Time + } + if errorMessage.Valid { + job.ErrorMessage = errorMessage.String + } + + // Send job_update with full job data + clientConn.WriteMu.Lock() + clientConn.Conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + writeErr := clientConn.Conn.WriteJSON(map[string]interface{}{ + "type": "job_update", + "channel": channel, + "job_id": jobID, + "data": job, + "timestamp": time.Now().Unix(), + }) + clientConn.WriteMu.Unlock() + if writeErr != nil { + log.Printf("Failed to send initial job_update to client %d: %v", clientConn.UserID, writeErr) + return + } + + // Get and send tasks (no limit - send all) + err = s.db.With(func(conn *sql.DB) error { + rows, err2 := conn.Query( + `SELECT id, job_id, runner_id, frame_start, frame_end, status, task_type, + current_step, retry_count, max_retries, output_path, created_at, started_at, + completed_at, error_message, timeout_seconds + FROM tasks WHERE job_id = ? ORDER BY frame_start ASC`, + jobID, + ) + if err2 != nil { + return err2 + } + defer rows.Close() for rows.Next() { - var log types.TaskLog + var task types.Task + var runnerID sql.NullInt64 + var startedAt, completedAt sql.NullTime + var timeoutSeconds sql.NullInt64 + var errorMessage sql.NullString + var currentStep sql.NullString + var outputPath sql.NullString + + err := rows.Scan( + &task.ID, &task.JobID, &runnerID, &task.FrameStart, &task.FrameEnd, + &task.Status, &task.TaskType, ¤tStep, &task.RetryCount, + &task.MaxRetries, &outputPath, &task.CreatedAt, &startedAt, + &completedAt, &errorMessage, &timeoutSeconds, + ) + if err != nil { + continue + } + + if runnerID.Valid { + task.RunnerID = &runnerID.Int64 + } + if startedAt.Valid { + task.StartedAt = &startedAt.Time + } + if completedAt.Valid { + task.CompletedAt = &completedAt.Time + } + if timeoutSeconds.Valid { + timeout := int(timeoutSeconds.Int64) + task.TimeoutSeconds = &timeout + } + if errorMessage.Valid { + task.ErrorMessage = errorMessage.String + } + if currentStep.Valid { + task.CurrentStep = currentStep.String + } + if outputPath.Valid { + task.OutputPath = outputPath.String + } + + // Send task_update + clientConn.WriteMu.Lock() + clientConn.Conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + writeErr := clientConn.Conn.WriteJSON(map[string]interface{}{ + "type": "task_update", + "channel": channel, + "job_id": jobID, + "task_id": task.ID, + "data": task, + "timestamp": time.Now().Unix(), + }) + clientConn.WriteMu.Unlock() + if writeErr != nil { + log.Printf("Failed to send initial task_update to client %d: %v", clientConn.UserID, writeErr) + // Connection is likely closed, stop sending more messages + break + } + } + return nil + }) + + } else if strings.HasPrefix(channel, "logs:") { + // Send initial logs for the task + parts := strings.Split(channel, ":") + if len(parts) != 3 { + return + } + jobID, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return + } + taskID, err := strconv.ParseInt(parts[2], 10, 64) + if err != nil { + return + } + + // Get existing logs (no limit - send all) + err = s.db.With(func(conn *sql.DB) error { + rows, err2 := conn.Query( + `SELECT id, task_id, runner_id, log_level, message, step_name, created_at + FROM task_logs WHERE task_id = ? ORDER BY id ASC`, + taskID, + ) + if err2 != nil { + return err2 + } + defer rows.Close() + for rows.Next() { + var taskLog types.TaskLog var runnerID sql.NullInt64 err := rows.Scan( - &log.ID, &log.TaskID, &runnerID, &log.LogLevel, &log.Message, - &log.StepName, &log.CreatedAt, + &taskLog.ID, &taskLog.TaskID, &runnerID, &taskLog.LogLevel, &taskLog.Message, + &taskLog.StepName, &taskLog.CreatedAt, ) if err != nil { continue } if runnerID.Valid { - log.RunnerID = &runnerID.Int64 - } - // Always update lastID to the highest ID we've seen - if log.ID > lastID { - lastID = log.ID + taskLog.RunnerID = &runnerID.Int64 } - // Serialize writes to prevent concurrent write panics - writeMu.Lock() - err = conn.WriteJSON(map[string]interface{}{ + // Send log + clientConn.WriteMu.Lock() + clientConn.Conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + writeErr := clientConn.Conn.WriteJSON(map[string]interface{}{ "type": "log", - "data": log, + "channel": channel, + "task_id": taskID, + "job_id": jobID, + "data": taskLog, "timestamp": time.Now().Unix(), }) - writeMu.Unlock() - if err != nil { - // Connection closed, exit the loop - return + clientConn.WriteMu.Unlock() + if writeErr != nil { + log.Printf("Failed to send initial log to client %d: %v", clientConn.UserID, writeErr) + // Connection is likely closed, stop sending more messages + break } } - rows.Close() + return nil + }) + + } else if channel == "runners" { + // Send initial runner list (only for admins) + if !clientConn.IsAdmin { + return + } + + s.db.With(func(conn *sql.DB) error { + rows, err2 := conn.Query( + `SELECT id, name, hostname, ip_address, status, last_heartbeat, capabilities, priority, created_at + FROM runners ORDER BY id ASC`, + ) + if err2 != nil { + return err2 + } + defer rows.Close() + for rows.Next() { + var runner types.Runner + err := rows.Scan( + &runner.ID, &runner.Name, &runner.Hostname, &runner.IPAddress, + &runner.Status, &runner.LastHeartbeat, &runner.Capabilities, + &runner.Priority, &runner.CreatedAt, + ) + if err != nil { + continue + } + + // Send runner_status + clientConn.WriteMu.Lock() + clientConn.Conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + writeErr := clientConn.Conn.WriteJSON(map[string]interface{}{ + "type": "runner_status", + "channel": channel, + "runner_id": runner.ID, + "data": runner, + "timestamp": time.Now().Unix(), + }) + clientConn.WriteMu.Unlock() + if writeErr != nil { + log.Printf("Failed to send initial runner_status to client %d: %v", clientConn.UserID, writeErr) + // Connection is likely closed, stop sending more messages + break + } + } + return nil + }) + + } else if strings.HasPrefix(channel, "upload:") { + // Send initial upload session state + sessionID := strings.TrimPrefix(channel, "upload:") + s.uploadSessionsMu.RLock() + session, exists := s.uploadSessions[sessionID] + s.uploadSessionsMu.RUnlock() + + if exists && session.UserID == clientConn.UserID { + msgType := "upload_progress" + if session.Status != "uploading" { + msgType = "processing_status" + } + + clientConn.WriteMu.Lock() + clientConn.Conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + writeErr := clientConn.Conn.WriteJSON(map[string]interface{}{ + "type": msgType, + "channel": channel, + "session_id": sessionID, + "data": map[string]interface{}{ + "progress": session.Progress, + "status": session.Status, + "message": session.Message, + }, + "timestamp": time.Now().Unix(), + }) + clientConn.WriteMu.Unlock() + if writeErr != nil { + log.Printf("Failed to send initial upload state to client %d: %v", clientConn.UserID, writeErr) + return + } } } } @@ -3657,8 +4476,10 @@ func (s *Server) handleJobsWebSocket(w http.ResponseWriter, r *http.Request) { // Read loop exited, close connection return case <-ticker.C: - conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) - if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { + // Reset read deadline before sending ping to ensure we can receive pong + conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + // Use WriteControl for ping frames (control frames) + if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { return } } @@ -3683,7 +4504,9 @@ func (s *Server) handleJobWebSocket(w http.ResponseWriter, r *http.Request) { isAdmin := isAdminUser(r) if !isAdmin { var jobUserID int64 - err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -3698,7 +4521,9 @@ func (s *Server) handleJobWebSocket(w http.ResponseWriter, r *http.Request) { } } else { var exists bool - err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM jobs WHERE id = ?)", jobID).Scan(&exists) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM jobs WHERE id = ?)", jobID).Scan(&exists) + }) if err != nil || !exists { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -3788,8 +4613,10 @@ func (s *Server) handleJobWebSocket(w http.ResponseWriter, r *http.Request) { // Read loop exited, close connection return case <-ticker.C: - conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) - if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { + // Reset read deadline before sending ping to ensure we can receive pong + conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + // Use WriteControl for ping frames (control frames) + if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { return } } @@ -3800,7 +4627,9 @@ func (s *Server) handleJobWebSocket(w http.ResponseWriter, r *http.Request) { func (s *Server) broadcastJobUpdate(jobID int64, updateType string, data interface{}) { // Get user_id from job var userID int64 - err := s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&userID) + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&userID) + }) if err != nil { return } @@ -3812,7 +4641,38 @@ func (s *Server) broadcastJobUpdate(jobID int64, updateType string, data interfa "timestamp": time.Now().Unix(), } - // Broadcast to job list connection + // Always broadcast to jobs channel (all clients receive this) + if updateType == "job_update" || updateType == "job_created" { + // For job_update, only send status and progress to jobs channel + if updateType == "job_update" { + if dataMap, ok := data.(map[string]interface{}); ok { + // Only include status and progress for jobs channel + jobsData := map[string]interface{}{} + if status, ok := dataMap["status"]; ok { + jobsData["status"] = status + } + if progress, ok := dataMap["progress"]; ok { + jobsData["progress"] = progress + } + jobsMsg := map[string]interface{}{ + "type": updateType, + "job_id": jobID, + "data": jobsData, + "timestamp": time.Now().Unix(), + } + s.broadcastToAllClients("jobs", jobsMsg) + } + } else { + // job_created - send full data to all clients + s.broadcastToAllClients("jobs", msg) + } + } + + // Broadcast to client WebSocket if subscribed to job:{id} + channel := fmt.Sprintf("job:%d", jobID) + s.broadcastToClient(userID, channel, msg) + + // Also broadcast to old WebSocket connections (for backwards compatibility during migration) s.jobListConnsMu.RLock() if conn, exists := s.jobListConns[userID]; exists && conn != nil { s.jobListConnsMu.RUnlock() @@ -3840,15 +4700,10 @@ func (s *Server) broadcastJobUpdate(jobID int64, updateType string, data interfa writeMu.Unlock() if err != nil { log.Printf("Failed to broadcast %s to job %d WebSocket: %v", updateType, jobID, err) - } else { - log.Printf("Successfully broadcast %s to job %d WebSocket", updateType, jobID) } } else { conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) - err := conn.WriteJSON(msg) - if err != nil { - log.Printf("Failed to broadcast %s to job %d WebSocket: %v", updateType, jobID, err) - } + conn.WriteJSON(msg) } } } @@ -3857,8 +4712,11 @@ func (s *Server) broadcastJobUpdate(jobID int64, updateType string, data interfa func (s *Server) broadcastTaskUpdate(jobID int64, taskID int64, updateType string, data interface{}) { // Get user_id from job var userID int64 - err := s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&userID) + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&userID) + }) if err != nil { + log.Printf("broadcastTaskUpdate: Failed to get user_id for job %d: %v", jobID, err) return } @@ -3878,7 +4736,14 @@ func (s *Server) broadcastTaskUpdate(jobID int64, taskID int64, updateType strin } } - // Broadcast to single job connection + // Broadcast to client WebSocket if subscribed to job:{id} + channel := fmt.Sprintf("job:%d", jobID) + if s.verboseWSLogging { + log.Printf("broadcastTaskUpdate: Broadcasting %s for task %d (job %d, user %d) on channel %s, data=%+v", updateType, taskID, jobID, userID, channel, data) + } + s.broadcastToClient(userID, channel, msg) + + // Also broadcast to old WebSocket connection (for backwards compatibility during migration) key := fmt.Sprintf("%d:%d", userID, jobID) s.jobConnsMu.RLock() conn, exists := s.jobConns[key] @@ -3901,6 +4766,118 @@ func (s *Server) broadcastTaskUpdate(jobID int64, taskID int64, updateType strin } } +// broadcastToClient sends a message to a specific client connection +func (s *Server) broadcastToClient(userID int64, channel string, msg map[string]interface{}) { + s.clientConnsMu.RLock() + clientConn, exists := s.clientConns[userID] + s.clientConnsMu.RUnlock() + + if !exists || clientConn == nil { + log.Printf("broadcastToClient: Client %d not connected (channel: %s)", userID, channel) + return + } + + // Check if client is subscribed to this channel (jobs channel is always sent) + if channel != "jobs" { + clientConn.SubsMu.RLock() + subscribed := clientConn.Subscriptions[channel] + allSubs := make([]string, 0, len(clientConn.Subscriptions)) + for ch := range clientConn.Subscriptions { + allSubs = append(allSubs, ch) + } + clientConn.SubsMu.RUnlock() + if !subscribed { + if s.verboseWSLogging { + log.Printf("broadcastToClient: Client %d not subscribed to channel %s (subscribed to: %v)", userID, channel, allSubs) + } + return + } + if s.verboseWSLogging { + log.Printf("broadcastToClient: Client %d is subscribed to channel %s", userID, channel) + } + } + + // Add channel to message + msg["channel"] = channel + + // Log what we're sending + if s.verboseWSLogging { + log.Printf("broadcastToClient: Sending to client %d on channel %s: type=%v, job_id=%v, task_id=%v", + userID, channel, msg["type"], msg["job_id"], msg["task_id"]) + } + + clientConn.WriteMu.Lock() + defer clientConn.WriteMu.Unlock() + + clientConn.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if err := clientConn.Conn.WriteJSON(msg); err != nil { + log.Printf("Failed to send message to client %d on channel %s: %v", userID, channel, err) + } else { + if s.verboseWSLogging { + log.Printf("broadcastToClient: Successfully sent message to client %d on channel %s", userID, channel) + } + } +} + +// broadcastToAllClients sends a message to all connected clients (for jobs channel) +func (s *Server) broadcastToAllClients(channel string, msg map[string]interface{}) { + msg["channel"] = channel + + s.clientConnsMu.RLock() + clients := make([]*ClientConnection, 0, len(s.clientConns)) + for _, clientConn := range s.clientConns { + clients = append(clients, clientConn) + } + s.clientConnsMu.RUnlock() + + for _, clientConn := range clients { + clientConn.WriteMu.Lock() + clientConn.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if err := clientConn.Conn.WriteJSON(msg); err != nil { + log.Printf("Failed to broadcast to client %d: %v", clientConn.UserID, err) + } + clientConn.WriteMu.Unlock() + } +} + +// broadcastUploadProgress broadcasts upload/processing progress to subscribed clients +func (s *Server) broadcastUploadProgress(sessionID string, progress float64, status, message string) { + s.uploadSessionsMu.RLock() + session, exists := s.uploadSessions[sessionID] + s.uploadSessionsMu.RUnlock() + + if !exists { + return + } + + // Update session + s.uploadSessionsMu.Lock() + session.Progress = progress + session.Status = status + session.Message = message + s.uploadSessionsMu.Unlock() + + // Determine message type + msgType := "upload_progress" + if status != "uploading" { + msgType = "processing_status" + } + + msg := map[string]interface{}{ + "type": msgType, + "session_id": sessionID, + "data": map[string]interface{}{ + "progress": progress, + "status": status, + "message": message, + }, + "timestamp": time.Now().Unix(), + } + + channel := fmt.Sprintf("upload:%s", sessionID) + s.broadcastToClient(session.UserID, channel, msg) +} + // truncateString truncates a string to a maximum length, appending "..." if truncated func truncateString(s string, maxLen int) string { if len(s) <= maxLen { diff --git a/internal/api/metadata.go b/internal/api/metadata.go index 44c72f8..9789a89 100644 --- a/internal/api/metadata.go +++ b/internal/api/metadata.go @@ -2,8 +2,6 @@ package api import ( "archive/tar" - "bufio" - "bytes" "database/sql" _ "embed" "encoding/json" @@ -13,10 +11,10 @@ import ( "log" "net/http" "os" - "os/exec" "path/filepath" "strings" + "jiggablend/pkg/executils" "jiggablend/pkg/scripts" "jiggablend/pkg/types" ) @@ -44,7 +42,9 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) { // Verify job exists var jobUserID int64 - err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -57,30 +57,36 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) { // Find the metadata extraction task for this job // First try to find task assigned to this runner, then fall back to any metadata task for this job var taskID int64 - err = s.db.QueryRow( + err = s.db.With(func(conn *sql.DB) error { + err := conn.QueryRow( `SELECT id FROM tasks WHERE job_id = ? AND task_type = ? AND runner_id = ?`, jobID, types.TaskTypeMetadata, runnerID, ).Scan(&taskID) if err == sql.ErrNoRows { // Fall back to any metadata task for this job (in case assignment changed) - err = s.db.QueryRow( + err = conn.QueryRow( `SELECT id FROM tasks WHERE job_id = ? AND task_type = ? ORDER BY created_at DESC LIMIT 1`, jobID, types.TaskTypeMetadata, ).Scan(&taskID) if err == sql.ErrNoRows { + return fmt.Errorf("metadata extraction task not found") + } + if err != nil { + return err + } + // Update the task to be assigned to this runner if it wasn't already + conn.Exec( + `UPDATE tasks SET runner_id = ? WHERE id = ? AND runner_id IS NULL`, + runnerID, taskID, + ) + } + return err + }) + if err != nil { + if err.Error() == "metadata extraction task not found" { s.respondError(w, http.StatusNotFound, "Metadata extraction task not found") return } - if err != nil { - s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to find task: %v", err)) - return - } - // Update the task to be assigned to this runner if it wasn't already - s.db.Exec( - `UPDATE tasks SET runner_id = ? WHERE id = ? AND runner_id IS NULL`, - runnerID, taskID, - ) - } else if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to find task: %v", err)) return } @@ -93,20 +99,27 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) { } // Update job with metadata - _, err = s.db.Exec( + err = s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( `UPDATE jobs SET blend_metadata = ? WHERE id = ?`, string(metadataJSON), jobID, ) + return err + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update job metadata: %v", err)) return } // Mark task as completed - _, err = s.db.Exec( - `UPDATE tasks SET status = ?, completed_at = CURRENT_TIMESTAMP WHERE id = ?`, - types.TaskStatusCompleted, taskID, - ) + err = s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec(`UPDATE tasks SET status = ? WHERE id = ?`, types.TaskStatusCompleted, taskID) + if err != nil { + return err + } + _, err = conn.Exec(`UPDATE tasks SET completed_at = CURRENT_TIMESTAMP WHERE id = ?`, taskID) + return err + }) if err != nil { log.Printf("Failed to mark metadata task as completed: %v", err) } else { @@ -136,10 +149,12 @@ func (s *Server) handleGetJobMetadata(w http.ResponseWriter, r *http.Request) { // Verify job belongs to user var jobUserID int64 var blendMetadataJSON sql.NullString - err = s.db.QueryRow( + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( `SELECT user_id, blend_metadata FROM jobs WHERE id = ?`, jobID, ).Scan(&jobUserID, &blendMetadataJSON) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") return @@ -245,64 +260,23 @@ func (s *Server) extractMetadataFromContext(jobID int64) (*types.BlendMetadata, return nil, fmt.Errorf("failed to get relative path for blend file: %w", err) } - // Execute Blender with Python script - cmd := exec.Command("blender", "-b", blendFileRel, "--python", "extract_metadata.py") - cmd.Dir = tmpDir + // Execute Blender with Python script using executils + result, err := executils.RunCommand( + "blender", + []string{"-b", blendFileRel, "--python", "extract_metadata.py"}, + tmpDir, + nil, // inherit environment + jobID, + nil, // no process tracker needed for metadata extraction + ) - // Capture stdout and stderr - stdoutPipe, err := cmd.StdoutPipe() if err != nil { - return nil, fmt.Errorf("failed to create stdout pipe: %w", err) - } - - stderrPipe, err := cmd.StderrPipe() - if err != nil { - return nil, fmt.Errorf("failed to create stderr pipe: %w", err) - } - - // Buffer to collect stdout for JSON parsing - var stdoutBuffer bytes.Buffer - - // Start the command - if err := cmd.Start(); err != nil { - return nil, fmt.Errorf("failed to start blender: %w", err) - } - - // Stream stdout and collect for JSON parsing - stdoutDone := make(chan bool) - go func() { - defer close(stdoutDone) - scanner := bufio.NewScanner(stdoutPipe) - for scanner.Scan() { - line := scanner.Text() - stdoutBuffer.WriteString(line) - stdoutBuffer.WriteString("\n") + stderrOutput := "" + stdoutOutput := "" + if result != nil { + stderrOutput = strings.TrimSpace(result.Stderr) + stdoutOutput = strings.TrimSpace(result.Stdout) } - }() - - // Capture stderr for error reporting - var stderrBuffer bytes.Buffer - stderrDone := make(chan bool) - go func() { - defer close(stderrDone) - scanner := bufio.NewScanner(stderrPipe) - for scanner.Scan() { - line := scanner.Text() - stderrBuffer.WriteString(line) - stderrBuffer.WriteString("\n") - } - }() - - // Wait for command to complete - err = cmd.Wait() - - // Wait for streaming goroutines to finish - <-stdoutDone - <-stderrDone - - if err != nil { - stderrOutput := strings.TrimSpace(stderrBuffer.String()) - stdoutOutput := strings.TrimSpace(stdoutBuffer.String()) log.Printf("Blender metadata extraction failed for job %d:", jobID) if stderrOutput != "" { log.Printf("Blender stderr: %s", stderrOutput) @@ -317,7 +291,7 @@ func (s *Server) extractMetadataFromContext(jobID int64) (*types.BlendMetadata, } // Parse output (metadata is printed to stdout) - metadataJSON := strings.TrimSpace(stdoutBuffer.String()) + metadataJSON := strings.TrimSpace(result.Stdout) // Extract JSON from output (Blender may print other stuff) jsonStart := strings.Index(metadataJSON, "{") jsonEnd := strings.LastIndex(metadataJSON, "}") diff --git a/internal/api/runners.go b/internal/api/runners.go index 1a2b975..97c9cc2 100644 --- a/internal/api/runners.go +++ b/internal/api/runners.go @@ -69,7 +69,9 @@ func (s *Server) runnerAuthMiddleware(next http.HandlerFunc) http.HandlerFunc { if apiKeyID != -1 { // Verify runner exists and uses this API key var dbAPIKeyID sql.NullInt64 - err = s.db.QueryRow("SELECT api_key_id FROM runners WHERE id = ?", runnerID).Scan(&dbAPIKeyID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT api_key_id FROM runners WHERE id = ?", runnerID).Scan(&dbAPIKeyID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "runner not found") return @@ -86,7 +88,9 @@ func (s *Server) runnerAuthMiddleware(next http.HandlerFunc) http.HandlerFunc { } else { // No runner ID provided - find the runner for this API key // For simplicity, assume each API key has one runner - err = s.db.QueryRow("SELECT id FROM runners WHERE api_key_id = ?", apiKeyID).Scan(&runnerID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT id FROM runners WHERE api_key_id = ?", apiKeyID).Scan(&runnerID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "no runner found for this API key") return @@ -120,10 +124,33 @@ func (s *Server) handleRegisterRunner(w http.ResponseWriter, r *http.Request) { s.secrets.RegistrationMu.Lock() defer s.secrets.RegistrationMu.Unlock() + // Validate runner name if req.Name == "" { s.respondError(w, http.StatusBadRequest, "Runner name is required") return } + if len(req.Name) > 255 { + s.respondError(w, http.StatusBadRequest, "Runner name must be 255 characters or less") + return + } + + // Validate hostname + if req.Hostname != "" { + // Basic hostname validation (allow IP addresses and domain names) + if len(req.Hostname) > 253 { + s.respondError(w, http.StatusBadRequest, "Hostname must be 253 characters or less") + return + } + } + + // Validate capabilities JSON if provided + if req.Capabilities != "" { + var testCapabilities map[string]interface{} + if err := json.Unmarshal([]byte(req.Capabilities), &testCapabilities); err != nil { + s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid capabilities JSON: %v", err)) + return + } + } if req.APIKey == "" { s.respondError(w, http.StatusBadRequest, "API key is required") @@ -166,10 +193,12 @@ func (s *Server) handleRegisterRunner(w http.ResponseWriter, r *http.Request) { if apiKeyID != -1 && req.Fingerprint != "" { var existingRunnerID int64 var existingAPIKeyID sql.NullInt64 - err = s.db.QueryRow( + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( "SELECT id, api_key_id FROM runners WHERE fingerprint = ?", req.Fingerprint, ).Scan(&existingRunnerID, &existingAPIKeyID) + }) if err == nil { // Runner already exists with this fingerprint @@ -177,10 +206,13 @@ func (s *Server) handleRegisterRunner(w http.ResponseWriter, r *http.Request) { // Same API key - update and return existing runner log.Printf("Runner with fingerprint %s already exists (ID: %d), updating info", req.Fingerprint, existingRunnerID) - _, err = s.db.Exec( + err = s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( `UPDATE runners SET name = ?, hostname = ?, capabilities = ?, status = ?, last_heartbeat = ? WHERE id = ?`, req.Name, req.Hostname, req.Capabilities, types.RunnerStatusOnline, time.Now(), existingRunnerID, ) + return err + }) if err != nil { log.Printf("Warning: Failed to update existing runner info: %v", err) } @@ -203,14 +235,20 @@ func (s *Server) handleRegisterRunner(w http.ResponseWriter, r *http.Request) { } // Insert runner - err = s.db.QueryRow( + err = s.db.With(func(conn *sql.DB) error { + result, err := conn.Exec( `INSERT INTO runners (name, hostname, ip_address, status, last_heartbeat, capabilities, api_key_id, api_key_scope, priority, fingerprint) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - RETURNING id`, + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, req.Name, req.Hostname, "", types.RunnerStatusOnline, time.Now(), req.Capabilities, dbAPIKeyID, apiKeyScope, priority, fingerprint, - ).Scan(&runnerID) + ) + if err != nil { + return err + } + runnerID, err = result.LastInsertId() + return err + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to register runner: %v", err)) return @@ -238,10 +276,13 @@ func (s *Server) handleRunnerPing(w http.ResponseWriter, r *http.Request) { } // Update last heartbeat - _, err := s.db.Exec( + err := s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( `UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`, time.Now(), types.RunnerStatusOnline, runnerID, ) + return err + }) if err != nil { log.Printf("Warning: Failed to update runner heartbeat: %v", err) } @@ -301,7 +342,9 @@ func (s *Server) handleUpdateTaskStep(w http.ResponseWriter, r *http.Request) { // Verify task belongs to runner var taskRunnerID sql.NullInt64 - err = s.db.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskID).Scan(&taskRunnerID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskID).Scan(&taskRunnerID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Task not found") return @@ -320,10 +363,12 @@ func (s *Server) handleUpdateTaskStep(w http.ResponseWriter, r *http.Request) { // Check if step already exists var existingStepID sql.NullInt64 - err = s.db.QueryRow( + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( `SELECT id FROM task_steps WHERE task_id = ? AND step_name = ?`, taskID, req.StepName, ).Scan(&existingStepID) + }) if err == sql.ErrNoRows || !existingStepID.Valid { // Create new step @@ -336,12 +381,18 @@ func (s *Server) handleUpdateTaskStep(w http.ResponseWriter, r *http.Request) { completedAt = &now } - err = s.db.QueryRow( + err = s.db.With(func(conn *sql.DB) error { + result, err := conn.Exec( `INSERT INTO task_steps (task_id, step_name, status, started_at, completed_at, duration_ms, error_message) - VALUES (?, ?, ?, ?, ?, ?, ?) - RETURNING id`, + VALUES (?, ?, ?, ?, ?, ?, ?)`, taskID, req.StepName, req.Status, startedAt, completedAt, req.DurationMs, req.ErrorMessage, - ).Scan(&stepID) + ) + if err != nil { + return err + } + stepID, err = result.LastInsertId() + return err + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create step: %v", err)) return @@ -355,7 +406,9 @@ func (s *Server) handleUpdateTaskStep(w http.ResponseWriter, r *http.Request) { // Get existing started_at if status is running/completed/failed if req.Status == string(types.StepStatusRunning) || req.Status == string(types.StepStatusCompleted) || req.Status == string(types.StepStatusFailed) { var existingStartedAt sql.NullTime - s.db.QueryRow(`SELECT started_at FROM task_steps WHERE id = ?`, stepID).Scan(&existingStartedAt) + s.db.With(func(conn *sql.DB) error { + return conn.QueryRow(`SELECT started_at FROM task_steps WHERE id = ?`, stepID).Scan(&existingStartedAt) + }) if existingStartedAt.Valid { startedAt = &existingStartedAt.Time } else { @@ -367,11 +420,14 @@ func (s *Server) handleUpdateTaskStep(w http.ResponseWriter, r *http.Request) { completedAt = &now } - _, err = s.db.Exec( + err = s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( `UPDATE task_steps SET status = ?, started_at = ?, completed_at = ?, duration_ms = ?, error_message = ? WHERE id = ?`, req.Status, startedAt, completedAt, req.DurationMs, req.ErrorMessage, stepID, ) + return err + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update step: %v", err)) return @@ -380,7 +436,9 @@ func (s *Server) handleUpdateTaskStep(w http.ResponseWriter, r *http.Request) { // Get job ID for broadcasting var jobID int64 - err = s.db.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&jobID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&jobID) + }) if err == nil { // Broadcast step update to frontend s.broadcastTaskUpdate(jobID, taskID, "step_update", map[string]interface{}{ @@ -462,12 +520,18 @@ func (s *Server) handleUploadFileFromRunner(w http.ResponseWriter, r *http.Reque // Record in database var fileID int64 - err = s.db.QueryRow( + err = s.db.With(func(conn *sql.DB) error { + result, err := conn.Exec( `INSERT INTO job_files (job_id, file_type, file_path, file_name, file_size) - VALUES (?, ?, ?, ?, ?) - RETURNING id`, + VALUES (?, ?, ?, ?, ?)`, jobID, types.JobFileTypeOutput, filePath, header.Filename, header.Size, - ).Scan(&fileID) + ) + if err != nil { + return err + } + fileID, err = result.LastInsertId() + return err + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to record file: %v", err)) return @@ -503,7 +567,8 @@ func (s *Server) handleGetJobStatusForRunner(w http.ResponseWriter, r *http.Requ var frameStart, frameEnd sql.NullInt64 var outputFormat sql.NullString var allowParallelRunners sql.NullBool - err = s.db.QueryRow( + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( `SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format, allow_parallel_runners, created_at, started_at, completed_at, error_message FROM jobs WHERE id = ?`, @@ -513,6 +578,11 @@ func (s *Server) handleGetJobStatusForRunner(w http.ResponseWriter, r *http.Requ &frameStart, &frameEnd, &outputFormat, &allowParallelRunners, &job.CreatedAt, &startedAt, &completedAt, &errorMessage, ) + }) + if err != nil { + s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query job: %v", err)) + return + } job.JobType = types.JobType(jobType) if frameStart.Valid { @@ -560,11 +630,16 @@ func (s *Server) handleGetJobFilesForRunner(w http.ResponseWriter, r *http.Reque return } - rows, err := s.db.Query( + var rows *sql.Rows + err = s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query( `SELECT id, job_id, file_type, file_path, file_name, file_size, created_at FROM job_files WHERE job_id = ? ORDER BY file_name`, jobID, ) + return err + }) if err != nil { s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query files: %v", err)) return @@ -597,10 +672,12 @@ func (s *Server) handleGetJobMetadataForRunner(w http.ResponseWriter, r *http.Re } var blendMetadataJSON sql.NullString - err = s.db.QueryRow( + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( `SELECT blend_metadata FROM jobs WHERE id = ?`, jobID, ).Scan(&blendMetadataJSON) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "Job not found") @@ -649,10 +726,12 @@ func (s *Server) handleDownloadFileForRunner(w http.ResponseWriter, r *http.Requ // Get file info from database var filePath string - err = s.db.QueryRow( + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( `SELECT file_path FROM job_files WHERE job_id = ? AND file_name = ?`, jobID, decodedFileName, ).Scan(&filePath) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "File not found") return @@ -766,7 +845,9 @@ func (s *Server) handleRunnerWebSocket(w http.ResponseWriter, r *http.Request) { // For fixed API keys, skip database verification if apiKeyID != -1 { var dbAPIKeyID sql.NullInt64 - err = s.db.QueryRow("SELECT api_key_id FROM runners WHERE id = ?", runnerID).Scan(&dbAPIKeyID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT api_key_id FROM runners WHERE id = ?", runnerID).Scan(&dbAPIKeyID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "runner not found") return @@ -782,7 +863,9 @@ func (s *Server) handleRunnerWebSocket(w http.ResponseWriter, r *http.Request) { } } else { // No runner ID provided - find the runner for this API key - err = s.db.QueryRow("SELECT id FROM runners WHERE api_key_id = ?", apiKeyID).Scan(&runnerID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT id FROM runners WHERE api_key_id = ?", apiKeyID).Scan(&runnerID) + }) if err == sql.ErrNoRows { s.respondError(w, http.StatusNotFound, "no runner found for this API key") return @@ -802,37 +885,56 @@ func (s *Server) handleRunnerWebSocket(w http.ResponseWriter, r *http.Request) { defer conn.Close() // Register connection (must be done before any distribution checks) - // Close old connection outside lock to avoid blocking + // Fix race condition: Close old connection and create write mutex BEFORE registering new connection var oldConn *websocket.Conn - var hadExistingConnection bool s.runnerConnsMu.Lock() if existingConn, exists := s.runnerConns[runnerID]; exists { oldConn = existingConn - hadExistingConnection = true } - s.runnerConns[runnerID] = conn s.runnerConnsMu.Unlock() - // Close old connection outside lock (if it existed) + // Close old connection BEFORE registering new one to prevent race conditions if oldConn != nil { log.Printf("Runner %d: closing existing WebSocket connection (reconnection)", runnerID) oldConn.Close() - } else if hadExistingConnection { - log.Printf("Runner %d: replacing existing WebSocket connection", runnerID) } - log.Printf("Runner %d: WebSocket connection established successfully", runnerID) - - // Create a write mutex for this connection + // Create write mutex BEFORE registering connection to prevent race condition s.runnerConnsWriteMuMu.Lock() s.runnerConnsWriteMu[runnerID] = &sync.Mutex{} s.runnerConnsWriteMuMu.Unlock() + // Now register the new connection + s.runnerConnsMu.Lock() + s.runnerConns[runnerID] = conn + s.runnerConnsMu.Unlock() + + log.Printf("Runner %d: WebSocket connection established successfully", runnerID) + + // Check if runner was offline and had tasks assigned - redistribute them + // This handles the case where the manager restarted and marked the runner offline + // but tasks were still assigned to it + s.db.With(func(conn *sql.DB) error { + var count int + err := conn.QueryRow( + `SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status = ?`, + runnerID, types.TaskStatusRunning, + ).Scan(&count) + if err == nil && count > 0 { + log.Printf("Runner %d reconnected with %d running tasks assigned - redistributing them", runnerID, count) + s.redistributeRunnerTasks(runnerID) + } + return nil + }) + // Update runner status to online - _, _ = s.db.Exec( + s.db.With(func(conn *sql.DB) error { + _, _ = conn.Exec( `UPDATE runners SET status = ?, last_heartbeat = ? WHERE id = ?`, types.RunnerStatusOnline, time.Now(), runnerID, ) + return nil + }) // Immediately try to distribute pending tasks to this newly connected runner log.Printf("Runner %d connected, distributing pending tasks", runnerID) @@ -842,23 +944,36 @@ func (s *Server) handleRunnerWebSocket(w http.ResponseWriter, r *http.Request) { // Task assignment logging happens in distributeTasksToRunners // Cleanup on disconnect + // Fix race condition: Only cleanup if this is still the current connection (not replaced by reconnection) defer func() { log.Printf("Runner %d: WebSocket connection cleanup started", runnerID) - // Update database status first - _, err := s.db.Exec( + // Check if this is still the current connection before cleanup + s.runnerConnsMu.Lock() + currentConn, stillCurrent := s.runnerConns[runnerID] + if !stillCurrent || currentConn != conn { + // Connection was replaced by a newer one, don't cleanup + s.runnerConnsMu.Unlock() + log.Printf("Runner %d: Skipping cleanup - connection was replaced by newer connection", runnerID) + return + } + // Remove connection from map + delete(s.runnerConns, runnerID) + s.runnerConnsMu.Unlock() + + // Update database status + err := s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( `UPDATE runners SET status = ?, last_heartbeat = ? WHERE id = ?`, types.RunnerStatusOffline, time.Now(), runnerID, ) + return err + }) if err != nil { log.Printf("Warning: Failed to update runner %d status to offline: %v", runnerID, err) } - // Clean up connection maps - s.runnerConnsMu.Lock() - delete(s.runnerConns, runnerID) - s.runnerConnsMu.Unlock() - + // Clean up write mutex s.runnerConnsWriteMuMu.Lock() delete(s.runnerConnsWriteMu, runnerID) s.runnerConnsWriteMuMu.Unlock() @@ -874,10 +989,13 @@ func (s *Server) handleRunnerWebSocket(w http.ResponseWriter, r *http.Request) { // Also reset read deadline to keep connection alive conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(90 * time.Second)) // Increased to 90 seconds - _, _ = s.db.Exec( + s.db.With(func(conn *sql.DB) error { + _, _ = conn.Exec( `UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`, time.Now(), types.RunnerStatusOnline, runnerID, ) + return nil + }) return nil }) @@ -937,13 +1055,11 @@ func (s *Server) handleRunnerWebSocket(w http.ResponseWriter, r *http.Request) { switch msg.Type { case "heartbeat": - // Update heartbeat from explicit heartbeat message + // Heartbeat messages are handled by pong handler (manager-side) // Reset read deadline to keep connection alive conn.SetReadDeadline(time.Now().Add(90 * time.Second)) - _, _ = s.db.Exec( - `UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`, - time.Now(), types.RunnerStatusOnline, runnerID, - ) + // Note: Heartbeat updates are consolidated to pong handler to avoid race conditions + // The pong handler is the single source of truth for heartbeat updates case "log_entry": var logEntry WSLogEntry @@ -969,11 +1085,14 @@ func (s *Server) handleRunnerWebSocket(w http.ResponseWriter, r *http.Request) { // handleWebSocketLog handles log entries from WebSocket func (s *Server) handleWebSocketLog(runnerID int64, logEntry WSLogEntry) { // Store log in database - _, err := s.db.Exec( + err := s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( `INSERT INTO task_logs (task_id, runner_id, log_level, message, step_name, created_at) VALUES (?, ?, ?, ?, ?, ?)`, logEntry.TaskID, runnerID, logEntry.LogLevel, logEntry.Message, logEntry.StepName, time.Now(), ) + return err + }) if err != nil { log.Printf("Failed to store log: %v", err) return @@ -986,7 +1105,9 @@ func (s *Server) handleWebSocketLog(runnerID int64, logEntry WSLogEntry) { if strings.Contains(logEntry.Message, "Fra:") { // Get job ID from task var jobID int64 - err := s.db.QueryRow("SELECT job_id FROM tasks WHERE id = ?", logEntry.TaskID).Scan(&jobID) + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT job_id FROM tasks WHERE id = ?", logEntry.TaskID).Scan(&jobID) + }) if err == nil { // Throttle progress updates (max once per 2 seconds per job) s.progressUpdateTimesMu.RLock() @@ -1017,7 +1138,9 @@ func (s *Server) handleWebSocketTaskUpdate(runnerID int64, taskUpdate WSTaskUpda func (s *Server) handleWebSocketTaskComplete(runnerID int64, taskUpdate WSTaskUpdate) { // Verify task belongs to runner var taskRunnerID sql.NullInt64 - err := s.db.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskUpdate.TaskID).Scan(&taskRunnerID) + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskUpdate.TaskID).Scan(&taskRunnerID) + }) if err != nil || !taskRunnerID.Valid || taskRunnerID.Int64 != runnerID { log.Printf("Task %d does not belong to runner %d", taskUpdate.TaskID, runnerID) return @@ -1028,32 +1151,67 @@ func (s *Server) handleWebSocketTaskComplete(runnerID int64, taskUpdate WSTaskUp status = types.TaskStatusFailed } - now := time.Now() - _, err = s.db.Exec( - `UPDATE tasks SET status = ?, output_path = ?, completed_at = ?, error_message = ? WHERE id = ?`, - status, taskUpdate.OutputPath, now, taskUpdate.Error, taskUpdate.TaskID, - ) - if err != nil { - log.Printf("Failed to update task: %v", err) - return - } - - // Update job status and progress + // Get job ID first for atomic update var jobID int64 - err = s.db.QueryRow( + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( `SELECT job_id FROM tasks WHERE id = ?`, taskUpdate.TaskID, ).Scan(&jobID) - if err == nil { - // Broadcast task update - s.broadcastTaskUpdate(jobID, taskUpdate.TaskID, "task_update", map[string]interface{}{ - "status": status, - "output_path": taskUpdate.OutputPath, - "completed_at": now, - "error": taskUpdate.Error, - }) - s.updateJobStatusFromTasks(jobID) + }) + if err != nil { + log.Printf("Failed to get job ID for task %d: %v", taskUpdate.TaskID, err) + return } + + // Use transaction to update task and job status atomically + now := time.Now() + err = s.db.WithTx(func(tx *sql.Tx) error { + // Update columns individually + _, err := tx.Exec(`UPDATE tasks SET status = ? WHERE id = ?`, status, taskUpdate.TaskID) + if err != nil { + log.Printf("Failed to update task status: %v", err) + return err + } + + if taskUpdate.OutputPath != "" { + _, err = tx.Exec(`UPDATE tasks SET output_path = ? WHERE id = ?`, taskUpdate.OutputPath, taskUpdate.TaskID) + if err != nil { + log.Printf("Failed to update task output_path: %v", err) + return err + } + } + + _, err = tx.Exec(`UPDATE tasks SET completed_at = ? WHERE id = ?`, now, taskUpdate.TaskID) + if err != nil { + log.Printf("Failed to update task completed_at: %v", err) + return err + } + + if taskUpdate.Error != "" { + _, err = tx.Exec(`UPDATE tasks SET error_message = ? WHERE id = ?`, taskUpdate.Error, taskUpdate.TaskID) + if err != nil { + log.Printf("Failed to update task error_message: %v", err) + return err + } + } + + return nil // Commit on nil return + }) + if err != nil { + log.Printf("Failed to update task %d: %v", taskUpdate.TaskID, err) + return + } + + // Broadcast task update + s.broadcastTaskUpdate(jobID, taskUpdate.TaskID, "task_update", map[string]interface{}{ + "status": status, + "output_path": taskUpdate.OutputPath, + "completed_at": now, + "error": taskUpdate.Error, + }) + // Update job status and progress (this will query tasks and update job accordingly) + s.updateJobStatusFromTasks(jobID) } // parseBlenderFrame extracts the current frame number from Blender log messages @@ -1092,10 +1250,15 @@ func parseBlenderFrame(logMessage string) (int, bool) { // getCurrentFrameFromLogs gets the highest frame number found in logs for a job's render tasks func (s *Server) getCurrentFrameFromLogs(jobID int64) (int, bool) { // Get all render tasks for this job - rows, err := s.db.Query( + var rows *sql.Rows + err := s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query( `SELECT id FROM tasks WHERE job_id = ? AND task_type = ? AND status = ?`, jobID, types.TaskTypeRender, types.TaskStatusRunning, ) + return err + }) if err != nil { return 0, false } @@ -1112,12 +1275,17 @@ func (s *Server) getCurrentFrameFromLogs(jobID int64) (int, bool) { } // Get the most recent log entries for this task (last 100 to avoid scanning all logs) - logRows, err := s.db.Query( + var logRows *sql.Rows + err := s.db.With(func(conn *sql.DB) error { + var err error + logRows, err = conn.Query( `SELECT message FROM task_logs WHERE task_id = ? AND message LIKE '%Fra:%' ORDER BY id DESC LIMIT 100`, taskID, ) + return err + }) if err != nil { continue } @@ -1141,6 +1309,62 @@ func (s *Server) getCurrentFrameFromLogs(jobID int64) (int, bool) { return maxFrame, found } +// resetFailedTasksAndRedistribute resets all failed tasks for a job to pending and redistributes them +func (s *Server) resetFailedTasksAndRedistribute(jobID int64) error { + // Reset all failed tasks to pending and clear their retry_count + err := s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + `UPDATE tasks SET status = ?, retry_count = 0, runner_id = NULL, started_at = NULL, completed_at = NULL, error_message = NULL + WHERE job_id = ? AND status = ?`, + types.TaskStatusPending, jobID, types.TaskStatusFailed, + ) + if err != nil { + return fmt.Errorf("failed to reset failed tasks: %v", err) + } + + // Increment job retry_count + _, err = conn.Exec( + `UPDATE jobs SET retry_count = retry_count + 1 WHERE id = ?`, + jobID, + ) + if err != nil { + return fmt.Errorf("failed to increment job retry_count: %v", err) + } + return nil + }) + if err != nil { + return err + } + + log.Printf("Reset failed tasks for job %d and incremented retry_count", jobID) + + // Trigger task distribution to redistribute the reset tasks + s.triggerTaskDistribution() + + return nil +} + +// cancelActiveTasksForJob cancels all active (pending or running) tasks for a job +func (s *Server) cancelActiveTasksForJob(jobID int64) error { + // Tasks don't have a cancelled status - mark them as failed instead + err := s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + `UPDATE tasks SET status = ?, error_message = ? WHERE job_id = ? AND status IN (?, ?)`, + types.TaskStatusFailed, "Job cancelled", jobID, types.TaskStatusPending, types.TaskStatusRunning, + ) + if err != nil { + return fmt.Errorf("failed to cancel active tasks: %v", err) + } + return nil + }) + if err != nil { + return err + } + + log.Printf("Cancelled all active tasks for job %d", jobID) + return nil +} + // updateJobStatusFromTasks updates job status and progress based on task states func (s *Server) updateJobStatusFromTasks(jobID int64) { now := time.Now() @@ -1149,10 +1373,12 @@ func (s *Server) updateJobStatusFromTasks(jobID int64) { var jobType string var frameStart, frameEnd sql.NullInt64 var allowParallelRunners sql.NullBool - err := s.db.QueryRow( + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( `SELECT job_type, frame_start, frame_end, allow_parallel_runners FROM jobs WHERE id = ?`, jobID, ).Scan(&jobType, &frameStart, &frameEnd, &allowParallelRunners) + }) if err != nil { log.Printf("Failed to get job info for job %d: %v", jobID, err) return @@ -1165,7 +1391,9 @@ func (s *Server) updateJobStatusFromTasks(jobID int64) { // Get current job status to detect changes var currentStatus string - err = s.db.QueryRow(`SELECT status FROM jobs WHERE id = ?`, jobID).Scan(¤tStatus) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow(`SELECT status FROM jobs WHERE id = ?`, jobID).Scan(¤tStatus) + }) if err != nil { log.Printf("Failed to get current job status for job %d: %v", jobID, err) return @@ -1173,18 +1401,19 @@ func (s *Server) updateJobStatusFromTasks(jobID int64) { // Count total tasks and completed tasks var totalTasks, completedTasks int - err = s.db.QueryRow( + err = s.db.With(func(conn *sql.DB) error { + err := conn.QueryRow( `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?, ?, ?)`, jobID, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed, ).Scan(&totalTasks) if err != nil { - log.Printf("Failed to count total tasks for job %d: %v", jobID, err) - return + return err } - err = s.db.QueryRow( + return conn.QueryRow( `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`, jobID, types.TaskStatusCompleted, ).Scan(&completedTasks) + }) if err != nil { log.Printf("Failed to count completed tasks for job %d: %v", jobID, err) return @@ -1204,14 +1433,17 @@ func (s *Server) updateJobStatusFromTasks(jobID int64) { // Count non-render tasks (like video generation) separately var nonRenderTasks, nonRenderCompleted int - s.db.QueryRow( + s.db.With(func(conn *sql.DB) error { + conn.QueryRow( `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND task_type != ? AND status IN (?, ?, ?, ?)`, jobID, types.TaskTypeRender, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed, ).Scan(&nonRenderTasks) - s.db.QueryRow( + conn.QueryRow( `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND task_type != ? AND status = ?`, jobID, types.TaskTypeRender, types.TaskStatusCompleted, ).Scan(&nonRenderCompleted) + return nil + }) // Calculate render task progress from frames var renderProgress float64 @@ -1231,14 +1463,17 @@ func (s *Server) updateJobStatusFromTasks(jobID int64) { } else { // Fall back to task-based progress for render tasks var renderTasks, renderCompleted int - s.db.QueryRow( + s.db.With(func(conn *sql.DB) error { + conn.QueryRow( `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND task_type = ? AND status IN (?, ?, ?, ?)`, jobID, types.TaskTypeRender, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed, ).Scan(&renderTasks) - s.db.QueryRow( + conn.QueryRow( `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND task_type = ? AND status = ?`, jobID, types.TaskTypeRender, types.TaskStatusCompleted, ).Scan(&renderCompleted) + return nil + }) if renderTasks > 0 { renderProgress = float64(renderCompleted) / float64(renderTasks) * 100.0 } @@ -1266,7 +1501,10 @@ func (s *Server) updateJobStatusFromTasks(jobID int64) { var jobStatus string var outputFormat sql.NullString - s.db.QueryRow(`SELECT output_format FROM jobs WHERE id = ?`, jobID).Scan(&outputFormat) + s.db.With(func(conn *sql.DB) error { + conn.QueryRow(`SELECT output_format FROM jobs WHERE id = ?`, jobID).Scan(&outputFormat) + return nil + }) outputFormatStr := "" if outputFormat.Valid { outputFormatStr = outputFormat.String @@ -1274,11 +1512,13 @@ func (s *Server) updateJobStatusFromTasks(jobID int64) { // Check if all non-cancelled tasks are completed var pendingOrRunningTasks int - err = s.db.QueryRow( + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?)`, jobID, types.TaskStatusPending, types.TaskStatusRunning, ).Scan(&pendingOrRunningTasks) + }) if err != nil { log.Printf("Failed to count pending/running tasks for job %d: %v", jobID, err) return @@ -1288,62 +1528,145 @@ func (s *Server) updateJobStatusFromTasks(jobID int64) { // All tasks are either completed or failed/cancelled // Check if any tasks failed var failedTasks int - s.db.QueryRow( + s.db.With(func(conn *sql.DB) error { + conn.QueryRow( `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`, jobID, types.TaskStatusFailed, ).Scan(&failedTasks) + return nil + }) if failedTasks > 0 { - // Some tasks failed - mark job as failed - jobStatus = string(types.JobStatusFailed) + // Some tasks failed - check if job has retries left + var retryCount, maxRetries int + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + `SELECT retry_count, max_retries FROM jobs WHERE id = ?`, + jobID, + ).Scan(&retryCount, &maxRetries) + }) + if err != nil { + log.Printf("Failed to get retry info for job %d: %v", jobID, err) + // Fall back to marking job as failed + jobStatus = string(types.JobStatusFailed) + } else if retryCount < maxRetries { + // Job has retries left - reset failed tasks and redistribute + if err := s.resetFailedTasksAndRedistribute(jobID); err != nil { + log.Printf("Failed to reset failed tasks for job %d: %v", jobID, err) + // If reset fails, mark job as failed + jobStatus = string(types.JobStatusFailed) + } else { + // Tasks reset successfully - job remains in running/pending state + // Don't update job status, just update progress + jobStatus = currentStatus // Keep current status + // Recalculate progress after reset (failed tasks are now pending again) + var newTotalTasks, newCompletedTasks int + s.db.With(func(conn *sql.DB) error { + conn.QueryRow( + `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?, ?, ?)`, + jobID, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed, + ).Scan(&newTotalTasks) + conn.QueryRow( + `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`, + jobID, types.TaskStatusCompleted, + ).Scan(&newCompletedTasks) + return nil + }) + if newTotalTasks > 0 { + progress = float64(newCompletedTasks) / float64(newTotalTasks) * 100.0 + } + // Update progress only + err := s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + `UPDATE jobs SET progress = ? WHERE id = ?`, + progress, jobID, + ) + return err + }) + if err != nil { + log.Printf("Failed to update job %d progress: %v", jobID, err) + } else { + // Broadcast job update via WebSocket + s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{ + "status": jobStatus, + "progress": progress, + }) + } + return // Exit early since we've handled the retry + } + } else { + // No retries left - mark job as failed and cancel active tasks + jobStatus = string(types.JobStatusFailed) + if err := s.cancelActiveTasksForJob(jobID); err != nil { + log.Printf("Failed to cancel active tasks for job %d: %v", jobID, err) + } + } } else { // All tasks completed successfully jobStatus = string(types.JobStatusCompleted) progress = 100.0 // Ensure progress is 100% when all tasks complete } - _, err := s.db.Exec( - `UPDATE jobs SET status = ?, progress = ?, completed_at = ? WHERE id = ?`, - jobStatus, progress, now, jobID, - ) - if err != nil { - log.Printf("Failed to update job %d status to %s: %v", jobID, jobStatus, err) - } else { - // Only log if status actually changed - if currentStatus != jobStatus { - log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed tasks: %d/%d)", jobID, currentStatus, jobStatus, progress, completedTasks, totalTasks) - } - // Broadcast job update via WebSocket - s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{ - "status": jobStatus, - "progress": progress, - "completed_at": now, + + // Update job status (if we didn't return early from retry logic) + if jobStatus != "" { + err := s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + `UPDATE jobs SET status = ?, progress = ?, completed_at = ? WHERE id = ?`, + jobStatus, progress, now, jobID, + ) + return err }) + if err != nil { + log.Printf("Failed to update job %d status to %s: %v", jobID, jobStatus, err) + } else { + // Only log if status actually changed + if currentStatus != jobStatus { + log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed tasks: %d/%d)", jobID, currentStatus, jobStatus, progress, completedTasks, totalTasks) + } + // Broadcast job update via WebSocket + s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{ + "status": jobStatus, + "progress": progress, + "completed_at": now, + }) + } } if outputFormatStr == "EXR_264_MP4" || outputFormatStr == "EXR_AV1_MP4" { // Check if a video generation task already exists for this job (any status) var existingVideoTask int - s.db.QueryRow( + s.db.With(func(conn *sql.DB) error { + conn.QueryRow( `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND task_type = ?`, jobID, types.TaskTypeVideoGeneration, ).Scan(&existingVideoTask) + return nil + }) if existingVideoTask == 0 { // Create a video generation task instead of calling generateMP4Video directly // This prevents race conditions when multiple runners complete frames simultaneously videoTaskTimeout := 86400 // 24 hours for video generation var videoTaskID int64 - err := s.db.QueryRow( + err := s.db.With(func(conn *sql.DB) error { + result, err := conn.Exec( `INSERT INTO tasks (job_id, frame_start, frame_end, task_type, status, timeout_seconds, max_retries) - VALUES (?, ?, ?, ?, ?, ?, ?) - RETURNING id`, + VALUES (?, ?, ?, ?, ?, ?, ?)`, jobID, 0, 0, types.TaskTypeVideoGeneration, types.TaskStatusPending, videoTaskTimeout, 1, - ).Scan(&videoTaskID) + ) + if err != nil { + return err + } + videoTaskID, err = result.LastInsertId() + return err + }) if err != nil { log.Printf("Failed to create video generation task for job %d: %v", jobID, err) } else { // Broadcast that a new task was added - log.Printf("Broadcasting task_added for job %d: video generation task %d", jobID, videoTaskID) + if s.verboseWSLogging { + log.Printf("Broadcasting task_added for job %d: video generation task %d", jobID, videoTaskID) + } s.broadcastTaskUpdate(jobID, videoTaskID, "task_added", map[string]interface{}{ "task_id": videoTaskID, "task_type": types.TaskTypeVideoGeneration, @@ -1360,28 +1683,37 @@ func (s *Server) updateJobStatusFromTasks(jobID int64) { } else { // Job has pending or running tasks - determine if it's running or still pending var runningTasks int - s.db.QueryRow( + s.db.With(func(conn *sql.DB) error { + conn.QueryRow( `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`, jobID, types.TaskStatusRunning, ).Scan(&runningTasks) + return nil + }) if runningTasks > 0 { // Has running tasks - job is running jobStatus = string(types.JobStatusRunning) var startedAt sql.NullTime - s.db.QueryRow(`SELECT started_at FROM jobs WHERE id = ?`, jobID).Scan(&startedAt) + s.db.With(func(conn *sql.DB) error { + conn.QueryRow(`SELECT started_at FROM jobs WHERE id = ?`, jobID).Scan(&startedAt) if !startedAt.Valid { - s.db.Exec(`UPDATE jobs SET started_at = ? WHERE id = ?`, now, jobID) + conn.Exec(`UPDATE jobs SET started_at = ? WHERE id = ?`, now, jobID) } + return nil + }) } else { // All tasks are pending - job is pending jobStatus = string(types.JobStatusPending) } - _, err := s.db.Exec( + err := s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( `UPDATE jobs SET status = ?, progress = ? WHERE id = ?`, jobStatus, progress, jobID, ) + return err + }) if err != nil { log.Printf("Failed to update job %d status to %s: %v", jobID, jobStatus, err) } else { @@ -1389,57 +1721,102 @@ func (s *Server) updateJobStatusFromTasks(jobID int64) { if currentStatus != jobStatus { log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed: %d/%d, pending: %d, running: %d)", jobID, currentStatus, jobStatus, progress, completedTasks, totalTasks, pendingOrRunningTasks-runningTasks, runningTasks) } + // Broadcast job update during execution (not just on completion) + s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{ + "status": jobStatus, + "progress": progress, + }) } } } // broadcastLogToFrontend broadcasts log to connected frontend clients func (s *Server) broadcastLogToFrontend(taskID int64, logEntry WSLogEntry) { - // Get job_id from task - var jobID int64 - err := s.db.QueryRow("SELECT job_id FROM tasks WHERE id = ?", taskID).Scan(&jobID) + // Get job_id, user_id, and task status from task + var jobID, userID int64 + var taskStatus string + var taskRunnerID sql.NullInt64 + var taskStartedAt sql.NullTime + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + `SELECT t.job_id, j.user_id, t.status, t.runner_id, t.started_at + FROM tasks t + JOIN jobs j ON t.job_id = j.id + WHERE t.id = ?`, + taskID, + ).Scan(&jobID, &userID, &taskStatus, &taskRunnerID, &taskStartedAt) + }) if err != nil { return } + // Get full log entry from database for consistency + // Use a more reliable query that gets the most recent log with matching message + // This avoids race conditions with concurrent inserts + var taskLog types.TaskLog + var runnerID sql.NullInt64 + var stepName sql.NullString + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + `SELECT id, task_id, runner_id, log_level, message, step_name, created_at + FROM task_logs WHERE task_id = ? AND message = ? ORDER BY id DESC LIMIT 1`, + taskID, logEntry.Message, + ).Scan(&taskLog.ID, &taskLog.TaskID, &runnerID, &taskLog.LogLevel, &taskLog.Message, &stepName, &taskLog.CreatedAt) + }) + if err != nil { + return + } + if runnerID.Valid { + taskLog.RunnerID = &runnerID.Int64 + } + if stepName.Valid { + taskLog.StepName = stepName.String + } + + msg := map[string]interface{}{ + "type": "log", + "task_id": taskID, + "job_id": jobID, + "data": taskLog, + "timestamp": time.Now().Unix(), + } + + // Broadcast to client WebSocket if subscribed to logs:{jobId}:{taskId} + channel := fmt.Sprintf("logs:%d:%d", jobID, taskID) + if s.verboseWSLogging { + runnerIDStr := "none" + if taskRunnerID.Valid { + runnerIDStr = fmt.Sprintf("%d", taskRunnerID.Int64) + } + log.Printf("broadcastLogToFrontend: Broadcasting log for task %d (job %d, user %d) on channel %s, log_id=%d, task_status=%s, runner_id=%s", taskID, jobID, userID, channel, taskLog.ID, taskStatus, runnerIDStr) + } + s.broadcastToClient(userID, channel, msg) + + // If task status is pending but logs are coming in, log a warning + // This indicates the initial assignment broadcast may have been missed or the database update failed + if taskStatus == string(types.TaskStatusPending) { + log.Printf("broadcastLogToFrontend: ERROR - Task %d has logs but status is 'pending'. This indicates the initial task assignment failed or the task_update broadcast was missed.", taskID) + } + + // Also broadcast to old WebSocket connection (for backwards compatibility during migration) key := fmt.Sprintf("%d:%d", jobID, taskID) s.frontendConnsMu.RLock() conn, exists := s.frontendConns[key] s.frontendConnsMu.RUnlock() if exists && conn != nil { - // Get full log entry from database for consistency - // Use a more reliable query that gets the most recent log with matching message - // This avoids race conditions with concurrent inserts - var log types.TaskLog - var runnerID sql.NullInt64 - err := s.db.QueryRow( - `SELECT id, task_id, runner_id, log_level, message, step_name, created_at - FROM task_logs WHERE task_id = ? AND message = ? ORDER BY id DESC LIMIT 1`, - taskID, logEntry.Message, - ).Scan(&log.ID, &log.TaskID, &runnerID, &log.LogLevel, &log.Message, &log.StepName, &log.CreatedAt) - if err == nil { - if runnerID.Valid { - log.RunnerID = &runnerID.Int64 - } - msg := map[string]interface{}{ - "type": "log", - "data": log, - "timestamp": time.Now().Unix(), - } - // Serialize writes to prevent concurrent write panics - s.frontendConnsWriteMuMu.RLock() - writeMu, hasMu := s.frontendConnsWriteMu[key] - s.frontendConnsWriteMuMu.RUnlock() + // Serialize writes to prevent concurrent write panics + s.frontendConnsWriteMuMu.RLock() + writeMu, hasMu := s.frontendConnsWriteMu[key] + s.frontendConnsWriteMuMu.RUnlock() - if hasMu && writeMu != nil { - writeMu.Lock() - conn.WriteJSON(msg) - writeMu.Unlock() - } else { - // Fallback if mutex doesn't exist yet (shouldn't happen, but be safe) - conn.WriteJSON(msg) - } + if hasMu && writeMu != nil { + writeMu.Lock() + conn.WriteJSON(msg) + writeMu.Unlock() + } else { + // Fallback if mutex doesn't exist yet (shouldn't happen, but be safe) + conn.WriteJSON(msg) } } } @@ -1447,8 +1824,10 @@ func (s *Server) broadcastLogToFrontend(taskID int64, logEntry WSLogEntry) { // triggerTaskDistribution triggers task distribution in a serialized manner func (s *Server) triggerTaskDistribution() { go func() { - // Try to acquire lock - if already running, skip + // Try to acquire lock - if already running, log and skip if !s.taskDistMu.TryLock() { + // Log when distribution is skipped to help with debugging + log.Printf("Task distribution already in progress, skipping trigger") return // Distribution already in progress } defer s.taskDistMu.Unlock() @@ -1461,12 +1840,14 @@ func (s *Server) triggerTaskDistribution() { func (s *Server) distributeTasksToRunners() { // Quick check: if there are no pending tasks, skip the expensive query var pendingCount int - err := s.db.QueryRow( + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( `SELECT COUNT(*) FROM tasks t JOIN jobs j ON t.job_id = j.id WHERE t.status = ? AND j.status != ?`, types.TaskStatusPending, types.JobStatusCancelled, ).Scan(&pendingCount) + }) if err != nil { log.Printf("Failed to check pending tasks count: %v", err) return @@ -1477,7 +1858,10 @@ func (s *Server) distributeTasksToRunners() { } // Get all pending tasks - rows, err := s.db.Query( + var rows *sql.Rows + err = s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query( `SELECT t.id, t.job_id, t.frame_start, t.frame_end, t.task_type, j.allow_parallel_runners, j.status as job_status, j.name as job_name, j.user_id FROM tasks t JOIN jobs j ON t.job_id = j.id @@ -1485,6 +1869,8 @@ func (s *Server) distributeTasksToRunners() { ORDER BY t.created_at ASC`, types.TaskStatusPending, types.JobStatusCancelled, ) + return err + }) if err != nil { log.Printf("Failed to query pending tasks: %v", err) return @@ -1556,23 +1942,29 @@ func (s *Server) distributeTasksToRunners() { runnerScopes := make(map[int64]string) for _, runnerID := range connectedRunners { var priority int - var capabilitiesJSON string + var capabilitiesJSON sql.NullString var scope string - err := s.db.QueryRow("SELECT priority, capabilities, api_key_scope FROM runners WHERE id = ?", runnerID).Scan(&priority, &capabilitiesJSON, &scope) + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT priority, capabilities, api_key_scope FROM runners WHERE id = ?", runnerID).Scan(&priority, &capabilitiesJSON, &scope) + }) if err != nil { // Default to 100 if priority not found priority = 100 - capabilitiesJSON = "{}" + capabilitiesJSON = sql.NullString{String: "{}", Valid: true} } runnerPriorities[runnerID] = priority runnerScopes[runnerID] = scope // Parse capabilities JSON (can contain both bools and numbers) + capabilitiesStr := "{}" + if capabilitiesJSON.Valid { + capabilitiesStr = capabilitiesJSON.String + } var capabilities map[string]interface{} - if err := json.Unmarshal([]byte(capabilitiesJSON), &capabilities); err != nil { + if err := json.Unmarshal([]byte(capabilitiesStr), &capabilities); err != nil { // If parsing fails, try old format (map[string]bool) for backward compatibility var oldCapabilities map[string]bool - if err2 := json.Unmarshal([]byte(capabilitiesJSON), &oldCapabilities); err2 == nil { + if err2 := json.Unmarshal([]byte(capabilitiesStr), &oldCapabilities); err2 == nil { // Convert old format to new format capabilities = make(map[string]interface{}) for k, v := range oldCapabilities { @@ -1590,10 +1982,13 @@ func (s *Server) distributeTasksToRunners() { for _, runnerID := range connectedRunners { // Ensure database status matches WebSocket connection // Update status to online if it's not already - _, _ = s.db.Exec( + s.db.With(func(conn *sql.DB) error { + _, _ = conn.Exec( `UPDATE runners SET status = ?, last_heartbeat = ? WHERE id = ? AND status != ?`, types.RunnerStatusOnline, time.Now(), runnerID, types.RunnerStatusOnline, ) + return nil + }) } if len(connectedRunners) == 0 { @@ -1645,13 +2040,10 @@ func (s *Server) distributeTasksToRunners() { // Find available runner var selectedRunnerID int64 var bestRunnerID int64 - var bestCapabilityMatch int = -1 // 0 = only required, 1 = required + others, 2 = no match var bestPriority int = -1 var bestTaskCount int = -1 var bestRandom float64 = -1 // Random tie-breaker - isMetadataTask := task.TaskType == string(types.TaskTypeMetadata) - // Try to find the best runner for this task for _, runnerID := range connectedRunners { // Check if runner's API key scope allows working on this job @@ -1663,9 +2055,13 @@ func (s *Server) distributeTasksToRunners() { if runnerScope == "user" { // Get the user who created this runner's API key var apiKeyID sql.NullInt64 - err := s.db.QueryRow("SELECT api_key_id FROM runners WHERE id = ?", runnerID).Scan(&apiKeyID) + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT api_key_id FROM runners WHERE id = ?", runnerID).Scan(&apiKeyID) + }) if err == nil && apiKeyID.Valid { - err = s.db.QueryRow("SELECT created_by FROM runner_api_keys WHERE id = ?", apiKeyID.Int64).Scan(&apiKeyCreatedBy) + s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT created_by FROM runner_api_keys WHERE id = ?", apiKeyID.Int64).Scan(&apiKeyCreatedBy) + }) if err != nil { continue // Skip this runner if we can't determine API key ownership } @@ -1694,110 +2090,43 @@ func (s *Server) distributeTasksToRunners() { continue // Runner doesn't have required capability } - // For video generation tasks, check GPU availability and ensure no blender tasks are running - if task.TaskType == string(types.TaskTypeVideoGeneration) { - // Check if runner has any blender/render tasks running (mutual exclusion) - var runningBlenderTasks int - s.db.QueryRow( - `SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status = ? AND task_type = ?`, - runnerID, types.TaskStatusRunning, types.TaskTypeRender, - ).Scan(&runningBlenderTasks) - - if runningBlenderTasks > 0 { - continue // Runner is busy with blender tasks, cannot run video tasks simultaneously - } - - // Get GPU count from capabilities - var gpuCount int - if videoGPUs, ok := capabilities["video_gpu_count"]; ok { - if count, ok := videoGPUs.(float64); ok { - gpuCount = int(count) - } else if count, ok := videoGPUs.(int); ok { - gpuCount = count - } - } - - // Count how many video generation tasks are currently running on this runner - var runningVideoTasks int - s.db.QueryRow( - `SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status = ? AND task_type = ?`, - runnerID, types.TaskStatusRunning, types.TaskTypeVideoGeneration, - ).Scan(&runningVideoTasks) - - // If all GPUs are in use, skip this runner - if gpuCount > 0 && runningVideoTasks >= gpuCount { - continue // All GPUs are busy - } + // Check if runner has ANY tasks (pending or running) - one task at a time only + // This prevents any runner from doing more than one task at a time + var activeTaskCount int + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + `SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status IN (?, ?)`, + runnerID, types.TaskStatusPending, types.TaskStatusRunning, + ).Scan(&activeTaskCount) + }) + if err != nil { + log.Printf("Failed to check active tasks for runner %d: %v", runnerID, err) + continue } - // For render/blender tasks, check if runner is busy and ensure no video tasks are running - if !isMetadataTask && task.TaskType != string(types.TaskTypeVideoGeneration) { - // Check if runner has any video generation tasks running (mutual exclusion) - var runningVideoTasks int - s.db.QueryRow( - `SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status = ? AND task_type = ?`, - runnerID, types.TaskStatusRunning, types.TaskTypeVideoGeneration, - ).Scan(&runningVideoTasks) - - if runningVideoTasks > 0 { - continue // Runner is busy with video tasks, cannot run blender tasks simultaneously - } - - // Check if runner is busy (has running render tasks) - only for non-metadata, non-video tasks - var runningCount int - s.db.QueryRow( - `SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status = ? AND task_type NOT IN (?, ?)`, - runnerID, types.TaskStatusRunning, types.TaskTypeMetadata, types.TaskTypeVideoGeneration, - ).Scan(&runningCount) - - if runningCount > 0 { - continue // Runner is busy with render tasks - } + if activeTaskCount > 0 { + continue // Runner is busy with another task, cannot run any other tasks } // For non-parallel jobs, check if runner already has tasks from this job if !task.AllowParallelRunners { var jobTaskCount int - s.db.QueryRow( + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( `SELECT COUNT(*) FROM tasks WHERE job_id = ? AND runner_id = ? AND status IN (?, ?)`, task.JobID, runnerID, types.TaskStatusPending, types.TaskStatusRunning, ).Scan(&jobTaskCount) + }) + if err != nil { + log.Printf("Failed to check job tasks for runner %d: %v", runnerID, err) + continue + } if jobTaskCount > 0 { continue // Another runner is working on this job } } - // Determine capability match type - // Count how many capabilities the runner has - capabilityCount := 0 - hasBlender := false - hasFFmpeg := false - if blenderVal, ok := capabilities["blender"]; ok { - if b, ok := blenderVal.(bool); ok { - hasBlender = b - } - } - if ffmpegVal, ok := capabilities["ffmpeg"]; ok { - if f, ok := ffmpegVal.(bool); ok { - hasFFmpeg = f - } - } - if hasBlender { - capabilityCount++ - } - if hasFFmpeg { - capabilityCount++ - } - - // Determine match type: 0 = only required capability, 1 = required + others - var capabilityMatch int - if capabilityCount == 1 { - capabilityMatch = 0 // Only has the required capability - } else { - capabilityMatch = 1 // Has required + other capabilities - } - // Get runner priority and task count priority := runnerPriorities[runnerID] currentTaskCount := runnerTaskCounts[runnerID] @@ -1805,43 +2134,36 @@ func (s *Server) distributeTasksToRunners() { randomValue := rand.Float64() // Selection priority: - // 1. Capability match (0 = only required, 1 = required + others) - // 2. Priority (higher is better) - // 3. Task count (fewer is better) - // 4. Random value (absolute tie-breaker) + // 1. Priority (higher is better) + // 2. Task count (fewer is better) + // 3. Random value (absolute tie-breaker) isBetter := false if bestRunnerID == 0 { isBetter = true - } else if capabilityMatch < bestCapabilityMatch { - // Prefer runners with only the required capability + } else if priority > bestPriority { + // Higher priority isBetter = true - } else if capabilityMatch == bestCapabilityMatch { - if priority > bestPriority { - // Same capability match, but higher priority + } else if priority == bestPriority { + if currentTaskCount < bestTaskCount { + // Same priority, but fewer tasks assigned in this cycle isBetter = true - } else if priority == bestPriority { - if currentTaskCount < bestTaskCount { - // Same capability match and priority, but fewer tasks + } else if currentTaskCount == bestTaskCount { + // Absolute tie - use random value as tie-breaker + if randomValue > bestRandom { isBetter = true - } else if currentTaskCount == bestTaskCount { - // Absolute tie - use random value as tie-breaker - if randomValue > bestRandom { - isBetter = true - } } } } if isBetter { bestRunnerID = runnerID - bestCapabilityMatch = capabilityMatch bestPriority = priority bestTaskCount = currentTaskCount bestRandom = randomValue } } - // Use the best runner we found (prioritized by capability match, then priority, then load balanced) + // Use the best runner we found (prioritized by priority, then load balanced) if bestRunnerID != 0 { selectedRunnerID = bestRunnerID } @@ -1860,53 +2182,74 @@ func (s *Server) distributeTasksToRunners() { // Atomically assign task to runner using UPDATE with WHERE runner_id IS NULL // This prevents race conditions when multiple goroutines try to assign the same task - // Use a transaction to ensure atomicity and handle DuckDB's foreign key constraints + // Use a transaction to ensure atomicity now := time.Now() - tx, err := s.db.Begin() - if err != nil { - log.Printf("Failed to begin transaction for task %d: %v", task.TaskID, err) - continue - } - + var rowsAffected int64 + var verifyStatus string + var verifyRunnerID sql.NullInt64 + var verifyStartedAt sql.NullTime + err := s.db.WithTx(func(tx *sql.Tx) error { result, err := tx.Exec( `UPDATE tasks SET runner_id = ?, status = ?, started_at = ? WHERE id = ? AND runner_id IS NULL AND status = ?`, selectedRunnerID, types.TaskStatusRunning, now, task.TaskID, types.TaskStatusPending, ) if err != nil { - tx.Rollback() + return err + } + + // Check if the update actually affected a row (task was successfully assigned) + rowsAffected, err = result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected == 0 { + return sql.ErrNoRows // Task was already assigned + } + + // Verify the update within the transaction before committing + // This ensures we catch any issues before the transaction is committed + err = tx.QueryRow( + `SELECT status, runner_id, started_at FROM tasks WHERE id = ?`, + task.TaskID, + ).Scan(&verifyStatus, &verifyRunnerID, &verifyStartedAt) + if err != nil { + return err + } + if verifyStatus != string(types.TaskStatusRunning) { + return fmt.Errorf("task status is %s after assignment, expected running", verifyStatus) + } + if !verifyRunnerID.Valid || verifyRunnerID.Int64 != selectedRunnerID { + return fmt.Errorf("task runner_id is %v after assignment, expected %d", verifyRunnerID, selectedRunnerID) + } + + return nil // Commit on nil return + }) + if err == sql.ErrNoRows { + // Task was already assigned by another goroutine, skip + continue + } + if err != nil { log.Printf("Failed to atomically assign task %d: %v", task.TaskID, err) continue } - // Check if the update actually affected a row (task was successfully assigned) - rowsAffected, err := result.RowsAffected() - if err != nil { - tx.Rollback() - log.Printf("Failed to get rows affected for task %d: %v", task.TaskID, err) - continue - } + log.Printf("Verified and committed task %d assignment: status=%s, runner_id=%d, started_at=%v", task.TaskID, verifyStatus, verifyRunnerID.Int64, verifyStartedAt) - if rowsAffected == 0 { - // Task was already assigned by another goroutine, skip - tx.Rollback() - continue - } - - // Commit the assignment before attempting WebSocket send - // If send fails, we'll rollback in a separate transaction - err = tx.Commit() - if err != nil { - log.Printf("Failed to commit transaction for task %d: %v", task.TaskID, err) - continue - } - - // Broadcast task assignment - s.broadcastTaskUpdate(task.JobID, task.TaskID, "task_update", map[string]interface{}{ + // Broadcast task assignment - include all fields to ensure frontend has complete info + updateData := map[string]interface{}{ "status": types.TaskStatusRunning, "runner_id": selectedRunnerID, - "started_at": now, - }) + "started_at": verifyStartedAt.Time, + } + if !verifyStartedAt.Valid { + updateData["started_at"] = now + } + if s.verboseWSLogging { + log.Printf("Broadcasting task_update for task %d (job %d, user %d): status=%s, runner_id=%d, started_at=%v", task.TaskID, task.JobID, task.JobUserID, types.TaskStatusRunning, selectedRunnerID, now) + } + s.broadcastTaskUpdate(task.JobID, task.TaskID, "task_update", updateData) // Task was successfully assigned in database, now send via WebSocket log.Printf("Assigned task %d (type: %s, job: %d) to runner %d", task.TaskID, task.TaskType, task.JobID, selectedRunnerID) @@ -1919,26 +2262,54 @@ func (s *Server) distributeTasksToRunners() { log.Printf("Failed to send task %d to runner %d: %v", task.TaskID, selectedRunnerID, err) // Log assignment failure s.logTaskEvent(task.TaskID, nil, types.LogLevelError, fmt.Sprintf("Failed to send task to runner %d: %v", selectedRunnerID, err), "") - // Rollback the assignment if WebSocket send fails using a new transaction - rollbackTx, rollbackErr := s.db.Begin() - if rollbackErr == nil { - _, rollbackErr = rollbackTx.Exec( + // Rollback the assignment if WebSocket send fails with retry mechanism + rollbackSuccess := false + for retry := 0; retry < 3; retry++ { + rollbackErr := s.db.WithTx(func(tx *sql.Tx) error { + _, err := tx.Exec( `UPDATE tasks SET runner_id = NULL, status = ?, started_at = NULL WHERE id = ? AND runner_id = ?`, types.TaskStatusPending, task.TaskID, selectedRunnerID, ) - if rollbackErr == nil { - rollbackTx.Commit() - // Log rollback - s.logTaskEvent(task.TaskID, nil, types.LogLevelWarn, fmt.Sprintf("Task assignment rolled back - runner %d connection failed", selectedRunnerID), "") - // Update job status after rollback - s.updateJobStatusFromTasks(task.JobID) - // Trigger redistribution + return err + }) + if rollbackErr != nil { + log.Printf("Failed to rollback task %d assignment (attempt %d/3): %v", task.TaskID, retry+1, rollbackErr) + if retry < 2 { + time.Sleep(time.Duration(retry+1) * 100 * time.Millisecond) // Exponential backoff + continue + } + // Final attempt failed + log.Printf("CRITICAL: Failed to rollback task %d after 3 attempts - task may be in inconsistent state", task.TaskID) s.triggerTaskDistribution() - } else { - rollbackTx.Rollback() - log.Printf("Failed to rollback task %d assignment: %v", task.TaskID, rollbackErr) + break } + // Rollback succeeded + rollbackSuccess = true + s.logTaskEvent(task.TaskID, nil, types.LogLevelWarn, fmt.Sprintf("Task assignment rolled back - runner %d connection failed", selectedRunnerID), "") + s.updateJobStatusFromTasks(task.JobID) + s.triggerTaskDistribution() + break + } + if !rollbackSuccess { + // Schedule background cleanup for inconsistent state + go func() { + time.Sleep(5 * time.Second) + // Retry rollback one more time in background + err := s.db.WithTx(func(tx *sql.Tx) error { + _, err := tx.Exec( + `UPDATE tasks SET runner_id = NULL, status = ?, started_at = NULL + WHERE id = ? AND runner_id = ? AND status = ?`, + types.TaskStatusPending, task.TaskID, selectedRunnerID, types.TaskStatusRunning, + ) + return err + }) + if err == nil { + log.Printf("Background cleanup: Successfully rolled back task %d", task.TaskID) + s.updateJobStatusFromTasks(task.JobID) + s.triggerTaskDistribution() + } + }() } } else { // WebSocket send succeeded, update job status @@ -1949,24 +2320,28 @@ func (s *Server) distributeTasksToRunners() { // assignTaskToRunner sends a task to a runner via WebSocket func (s *Server) assignTaskToRunner(runnerID int64, taskID int64) error { + // Hold read lock during entire operation to prevent connection from being replaced s.runnerConnsMu.RLock() conn, exists := s.runnerConns[runnerID] - s.runnerConnsMu.RUnlock() - if !exists { + s.runnerConnsMu.RUnlock() return fmt.Errorf("runner %d not connected", runnerID) } + // Keep lock held to prevent connection replacement during operation + defer s.runnerConnsMu.RUnlock() // Get task details var task WSTaskAssignment var jobName string var outputFormat sql.NullString var taskType string - err := s.db.QueryRow( + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( `SELECT t.job_id, t.frame_start, t.frame_end, t.task_type, j.name, j.output_format FROM tasks t JOIN jobs j ON t.job_id = j.id WHERE t.id = ?`, taskID, ).Scan(&task.JobID, &task.FrameStart, &task.FrameEnd, &taskType, &jobName, &outputFormat) + }) if err != nil { return err } @@ -1982,10 +2357,15 @@ func (s *Server) assignTaskToRunner(runnerID int64, taskID int64) error { task.TaskType = taskType // Get input files - rows, err := s.db.Query( + var rows *sql.Rows + err = s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query( `SELECT file_path FROM job_files WHERE job_id = ? AND file_type = ?`, task.JobID, types.JobFileTypeInput, ) + return err + }) if err == nil { defer rows.Close() for rows.Next() { @@ -2005,11 +2385,14 @@ func (s *Server) assignTaskToRunner(runnerID int64, taskID int64) error { log.Printf("ERROR: %s", errMsg) // Don't send the task - it will fail anyway // Rollback the assignment - s.db.Exec( + s.db.With(func(conn *sql.DB) error { + _, _ = conn.Exec( `UPDATE tasks SET runner_id = NULL, status = ?, started_at = NULL WHERE id = ?`, types.TaskStatusPending, taskID, ) + return nil + }) s.logTaskEvent(taskID, nil, types.LogLevelError, errMsg, "") return errors.New(errMsg) } @@ -2017,7 +2400,9 @@ func (s *Server) assignTaskToRunner(runnerID int64, taskID int64) error { // Note: Task is already assigned in database by the atomic update in distributeTasksToRunners // We just need to verify it's still assigned to this runner var assignedRunnerID sql.NullInt64 - err = s.db.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskID).Scan(&assignedRunnerID) + err = s.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT runner_id FROM tasks WHERE id = ?", taskID).Scan(&assignedRunnerID) + }) if err != nil { return fmt.Errorf("task not found: %w", err) } @@ -2041,14 +2426,8 @@ func (s *Server) assignTaskToRunner(runnerID int64, taskID int64) error { return fmt.Errorf("runner %d write mutex not found", runnerID) } - // Re-check connection is still valid before writing - s.runnerConnsMu.RLock() - _, stillExists := s.runnerConns[runnerID] - s.runnerConnsMu.RUnlock() - if !stillExists { - return fmt.Errorf("runner %d disconnected", runnerID) - } - + // Connection is still valid (we're holding the read lock) + // Write to connection with mutex protection writeMu.Lock() err = conn.WriteJSON(msg) writeMu.Unlock() @@ -2060,11 +2439,16 @@ func (s *Server) redistributeRunnerTasks(runnerID int64) { log.Printf("Starting task redistribution for disconnected runner %d", runnerID) // Get tasks assigned to this runner that are still running - taskRows, err := s.db.Query( + var taskRows *sql.Rows + err := s.db.With(func(conn *sql.DB) error { + var err error + taskRows, err = conn.Query( `SELECT id, retry_count, max_retries, job_id FROM tasks WHERE runner_id = ? AND status = ?`, runnerID, types.TaskStatusRunning, ) + return err + }) if err != nil { log.Printf("Failed to query tasks for runner %d: %v", runnerID, err) return @@ -2106,11 +2490,22 @@ func (s *Server) redistributeRunnerTasks(runnerID int64) { for _, task := range tasksToReset { if task.RetryCount >= task.MaxRetries { // Mark as failed - _, err = s.db.Exec( - `UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL, completed_at = ? - WHERE id = ? AND runner_id = ?`, - types.TaskStatusFailed, "Runner disconnected, max retries exceeded", time.Now(), task.ID, runnerID, - ) + err = s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec(`UPDATE tasks SET status = ? WHERE id = ? AND runner_id = ?`, types.TaskStatusFailed, task.ID, runnerID) + if err != nil { + return err + } + _, err = conn.Exec(`UPDATE tasks SET error_message = ? WHERE id = ? AND runner_id = ?`, "Runner disconnected, max retries exceeded", task.ID, runnerID) + if err != nil { + return err + } + _, err = conn.Exec(`UPDATE tasks SET runner_id = NULL WHERE id = ? AND runner_id = ?`, task.ID, runnerID) + if err != nil { + return err + } + _, err = conn.Exec(`UPDATE tasks SET completed_at = ? WHERE id = ?`, time.Now(), task.ID) + return err + }) if err != nil { log.Printf("Failed to mark task %d as failed: %v", task.ID, err) } else { @@ -2121,11 +2516,14 @@ func (s *Server) redistributeRunnerTasks(runnerID int64) { } } else { // Reset to pending so it can be redistributed - _, err = s.db.Exec( + err = s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( `UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL, retry_count = retry_count + 1, started_at = NULL WHERE id = ? AND runner_id = ?`, types.TaskStatusPending, task.ID, runnerID, ) + return err + }) if err != nil { log.Printf("Failed to reset task %d: %v", task.ID, err) } else { @@ -2164,11 +2562,14 @@ func (s *Server) logTaskEvent(taskID int64, runnerID *int64, logLevel types.LogL runnerIDValue = *runnerID } - _, err := s.db.Exec( + err := s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( `INSERT INTO task_logs (task_id, runner_id, log_level, message, step_name, created_at) VALUES (?, ?, ?, ?, ?, ?)`, taskID, runnerIDValue, logLevel, message, stepName, time.Now(), ) + return err + }) if err != nil { log.Printf("Failed to log task event for task %d: %v", taskID, err) return @@ -2182,3 +2583,120 @@ func (s *Server) logTaskEvent(taskID int64, runnerID *int64, logLevel types.LogL StepName: stepName, }) } + +// cleanupOldOfflineRunners periodically deletes runners that have been offline for more than 1 month +func (s *Server) cleanupOldOfflineRunners() { + // Run cleanup every 24 hours + ticker := time.NewTicker(24 * time.Hour) + defer ticker.Stop() + + // Run once immediately on startup + s.cleanupOldOfflineRunnersOnce() + + for range ticker.C { + s.cleanupOldOfflineRunnersOnce() + } +} + +// cleanupOldOfflineRunnersOnce finds and deletes runners that have been offline for more than 1 month +func (s *Server) cleanupOldOfflineRunnersOnce() { + defer func() { + if r := recover(); r != nil { + log.Printf("Panic in cleanupOldOfflineRunners: %v", r) + } + }() + + // Find runners that: + // 1. Are offline + // 2. Haven't had a heartbeat in over 1 month + // 3. Are not currently connected via WebSocket + var rows *sql.Rows + err := s.db.With(func(conn *sql.DB) error { + var err error + rows, err = conn.Query( + `SELECT id, name FROM runners + WHERE status = ? + AND last_heartbeat < datetime('now', '-1 month')`, + types.RunnerStatusOffline, + ) + return err + }) + if err != nil { + log.Printf("Failed to query old offline runners: %v", err) + return + } + defer rows.Close() + + type runnerInfo struct { + ID int64 + Name string + } + var runnersToDelete []runnerInfo + + s.runnerConnsMu.RLock() + for rows.Next() { + var info runnerInfo + if err := rows.Scan(&info.ID, &info.Name); err == nil { + // Double-check runner is not connected via WebSocket + if _, connected := s.runnerConns[info.ID]; !connected { + runnersToDelete = append(runnersToDelete, info) + } + } + } + s.runnerConnsMu.RUnlock() + rows.Close() + + if len(runnersToDelete) == 0 { + return + } + + log.Printf("Cleaning up %d old offline runners (offline for more than 1 month)", len(runnersToDelete)) + + // Delete each runner + for _, runner := range runnersToDelete { + // First, check if there are any tasks still assigned to this runner + // If so, reset them to pending before deleting the runner + var assignedTaskCount int + err := s.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + `SELECT COUNT(*) FROM tasks WHERE runner_id = ? AND status IN (?, ?)`, + runner.ID, types.TaskStatusRunning, types.TaskStatusPending, + ).Scan(&assignedTaskCount) + }) + if err != nil { + log.Printf("Failed to check assigned tasks for runner %d: %v", runner.ID, err) + continue + } + + if assignedTaskCount > 0 { + // Reset any tasks assigned to this runner + log.Printf("Resetting %d tasks assigned to runner %d before deletion", assignedTaskCount, runner.ID) + err = s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + `UPDATE tasks SET runner_id = NULL, status = ? WHERE runner_id = ? AND status IN (?, ?)`, + types.TaskStatusPending, runner.ID, types.TaskStatusRunning, types.TaskStatusPending, + ) + return err + }) + if err != nil { + log.Printf("Failed to reset tasks for runner %d: %v", runner.ID, err) + continue + } + } + + // Delete the runner + err = s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec("DELETE FROM runners WHERE id = ?", runner.ID) + return err + }) + if err != nil { + log.Printf("Failed to delete runner %d (%s): %v", runner.ID, runner.Name, err) + continue + } + + log.Printf("Deleted old offline runner: %d (%s)", runner.ID, runner.Name) + } + + // Trigger task distribution if any tasks were reset + s.triggerTaskDistribution() +} diff --git a/internal/api/server.go b/internal/api/server.go index f2965fd..06639cf 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -10,15 +10,18 @@ import ( "net/http" "os" "path/filepath" + "runtime" "strconv" "strings" "sync" "time" authpkg "jiggablend/internal/auth" + "jiggablend/internal/config" "jiggablend/internal/database" "jiggablend/internal/storage" "jiggablend/pkg/types" + "jiggablend/web" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" @@ -26,79 +29,299 @@ import ( "github.com/gorilla/websocket" ) +// Configuration constants +const ( + // WebSocket timeouts + WSReadDeadline = 90 * time.Second + WSPingInterval = 30 * time.Second + WSWriteDeadline = 10 * time.Second + + // Task timeouts + DefaultTaskTimeout = 300 // 5 minutes for frame rendering + VideoGenerationTimeout = 86400 // 24 hours for video generation + DefaultJobTimeout = 86400 // 24 hours + + // Limits + MaxFrameRange = 10000 + MaxUploadSize = 50 << 30 // 50 GB + RunnerHeartbeatTimeout = 90 * time.Second + TaskDistributionInterval = 10 * time.Second + ProgressUpdateThrottle = 2 * time.Second + + // Cookie settings + SessionCookieMaxAge = 86400 // 24 hours +) + // Server represents the API server type Server struct { db *database.DB + cfg *config.Config auth *authpkg.Auth secrets *authpkg.Secrets storage *storage.Storage router *chi.Mux // WebSocket connections - wsUpgrader websocket.Upgrader - runnerConns map[int64]*websocket.Conn - runnerConnsMu sync.RWMutex + wsUpgrader websocket.Upgrader + runnerConns map[int64]*websocket.Conn + runnerConnsMu sync.RWMutex // Mutexes for each runner connection to serialize writes runnerConnsWriteMu map[int64]*sync.Mutex runnerConnsWriteMuMu sync.RWMutex - frontendConns map[string]*websocket.Conn // key: "jobId:taskId" - frontendConnsMu sync.RWMutex - // Mutexes for each frontend connection to serialize writes - frontendConnsWriteMu map[string]*sync.Mutex // key: "jobId:taskId" + + // DEPRECATED: Old WebSocket connection maps (kept for backwards compatibility) + // These will be removed in a future release. Use clientConns instead. + frontendConns map[string]*websocket.Conn // key: "jobId:taskId" + frontendConnsMu sync.RWMutex + frontendConnsWriteMu map[string]*sync.Mutex frontendConnsWriteMuMu sync.RWMutex - // Job list WebSocket connections (key: userID) - jobListConns map[int64]*websocket.Conn - jobListConnsMu sync.RWMutex - // Single job WebSocket connections (key: "userId:jobId") - jobConns map[string]*websocket.Conn - jobConnsMu sync.RWMutex - // Mutexes for job WebSocket connections - jobConnsWriteMu map[string]*sync.Mutex - jobConnsWriteMuMu sync.RWMutex + jobListConns map[int64]*websocket.Conn + jobListConnsMu sync.RWMutex + jobConns map[string]*websocket.Conn + jobConnsMu sync.RWMutex + jobConnsWriteMu map[string]*sync.Mutex + jobConnsWriteMuMu sync.RWMutex + // Throttling for progress updates (per job) progressUpdateTimes map[int64]time.Time // key: jobID progressUpdateTimesMu sync.RWMutex + // Throttling for task status updates (per task) + taskUpdateTimes map[int64]time.Time // key: taskID + taskUpdateTimesMu sync.RWMutex // Task distribution serialization taskDistMu sync.Mutex // Mutex to prevent concurrent distribution + + // Client WebSocket connections (new unified WebSocket) + clientConns map[int64]*ClientConnection // key: userID + clientConnsMu sync.RWMutex + + // Upload session tracking + uploadSessions map[string]*UploadSession // sessionId -> session info + uploadSessionsMu sync.RWMutex + + // Verbose WebSocket logging (set to true to enable detailed WebSocket logs) + verboseWSLogging bool + + // Server start time for health checks + startTime time.Time +} + +// ClientConnection represents a client WebSocket connection with subscriptions +type ClientConnection struct { + Conn *websocket.Conn + UserID int64 + IsAdmin bool + Subscriptions map[string]bool // channel -> subscribed + SubsMu sync.RWMutex // Protects Subscriptions map + WriteMu *sync.Mutex +} + +// UploadSession tracks upload and processing progress +type UploadSession struct { + SessionID string + UserID int64 + Progress float64 + Status string // "uploading", "processing", "extracting_metadata", "creating_context", "completed", "error" + Message string + CreatedAt time.Time } // NewServer creates a new API server -func NewServer(db *database.DB, auth *authpkg.Auth, storage *storage.Storage) (*Server, error) { - secrets, err := authpkg.NewSecrets(db.DB) +func NewServer(db *database.DB, cfg *config.Config, auth *authpkg.Auth, storage *storage.Storage) (*Server, error) { + secrets, err := authpkg.NewSecrets(db, cfg) if err != nil { return nil, fmt.Errorf("failed to initialize secrets: %w", err) } s := &Server{ - db: db, - auth: auth, - secrets: secrets, - storage: storage, - router: chi.NewRouter(), + db: db, + cfg: cfg, + auth: auth, + secrets: secrets, + storage: storage, + router: chi.NewRouter(), + startTime: time.Now(), wsUpgrader: websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true // Allow all origins for now - }, + CheckOrigin: checkWebSocketOrigin, ReadBufferSize: 1024, WriteBufferSize: 1024, }, - runnerConns: make(map[int64]*websocket.Conn), - runnerConnsWriteMu: make(map[int64]*sync.Mutex), + runnerConns: make(map[int64]*websocket.Conn), + runnerConnsWriteMu: make(map[int64]*sync.Mutex), + // DEPRECATED: Initialize old WebSocket maps for backward compatibility frontendConns: make(map[string]*websocket.Conn), frontendConnsWriteMu: make(map[string]*sync.Mutex), jobListConns: make(map[int64]*websocket.Conn), jobConns: make(map[string]*websocket.Conn), jobConnsWriteMu: make(map[string]*sync.Mutex), progressUpdateTimes: make(map[int64]time.Time), + taskUpdateTimes: make(map[int64]time.Time), + clientConns: make(map[int64]*ClientConnection), + uploadSessions: make(map[string]*UploadSession), } s.setupMiddleware() s.setupRoutes() s.StartBackgroundTasks() + // On startup, check for runners that are marked online but not actually connected + // This handles the case where the manager restarted and lost track of connections + go s.recoverRunnersOnStartup() + return s, nil } +// checkWebSocketOrigin validates WebSocket connection origins +// In production mode, only allows same-origin connections or configured allowed origins +func checkWebSocketOrigin(r *http.Request) bool { + origin := r.Header.Get("Origin") + if origin == "" { + // No origin header - allow (could be non-browser client like runner) + return true + } + + // In development mode, allow all origins + // Note: This function doesn't have access to Server, so we use authpkg.IsProductionMode() + // which checks environment variable. The server setup uses s.cfg.IsProductionMode() for consistency. + if !authpkg.IsProductionMode() { + return true + } + + // In production, check against allowed origins + allowedOrigins := os.Getenv("ALLOWED_ORIGINS") + if allowedOrigins == "" { + // Default to same-origin only + host := r.Host + return strings.HasSuffix(origin, "://"+host) || strings.HasSuffix(origin, "://"+strings.Split(host, ":")[0]) + } + + // Check against configured allowed origins + for _, allowed := range strings.Split(allowedOrigins, ",") { + allowed = strings.TrimSpace(allowed) + if allowed == "*" { + return true + } + if origin == allowed { + return true + } + } + + log.Printf("WebSocket origin rejected: %s (allowed: %s)", origin, allowedOrigins) + return false +} + +// RateLimiter provides simple in-memory rate limiting per IP +type RateLimiter struct { + requests map[string][]time.Time + mu sync.RWMutex + limit int // max requests + window time.Duration // time window +} + +// NewRateLimiter creates a new rate limiter +func NewRateLimiter(limit int, window time.Duration) *RateLimiter { + rl := &RateLimiter{ + requests: make(map[string][]time.Time), + limit: limit, + window: window, + } + // Start cleanup goroutine + go rl.cleanup() + return rl +} + +// Allow checks if a request from the given IP is allowed +func (rl *RateLimiter) Allow(ip string) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + cutoff := now.Add(-rl.window) + + // Get existing requests and filter old ones + reqs := rl.requests[ip] + validReqs := make([]time.Time, 0, len(reqs)) + for _, t := range reqs { + if t.After(cutoff) { + validReqs = append(validReqs, t) + } + } + + // Check if under limit + if len(validReqs) >= rl.limit { + rl.requests[ip] = validReqs + return false + } + + // Add this request + validReqs = append(validReqs, now) + rl.requests[ip] = validReqs + return true +} + +// cleanup periodically removes old entries +func (rl *RateLimiter) cleanup() { + ticker := time.NewTicker(5 * time.Minute) + for range ticker.C { + rl.mu.Lock() + cutoff := time.Now().Add(-rl.window) + for ip, reqs := range rl.requests { + validReqs := make([]time.Time, 0, len(reqs)) + for _, t := range reqs { + if t.After(cutoff) { + validReqs = append(validReqs, t) + } + } + if len(validReqs) == 0 { + delete(rl.requests, ip) + } else { + rl.requests[ip] = validReqs + } + } + rl.mu.Unlock() + } +} + +// Global rate limiters for different endpoint types +var ( + // General API rate limiter: 100 requests per minute per IP + apiRateLimiter = NewRateLimiter(100, time.Minute) + // Auth rate limiter: 10 requests per minute per IP (stricter for login attempts) + authRateLimiter = NewRateLimiter(10, time.Minute) +) + +// rateLimitMiddleware applies rate limiting based on client IP +func rateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Get client IP (handle proxied requests) + ip := r.RemoteAddr + if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" { + // Take the first IP in the chain + if idx := strings.Index(forwarded, ","); idx != -1 { + ip = strings.TrimSpace(forwarded[:idx]) + } else { + ip = strings.TrimSpace(forwarded) + } + } else if realIP := r.Header.Get("X-Real-IP"); realIP != "" { + ip = strings.TrimSpace(realIP) + } + + if !limiter.Allow(ip) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Retry-After", "60") + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]string{ + "error": "Rate limit exceeded. Please try again later.", + }) + return + } + + next.ServeHTTP(w, r) + }) + } +} + // setupMiddleware configures middleware func (s *Server) setupMiddleware() { s.router.Use(middleware.Logger) @@ -106,17 +329,47 @@ func (s *Server) setupMiddleware() { // Note: Timeout middleware is NOT applied globally to avoid conflicts with WebSocket connections // WebSocket connections are long-lived and should not have HTTP timeouts + // Check production mode from config + isProduction := s.cfg.IsProductionMode() + + // Add rate limiting (applied in production mode only, or when explicitly enabled) + if isProduction || os.Getenv("ENABLE_RATE_LIMITING") == "true" { + s.router.Use(rateLimitMiddleware(apiRateLimiter)) + log.Printf("Rate limiting enabled: 100 requests/minute per IP") + } + // Add gzip compression for JSON responses s.router.Use(gzipMiddleware) - s.router.Use(cors.Handler(cors.Options{ - AllowedOrigins: []string{"*"}, - AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, + // Configure CORS based on environment + corsOptions := cors.Options{ + AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"}, AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "Range", "If-None-Match"}, ExposedHeaders: []string{"Link", "Content-Range", "Accept-Ranges", "Content-Length", "ETag"}, AllowCredentials: true, MaxAge: 300, - })) + } + + // In production, restrict CORS origins + if isProduction { + allowedOrigins := s.cfg.AllowedOrigins() + if allowedOrigins != "" { + corsOptions.AllowedOrigins = strings.Split(allowedOrigins, ",") + for i := range corsOptions.AllowedOrigins { + corsOptions.AllowedOrigins[i] = strings.TrimSpace(corsOptions.AllowedOrigins[i]) + } + } else { + // Default to no origins in production if not configured + // This effectively disables cross-origin requests + corsOptions.AllowedOrigins = []string{} + } + log.Printf("Production mode: CORS restricted to origins: %v", corsOptions.AllowedOrigins) + } else { + // Development mode: allow all origins + corsOptions.AllowedOrigins = []string{"*"} + } + + s.router.Use(cors.Handler(corsOptions)) } // gzipMiddleware compresses responses with gzip if client supports it @@ -164,8 +417,15 @@ func (w *gzipResponseWriter) WriteHeader(statusCode int) { // setupRoutes configures routes func (s *Server) setupRoutes() { - // Public routes + // Health check endpoint (unauthenticated) + s.router.Get("/api/health", s.handleHealthCheck) + + // Public routes (with stricter rate limiting for auth endpoints) s.router.Route("/api/auth", func(r chi.Router) { + // Apply stricter rate limiting to auth endpoints in production + if s.cfg.IsProductionMode() || os.Getenv("ENABLE_RATE_LIMITING") == "true" { + r.Use(rateLimitMiddleware(authRateLimiter)) + } r.Get("/providers", s.handleGetAuthProviders) r.Get("/google/login", s.handleGoogleLogin) r.Get("/google/callback", s.handleGoogleCallback) @@ -203,16 +463,10 @@ func (s *Server) setupRoutes() { r.Get("/{id}/tasks/summary", s.handleListJobTasksSummary) r.Post("/{id}/tasks/batch", s.handleBatchGetTasks) r.Get("/{id}/tasks/{taskId}/logs", s.handleGetTaskLogs) - // WebSocket route - no timeout middleware (long-lived connection) - r.With(func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Remove timeout middleware for WebSocket - next.ServeHTTP(w, r) - }) - }).Get("/{id}/tasks/{taskId}/logs/ws", s.handleStreamTaskLogsWebSocket) + // Old WebSocket route removed - use client WebSocket with subscriptions instead r.Get("/{id}/tasks/{taskId}/steps", s.handleGetTaskSteps) r.Post("/{id}/tasks/{taskId}/retry", s.handleRetryTask) - // WebSocket routes for real-time updates + // WebSocket route for unified client WebSocket r.With(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Apply authentication middleware first @@ -221,16 +475,7 @@ func (s *Server) setupRoutes() { next.ServeHTTP(w, r) })(w, r) }) - }).Get("/ws", s.handleJobsWebSocket) - r.With(func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Apply authentication middleware first - s.auth.Middleware(func(w http.ResponseWriter, r *http.Request) { - // Remove timeout middleware for WebSocket - next.ServeHTTP(w, r) - })(w, r) - }) - }).Get("/{id}/ws", s.handleJobWebSocket) + }).Get("/ws", s.handleClientWebSocket) }) // Admin routes @@ -286,8 +531,8 @@ func (s *Server) setupRoutes() { }) }) - // Serve static files (built React app) - s.router.Handle("/*", http.FileServer(http.Dir("./web/dist"))) + // Serve static files (embedded React app with SPA fallback) + s.router.Handle("/*", web.SPAHandler()) } // ServeHTTP implements http.Handler @@ -308,6 +553,76 @@ func (s *Server) respondError(w http.ResponseWriter, status int, message string) s.respondJSON(w, status, map[string]string{"error": message}) } +// createSessionCookie creates a secure session cookie with appropriate flags for the environment +func createSessionCookie(sessionID string) *http.Cookie { + cookie := &http.Cookie{ + Name: "session_id", + Value: sessionID, + Path: "/", + MaxAge: SessionCookieMaxAge, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + } + + // In production mode, set Secure flag to require HTTPS + if authpkg.IsProductionMode() { + cookie.Secure = true + } + + return cookie +} + +// handleHealthCheck returns server health status +func (s *Server) handleHealthCheck(w http.ResponseWriter, r *http.Request) { + // Check database connectivity + dbHealthy := true + if err := s.db.Ping(); err != nil { + dbHealthy = false + log.Printf("Health check: database ping failed: %v", err) + } + + // Count connected runners + s.runnerConnsMu.RLock() + runnerCount := len(s.runnerConns) + s.runnerConnsMu.RUnlock() + + // Count connected clients + s.clientConnsMu.RLock() + clientCount := len(s.clientConns) + s.clientConnsMu.RUnlock() + + // Calculate uptime + uptime := time.Since(s.startTime) + + // Get memory stats + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + + status := "healthy" + statusCode := http.StatusOK + if !dbHealthy { + status = "degraded" + statusCode = http.StatusServiceUnavailable + } + + response := map[string]interface{}{ + "status": status, + "uptime_seconds": int64(uptime.Seconds()), + "database": dbHealthy, + "connected_runners": runnerCount, + "connected_clients": clientCount, + "memory": map[string]interface{}{ + "alloc_mb": memStats.Alloc / 1024 / 1024, + "total_alloc_mb": memStats.TotalAlloc / 1024 / 1024, + "sys_mb": memStats.Sys / 1024 / 1024, + "num_gc": memStats.NumGC, + }, + "timestamp": time.Now().Unix(), + } + + s.respondJSON(w, statusCode, response) +} + // Auth handlers func (s *Server) handleGoogleLogin(w http.ResponseWriter, r *http.Request) { url, err := s.auth.GoogleLoginURL() @@ -337,14 +652,7 @@ func (s *Server) handleGoogleCallback(w http.ResponseWriter, r *http.Request) { } sessionID := s.auth.CreateSession(session) - http.SetCookie(w, &http.Cookie{ - Name: "session_id", - Value: sessionID, - Path: "/", - MaxAge: 86400, - HttpOnly: true, - SameSite: http.SameSiteLaxMode, - }) + http.SetCookie(w, createSessionCookie(sessionID)) http.Redirect(w, r, "/", http.StatusFound) } @@ -377,14 +685,7 @@ func (s *Server) handleDiscordCallback(w http.ResponseWriter, r *http.Request) { } sessionID := s.auth.CreateSession(session) - http.SetCookie(w, &http.Cookie{ - Name: "session_id", - Value: sessionID, - Path: "/", - MaxAge: 86400, - HttpOnly: true, - SameSite: http.SameSiteLaxMode, - }) + http.SetCookie(w, createSessionCookie(sessionID)) http.Redirect(w, r, "/", http.StatusFound) } @@ -394,13 +695,20 @@ func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) { if err == nil { s.auth.DeleteSession(cookie.Value) } - http.SetCookie(w, &http.Cookie{ + // Create an expired cookie to clear the session + expiredCookie := &http.Cookie{ Name: "session_id", Value: "", Path: "/", MaxAge: -1, HttpOnly: true, - }) + SameSite: http.SameSiteLaxMode, + } + // Use s.cfg.IsProductionMode() for consistency with other server methods + if s.cfg.IsProductionMode() { + expiredCookie.Secure = true + } + http.SetCookie(w, expiredCookie) s.respondJSON(w, http.StatusOK, map[string]string{"message": "Logged out"}) } @@ -470,14 +778,7 @@ func (s *Server) handleLocalRegister(w http.ResponseWriter, r *http.Request) { } sessionID := s.auth.CreateSession(session) - http.SetCookie(w, &http.Cookie{ - Name: "session_id", - Value: sessionID, - Path: "/", - MaxAge: 86400, - HttpOnly: true, - SameSite: http.SameSiteLaxMode, - }) + http.SetCookie(w, createSessionCookie(sessionID)) s.respondJSON(w, http.StatusCreated, map[string]interface{}{ "message": "Registration successful", @@ -514,14 +815,7 @@ func (s *Server) handleLocalLogin(w http.ResponseWriter, r *http.Request) { } sessionID := s.auth.CreateSession(session) - http.SetCookie(w, &http.Cookie{ - Name: "session_id", - Value: sessionID, - Path: "/", - MaxAge: 86400, - HttpOnly: true, - SameSite: http.SameSiteLaxMode, - }) + http.SetCookie(w, createSessionCookie(sessionID)) s.respondJSON(w, http.StatusOK, map[string]interface{}{ "message": "Login successful", @@ -612,6 +906,87 @@ func (s *Server) StartBackgroundTasks() { go s.recoverStuckTasks() go s.cleanupOldRenderJobs() go s.cleanupOldTempDirectories() + go s.cleanupOldOfflineRunners() + go s.cleanupOldUploadSessions() +} + +// recoverRunnersOnStartup checks for runners marked as online but not actually connected +// This runs once on startup to handle manager restarts where we lose track of connections +func (s *Server) recoverRunnersOnStartup() { + // Wait a short time for runners to reconnect after manager restart + // This gives runners a chance to reconnect before we mark them as dead + time.Sleep(5 * time.Second) + + log.Printf("Recovering runners on startup: checking for disconnected runners...") + + var onlineRunnerIDs []int64 + err := s.db.With(func(conn *sql.DB) error { + rows, err := conn.Query( + `SELECT id FROM runners WHERE status = ?`, + types.RunnerStatusOnline, + ) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var runnerID int64 + if err := rows.Scan(&runnerID); err == nil { + onlineRunnerIDs = append(onlineRunnerIDs, runnerID) + } + } + return nil + }) + + if err != nil { + log.Printf("Failed to query online runners on startup: %v", err) + return + } + + if len(onlineRunnerIDs) == 0 { + log.Printf("No runners marked as online on startup") + return + } + + log.Printf("Found %d runners marked as online, checking actual connections...", len(onlineRunnerIDs)) + + // Check which runners are actually connected + s.runnerConnsMu.RLock() + deadRunnerIDs := make([]int64, 0) + for _, runnerID := range onlineRunnerIDs { + if _, connected := s.runnerConns[runnerID]; !connected { + deadRunnerIDs = append(deadRunnerIDs, runnerID) + } + } + s.runnerConnsMu.RUnlock() + + if len(deadRunnerIDs) == 0 { + log.Printf("All runners marked as online are actually connected") + return + } + + log.Printf("Found %d runners marked as online but not connected, redistributing their tasks...", len(deadRunnerIDs)) + + // Redistribute tasks for disconnected runners + for _, runnerID := range deadRunnerIDs { + log.Printf("Recovering runner %d: redistributing tasks and marking as offline", runnerID) + s.redistributeRunnerTasks(runnerID) + + // Mark runner as offline + s.db.With(func(conn *sql.DB) error { + _, _ = conn.Exec( + `UPDATE runners SET status = ?, last_heartbeat = ? WHERE id = ?`, + types.RunnerStatusOffline, time.Now(), runnerID, + ) + return nil + }) + } + + log.Printf("Startup recovery complete: redistributed tasks from %d disconnected runners", len(deadRunnerIDs)) + + // Trigger task distribution to assign recovered tasks to available runners + s.triggerTaskDistribution() } // recoverStuckTasks periodically checks for dead runners and stuck tasks @@ -639,34 +1014,53 @@ func (s *Server) recoverStuckTasks() { // Find dead runners (no heartbeat for 90 seconds) // But only mark as dead if they're not actually connected via WebSocket - rows, err := s.db.Query( - `SELECT id FROM runners - WHERE last_heartbeat < CURRENT_TIMESTAMP - INTERVAL '90 seconds' + var deadRunnerIDs []int64 + var stillConnectedIDs []int64 + err := s.db.With(func(conn *sql.DB) error { + rows, err := conn.Query( + `SELECT id FROM runners + WHERE last_heartbeat < datetime('now', '-90 seconds') AND status = ?`, - types.RunnerStatusOnline, - ) + types.RunnerStatusOnline, + ) + if err != nil { + return err + } + defer rows.Close() + + s.runnerConnsMu.RLock() + for rows.Next() { + var runnerID int64 + if err := rows.Scan(&runnerID); err == nil { + // Only mark as dead if not actually connected via WebSocket + // The WebSocket connection is the source of truth + if _, stillConnected := s.runnerConns[runnerID]; !stillConnected { + deadRunnerIDs = append(deadRunnerIDs, runnerID) + } else { + // Runner is still connected but heartbeat is stale - update it + stillConnectedIDs = append(stillConnectedIDs, runnerID) + } + } + } + s.runnerConnsMu.RUnlock() + return nil + }) if err != nil { log.Printf("Failed to query dead runners: %v", err) return } - defer rows.Close() - var deadRunnerIDs []int64 - s.runnerConnsMu.RLock() - for rows.Next() { - var runnerID int64 - if err := rows.Scan(&runnerID); err == nil { - // Only mark as dead if not actually connected via WebSocket - // The WebSocket connection is the source of truth - if _, stillConnected := s.runnerConns[runnerID]; !stillConnected { - deadRunnerIDs = append(deadRunnerIDs, runnerID) - } - // If still connected, heartbeat should be updated by pong handler or heartbeat message - // No need to manually update here - if it's stale, the pong handler isn't working - } + // Update heartbeat for runners that are still connected but have stale heartbeats + // This ensures the database stays in sync with actual connection state + for _, runnerID := range stillConnectedIDs { + s.db.With(func(conn *sql.DB) error { + _, _ = conn.Exec( + `UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`, + time.Now(), types.RunnerStatusOnline, runnerID, + ) + return nil + }) } - s.runnerConnsMu.RUnlock() - rows.Close() if len(deadRunnerIDs) == 0 { // Check for task timeouts @@ -679,10 +1073,13 @@ func (s *Server) recoverStuckTasks() { s.redistributeRunnerTasks(runnerID) // Mark runner as offline - _, _ = s.db.Exec( - `UPDATE runners SET status = ? WHERE id = ?`, - types.RunnerStatusOffline, runnerID, - ) + s.db.With(func(conn *sql.DB) error { + _, _ = conn.Exec( + `UPDATE runners SET status = ? WHERE id = ?`, + types.RunnerStatusOffline, runnerID, + ) + return nil + }) } // Check for task timeouts @@ -697,33 +1094,59 @@ func (s *Server) recoverStuckTasks() { // recoverTaskTimeouts handles tasks that have exceeded their timeout func (s *Server) recoverTaskTimeouts() { // Find tasks running longer than their timeout - rows, err := s.db.Query( - `SELECT t.id, t.runner_id, t.retry_count, t.max_retries, t.timeout_seconds, t.started_at + var tasks []struct { + taskID int64 + runnerID sql.NullInt64 + retryCount int + maxRetries int + timeoutSeconds sql.NullInt64 + startedAt time.Time + } + + err := s.db.With(func(conn *sql.DB) error { + rows, err := conn.Query( + `SELECT t.id, t.runner_id, t.retry_count, t.max_retries, t.timeout_seconds, t.started_at FROM tasks t WHERE t.status = ? AND t.started_at IS NOT NULL AND (t.timeout_seconds IS NULL OR - t.started_at + INTERVAL (t.timeout_seconds || ' seconds') < CURRENT_TIMESTAMP)`, - types.TaskStatusRunning, - ) + (julianday('now') - julianday(t.started_at)) * 86400 > t.timeout_seconds)`, + types.TaskStatusRunning, + ) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var task struct { + taskID int64 + runnerID sql.NullInt64 + retryCount int + maxRetries int + timeoutSeconds sql.NullInt64 + startedAt time.Time + } + err := rows.Scan(&task.taskID, &task.runnerID, &task.retryCount, &task.maxRetries, &task.timeoutSeconds, &task.startedAt) + if err != nil { + log.Printf("Failed to scan task row in recoverTaskTimeouts: %v", err) + continue + } + tasks = append(tasks, task) + } + return nil + }) if err != nil { log.Printf("Failed to query timed out tasks: %v", err) return } - defer rows.Close() - for rows.Next() { - var taskID int64 - var runnerID sql.NullInt64 - var retryCount, maxRetries int - var timeoutSeconds sql.NullInt64 - var startedAt time.Time - - err := rows.Scan(&taskID, &runnerID, &retryCount, &maxRetries, &timeoutSeconds, &startedAt) - if err != nil { - log.Printf("Failed to scan task row in recoverTaskTimeouts: %v", err) - continue - } + for _, task := range tasks { + taskID := task.taskID + retryCount := task.retryCount + maxRetries := task.maxRetries + timeoutSeconds := task.timeoutSeconds + startedAt := task.startedAt // Use default timeout if not set (5 minutes for frame tasks, 24 hours for FFmpeg) timeout := 300 // 5 minutes default @@ -738,21 +1161,39 @@ func (s *Server) recoverTaskTimeouts() { if retryCount >= maxRetries { // Mark as failed - _, err = s.db.Exec( - `UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL - WHERE id = ?`, - types.TaskStatusFailed, "Task timeout exceeded, max retries reached", taskID, - ) + err = s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec(`UPDATE tasks SET status = ? WHERE id = ?`, types.TaskStatusFailed, taskID) + if err != nil { + return err + } + _, err = conn.Exec(`UPDATE tasks SET error_message = ? WHERE id = ?`, "Task timeout exceeded, max retries reached", taskID) + if err != nil { + return err + } + _, err = conn.Exec(`UPDATE tasks SET runner_id = NULL WHERE id = ?`, taskID) + return err + }) if err != nil { log.Printf("Failed to mark task %d as failed: %v", taskID, err) } } else { // Reset to pending - _, err = s.db.Exec( - `UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL, - retry_count = retry_count + 1 WHERE id = ?`, - types.TaskStatusPending, taskID, - ) + err = s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec(`UPDATE tasks SET status = ? WHERE id = ?`, types.TaskStatusPending, taskID) + if err != nil { + return err + } + _, err = conn.Exec(`UPDATE tasks SET runner_id = NULL WHERE id = ?`, taskID) + if err != nil { + return err + } + _, err = conn.Exec(`UPDATE tasks SET current_step = NULL WHERE id = ?`, taskID) + if err != nil { + return err + } + _, err = conn.Exec(`UPDATE tasks SET retry_count = retry_count + 1 WHERE id = ?`, taskID) + return err + }) if err == nil { // Add log entry using the helper function s.logTaskEvent(taskID, nil, types.LogLevelWarn, fmt.Sprintf("Task timeout exceeded, resetting (retry %d/%d)", retryCount+1, maxRetries), "") @@ -784,7 +1225,7 @@ func (s *Server) cleanupOldTempDirectoriesOnce() { }() tempPath := filepath.Join(s.storage.BasePath(), "temp") - + // Check if temp directory exists if _, err := os.Stat(tempPath); os.IsNotExist(err) { return @@ -799,21 +1240,34 @@ func (s *Server) cleanupOldTempDirectoriesOnce() { now := time.Now() cleanedCount := 0 - + + // Check upload sessions to avoid deleting active uploads + s.uploadSessionsMu.RLock() + activeSessions := make(map[string]bool) + for sessionID := range s.uploadSessions { + activeSessions[sessionID] = true + } + s.uploadSessionsMu.RUnlock() + for _, entry := range entries { if !entry.IsDir() { continue } entryPath := filepath.Join(tempPath, entry.Name()) - + + // Skip if this directory has an active upload session + if activeSessions[entryPath] { + continue + } + // Get directory info to check modification time info, err := entry.Info() if err != nil { continue } - // Remove directories older than 1 hour + // Remove directories older than 1 hour (only if no active session) age := now.Sub(info.ModTime()) if age > 1*time.Hour { if err := os.RemoveAll(entryPath); err != nil { @@ -829,3 +1283,47 @@ func (s *Server) cleanupOldTempDirectoriesOnce() { log.Printf("Cleaned up %d old temp directories", cleanedCount) } } + +// cleanupOldUploadSessions periodically cleans up abandoned upload sessions +func (s *Server) cleanupOldUploadSessions() { + // Run cleanup every 10 minutes + ticker := time.NewTicker(10 * time.Minute) + defer ticker.Stop() + + // Run once immediately on startup + s.cleanupOldUploadSessionsOnce() + + for range ticker.C { + s.cleanupOldUploadSessionsOnce() + } +} + +// cleanupOldUploadSessionsOnce removes upload sessions older than 1 hour +func (s *Server) cleanupOldUploadSessionsOnce() { + defer func() { + if r := recover(); r != nil { + log.Printf("Panic in cleanupOldUploadSessions: %v", r) + } + }() + + s.uploadSessionsMu.Lock() + defer s.uploadSessionsMu.Unlock() + + now := time.Now() + cleanedCount := 0 + + for sessionID, session := range s.uploadSessions { + // Remove sessions older than 1 hour + age := now.Sub(session.CreatedAt) + if age > 1*time.Hour { + delete(s.uploadSessions, sessionID) + cleanedCount++ + log.Printf("Cleaned up abandoned upload session: %s (user: %d, status: %s, age: %v)", + sessionID, session.UserID, session.Status, age) + } + } + + if cleanedCount > 0 { + log.Printf("Cleaned up %d abandoned upload sessions", cleanedCount) + } +} diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 3044d91..d353a23 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -5,10 +5,13 @@ import ( "database/sql" "encoding/json" "fmt" + "jiggablend/internal/config" + "jiggablend/internal/database" "log" "net/http" "os" "strings" + "sync" "time" "github.com/google/uuid" @@ -17,12 +20,31 @@ import ( "golang.org/x/oauth2/google" ) +// Context key types to avoid collisions (typed keys are safer than string keys) +type contextKey int + +const ( + contextKeyUserID contextKey = iota + contextKeyUserEmail + contextKeyUserName + contextKeyIsAdmin +) + +// Configuration constants +const ( + SessionDuration = 24 * time.Hour + SessionCleanupInterval = 1 * time.Hour +) + // Auth handles authentication type Auth struct { - db *sql.DB + db *database.DB + cfg *config.Config googleConfig *oauth2.Config discordConfig *oauth2.Config - sessionStore map[string]*Session + sessionCache map[string]*Session // In-memory cache for performance + cacheMu sync.RWMutex + stopCleanup chan struct{} } // Session represents a user session @@ -35,41 +57,53 @@ type Session struct { } // NewAuth creates a new auth instance -func NewAuth(db *sql.DB) (*Auth, error) { +func NewAuth(db *database.DB, cfg *config.Config) (*Auth, error) { auth := &Auth{ db: db, - sessionStore: make(map[string]*Session), + cfg: cfg, + sessionCache: make(map[string]*Session), + stopCleanup: make(chan struct{}), } - // Initialize Google OAuth - googleClientID := os.Getenv("GOOGLE_CLIENT_ID") - googleClientSecret := os.Getenv("GOOGLE_CLIENT_SECRET") + // Initialize Google OAuth from database config + googleClientID := cfg.GoogleClientID() + googleClientSecret := cfg.GoogleClientSecret() if googleClientID != "" && googleClientSecret != "" { auth.googleConfig = &oauth2.Config{ ClientID: googleClientID, ClientSecret: googleClientSecret, - RedirectURL: os.Getenv("GOOGLE_REDIRECT_URL"), + RedirectURL: cfg.GoogleRedirectURL(), Scopes: []string{"openid", "profile", "email"}, Endpoint: google.Endpoint, } + log.Printf("Google OAuth configured") } - // Initialize Discord OAuth - discordClientID := os.Getenv("DISCORD_CLIENT_ID") - discordClientSecret := os.Getenv("DISCORD_CLIENT_SECRET") + // Initialize Discord OAuth from database config + discordClientID := cfg.DiscordClientID() + discordClientSecret := cfg.DiscordClientSecret() if discordClientID != "" && discordClientSecret != "" { auth.discordConfig = &oauth2.Config{ ClientID: discordClientID, ClientSecret: discordClientSecret, - RedirectURL: os.Getenv("DISCORD_REDIRECT_URL"), + RedirectURL: cfg.DiscordRedirectURL(), Scopes: []string{"identify", "email"}, Endpoint: oauth2.Endpoint{ AuthURL: "https://discord.com/api/oauth2/authorize", TokenURL: "https://discord.com/api/oauth2/token", }, } + log.Printf("Discord OAuth configured") } + // Load existing sessions from database into cache + if err := auth.loadSessionsFromDB(); err != nil { + log.Printf("Warning: Failed to load sessions from database: %v", err) + } + + // Start background cleanup goroutine + go auth.cleanupExpiredSessions() + // Initialize admin settings on startup to ensure they persist between boots if err := auth.initializeSettings(); err != nil { log.Printf("Warning: Failed to initialize admin settings: %v", err) @@ -85,19 +119,119 @@ func NewAuth(db *sql.DB) (*Auth, error) { return auth, nil } +// Close stops background goroutines +func (a *Auth) Close() { + close(a.stopCleanup) +} + +// loadSessionsFromDB loads all valid sessions from database into cache +func (a *Auth) loadSessionsFromDB() error { + var sessions []struct { + sessionID string + session Session + } + + err := a.db.With(func(conn *sql.DB) error { + rows, err := conn.Query( + `SELECT session_id, user_id, email, name, is_admin, expires_at + FROM sessions WHERE expires_at > CURRENT_TIMESTAMP`, + ) + if err != nil { + return fmt.Errorf("failed to query sessions: %w", err) + } + defer rows.Close() + + for rows.Next() { + var s struct { + sessionID string + session Session + } + err := rows.Scan(&s.sessionID, &s.session.UserID, &s.session.Email, &s.session.Name, &s.session.IsAdmin, &s.session.ExpiresAt) + if err != nil { + log.Printf("Warning: Failed to scan session row: %v", err) + continue + } + sessions = append(sessions, s) + } + return nil + }) + if err != nil { + return err + } + + a.cacheMu.Lock() + defer a.cacheMu.Unlock() + + for _, s := range sessions { + a.sessionCache[s.sessionID] = &s.session + } + + if len(sessions) > 0 { + log.Printf("Loaded %d active sessions from database", len(sessions)) + } + return nil +} + +// cleanupExpiredSessions periodically removes expired sessions from database and cache +func (a *Auth) cleanupExpiredSessions() { + ticker := time.NewTicker(SessionCleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + // Delete expired sessions from database + var deleted int64 + err := a.db.With(func(conn *sql.DB) error { + result, err := conn.Exec(`DELETE FROM sessions WHERE expires_at < CURRENT_TIMESTAMP`) + if err != nil { + return err + } + deleted, _ = result.RowsAffected() + return nil + }) + if err != nil { + log.Printf("Warning: Failed to cleanup expired sessions: %v", err) + continue + } + + // Clean up cache + a.cacheMu.Lock() + now := time.Now() + for sessionID, session := range a.sessionCache { + if now.After(session.ExpiresAt) { + delete(a.sessionCache, sessionID) + } + } + a.cacheMu.Unlock() + + if deleted > 0 { + log.Printf("Cleaned up %d expired sessions", deleted) + } + case <-a.stopCleanup: + return + } + } +} + // initializeSettings ensures all admin settings are initialized with defaults if they don't exist func (a *Auth) initializeSettings() error { // Initialize registration_enabled setting (default: true) if it doesn't exist var settingCount int - err := a.db.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount) + err := a.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount) + }) if err != nil { return fmt.Errorf("failed to check registration_enabled setting: %w", err) } if settingCount == 0 { - _, err = a.db.Exec( + err = a.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( `INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)`, "registration_enabled", "true", ) + return err + }) if err != nil { return fmt.Errorf("failed to initialize registration_enabled setting: %w", err) } @@ -118,7 +252,9 @@ func (a *Auth) initializeTestUser() error { // Check if user already exists var exists bool - err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND oauth_provider = 'local')", testEmail).Scan(&exists) + err := a.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND oauth_provider = 'local')", testEmail).Scan(&exists) + }) if err != nil { return fmt.Errorf("failed to check if test user exists: %w", err) } @@ -137,7 +273,12 @@ func (a *Auth) initializeTestUser() error { // Check if this is the first user (make them admin) var userCount int - a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount) + err = a.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount) + }) + if err != nil { + return fmt.Errorf("failed to check user count: %w", err) + } isAdmin := userCount == 0 // Create test user (use email as name if no name is provided) @@ -147,10 +288,13 @@ func (a *Auth) initializeTestUser() error { } // Create test user - _, err = a.db.Exec( + err = a.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( "INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?)", testEmail, testName, testEmail, string(hashedPassword), isAdmin, ) + return err + }) if err != nil { return fmt.Errorf("failed to create test user: %w", err) } @@ -242,7 +386,9 @@ func (a *Auth) DiscordCallback(ctx context.Context, code string) (*Session, erro // IsRegistrationEnabled checks if new user registration is enabled func (a *Auth) IsRegistrationEnabled() (bool, error) { var value string - err := a.db.QueryRow("SELECT value FROM settings WHERE key = ?", "registration_enabled").Scan(&value) + err := a.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT value FROM settings WHERE key = ?", "registration_enabled").Scan(&value) + }) if err == sql.ErrNoRows { // Default to enabled if setting doesn't exist return true, nil @@ -262,29 +408,31 @@ func (a *Auth) SetRegistrationEnabled(enabled bool) error { // Check if setting exists var exists bool - err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", "registration_enabled").Scan(&exists) + err := a.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", "registration_enabled").Scan(&exists) + }) if err != nil { return fmt.Errorf("failed to check if setting exists: %w", err) } + err = a.db.With(func(conn *sql.DB) error { if exists { // Update existing setting - _, err = a.db.Exec( + _, err = conn.Exec( "UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?", value, "registration_enabled", ) - if err != nil { - return fmt.Errorf("failed to update setting: %w", err) - } } else { // Insert new setting - _, err = a.db.Exec( + _, err = conn.Exec( "INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)", "registration_enabled", value, ) - if err != nil { - return fmt.Errorf("failed to insert setting: %w", err) } + return err + }) + if err != nil { + return fmt.Errorf("failed to set registration_enabled: %w", err) } return nil @@ -299,17 +447,21 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session, var dbProvider, dbOAuthID string // First, try to find by provider + oauth_id - err := a.db.QueryRow( + err := a.db.With(func(conn *sql.DB) error { + return conn.QueryRow( "SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE oauth_provider = ? AND oauth_id = ?", provider, oauthID, ).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID) + }) if err == sql.ErrNoRows { // Not found by provider+oauth_id, check by email for account linking - err = a.db.QueryRow( + err = a.db.With(func(conn *sql.DB) error { + return conn.QueryRow( "SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE email = ?", email, ).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID) + }) if err == sql.ErrNoRows { // User doesn't exist, check if registration is enabled @@ -323,14 +475,26 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session, // Check if this is the first user var userCount int - a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount) + err = a.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount) + }) + if err != nil { + return nil, fmt.Errorf("failed to check user count: %w", err) + } isAdmin = userCount == 0 // Create new user - err = a.db.QueryRow( - "INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?) RETURNING id", + err = a.db.With(func(conn *sql.DB) error { + result, err := conn.Exec( + "INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?)", email, name, provider, oauthID, isAdmin, - ).Scan(&userID) + ) + if err != nil { + return err + } + userID, err = result.LastInsertId() + return err + }) if err != nil { return nil, fmt.Errorf("failed to create user: %w", err) } @@ -339,10 +503,13 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session, } else { // User exists with same email but different provider - link accounts by updating provider info // This allows the user to log in with any provider that has the same email - _, err = a.db.Exec( + err = a.db.With(func(conn *sql.DB) error { + _, err = conn.Exec( "UPDATE users SET oauth_provider = ?, oauth_id = ?, name = ? WHERE id = ?", provider, oauthID, name, userID, ) + return err + }) if err != nil { return nil, fmt.Errorf("failed to link account: %w", err) } @@ -353,10 +520,13 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session, } else { // User found by provider+oauth_id, update info if changed if dbEmail != email || dbName != name { - _, err = a.db.Exec( + err = a.db.With(func(conn *sql.DB) error { + _, err = conn.Exec( "UPDATE users SET email = ?, name = ? WHERE id = ?", email, name, userID, ) + return err + }) if err != nil { return nil, fmt.Errorf("failed to update user: %w", err) } @@ -368,41 +538,134 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session, Email: email, Name: name, IsAdmin: isAdmin, - ExpiresAt: time.Now().Add(24 * time.Hour), + ExpiresAt: time.Now().Add(SessionDuration), } return session, nil } // CreateSession creates a new session and returns a session ID +// Sessions are persisted to database and cached in memory func (a *Auth) CreateSession(session *Session) string { sessionID := uuid.New().String() - a.sessionStore[sessionID] = session + + // Store in database first + err := a.db.With(func(conn *sql.DB) error { + _, err := conn.Exec( + `INSERT INTO sessions (session_id, user_id, email, name, is_admin, expires_at) + VALUES (?, ?, ?, ?, ?, ?)`, + sessionID, session.UserID, session.Email, session.Name, session.IsAdmin, session.ExpiresAt, + ) + return err + }) + if err != nil { + log.Printf("Warning: Failed to persist session to database: %v", err) + // Continue anyway - session will work from cache but won't survive restart + } + + // Store in cache + a.cacheMu.Lock() + a.sessionCache[sessionID] = session + a.cacheMu.Unlock() + return sessionID } // GetSession retrieves a session by ID +// First checks cache, then database if not found func (a *Auth) GetSession(sessionID string) (*Session, bool) { - session, ok := a.sessionStore[sessionID] - if !ok { + // Check cache first + a.cacheMu.RLock() + session, ok := a.sessionCache[sessionID] + a.cacheMu.RUnlock() + + if ok { + if time.Now().After(session.ExpiresAt) { + a.DeleteSession(sessionID) + return nil, false + } + // Refresh admin status from database + var isAdmin bool + err := a.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin) + }) + if err == nil { + session.IsAdmin = isAdmin + } + return session, true + } + + // Not in cache, check database + session = &Session{} + err := a.db.With(func(conn *sql.DB) error { + return conn.QueryRow( + `SELECT user_id, email, name, is_admin, expires_at + FROM sessions WHERE session_id = ?`, + sessionID, + ).Scan(&session.UserID, &session.Email, &session.Name, &session.IsAdmin, &session.ExpiresAt) + }) + + if err == sql.ErrNoRows { return nil, false } + if err != nil { + log.Printf("Warning: Failed to query session from database: %v", err) + return nil, false + } + if time.Now().After(session.ExpiresAt) { - delete(a.sessionStore, sessionID) + a.DeleteSession(sessionID) return nil, false } + // Refresh admin status from database var isAdmin bool - err := a.db.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin) + err = a.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin) + }) if err == nil { session.IsAdmin = isAdmin } + + // Add to cache + a.cacheMu.Lock() + a.sessionCache[sessionID] = session + a.cacheMu.Unlock() + return session, true } -// DeleteSession deletes a session +// DeleteSession deletes a session from both cache and database func (a *Auth) DeleteSession(sessionID string) { - delete(a.sessionStore, sessionID) + // Delete from cache + a.cacheMu.Lock() + delete(a.sessionCache, sessionID) + a.cacheMu.Unlock() + + // Delete from database + err := a.db.With(func(conn *sql.DB) error { + _, err := conn.Exec("DELETE FROM sessions WHERE session_id = ?", sessionID) + return err + }) + if err != nil { + log.Printf("Warning: Failed to delete session from database: %v", err) + } +} + +// IsProductionMode returns true if running in production mode +// This is a package-level function that checks the environment variable +// For config-based checks, use Config.IsProductionMode() +func IsProductionMode() bool { + // Check environment variable first for backwards compatibility + if os.Getenv("PRODUCTION") == "true" { + return true + } + return false +} + +// IsProductionModeFromConfig returns true if production mode is enabled in config +func (a *Auth) IsProductionModeFromConfig() bool { + return a.cfg.IsProductionMode() } // Middleware creates an authentication middleware @@ -426,24 +689,24 @@ func (a *Auth) Middleware(next http.HandlerFunc) http.HandlerFunc { return } - // Add user info to request context - ctx := context.WithValue(r.Context(), "user_id", session.UserID) - ctx = context.WithValue(ctx, "user_email", session.Email) - ctx = context.WithValue(ctx, "user_name", session.Name) - ctx = context.WithValue(ctx, "is_admin", session.IsAdmin) + // Add user info to request context using typed keys + ctx := context.WithValue(r.Context(), contextKeyUserID, session.UserID) + ctx = context.WithValue(ctx, contextKeyUserEmail, session.Email) + ctx = context.WithValue(ctx, contextKeyUserName, session.Name) + ctx = context.WithValue(ctx, contextKeyIsAdmin, session.IsAdmin) next(w, r.WithContext(ctx)) } } // GetUserID gets the user ID from context func GetUserID(ctx context.Context) (int64, bool) { - userID, ok := ctx.Value("user_id").(int64) + userID, ok := ctx.Value(contextKeyUserID).(int64) return userID, ok } // IsAdmin checks if the user in context is an admin func IsAdmin(ctx context.Context) bool { - isAdmin, ok := ctx.Value("is_admin").(bool) + isAdmin, ok := ctx.Value(contextKeyIsAdmin).(bool) return ok && isAdmin } @@ -478,19 +741,19 @@ func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc { return } - // Add user info to request context - ctx := context.WithValue(r.Context(), "user_id", session.UserID) - ctx = context.WithValue(ctx, "user_email", session.Email) - ctx = context.WithValue(ctx, "user_name", session.Name) - ctx = context.WithValue(ctx, "is_admin", session.IsAdmin) + // Add user info to request context using typed keys + ctx := context.WithValue(r.Context(), contextKeyUserID, session.UserID) + ctx = context.WithValue(ctx, contextKeyUserEmail, session.Email) + ctx = context.WithValue(ctx, contextKeyUserName, session.Name) + ctx = context.WithValue(ctx, contextKeyIsAdmin, session.IsAdmin) next(w, r.WithContext(ctx)) } } // IsLocalLoginEnabled returns whether local login is enabled -// Local login is enabled when ENABLE_LOCAL_AUTH environment variable is set to "true" +// Checks database config first, falls back to environment variable func (a *Auth) IsLocalLoginEnabled() bool { - return os.Getenv("ENABLE_LOCAL_AUTH") == "true" + return a.cfg.IsLocalAuthEnabled() } // IsGoogleOAuthConfigured returns whether Google OAuth is configured @@ -511,10 +774,12 @@ func (a *Auth) LocalLogin(username, password string) (*Session, error) { var dbEmail, dbName, passwordHash string var isAdmin bool - err := a.db.QueryRow( + err := a.db.With(func(conn *sql.DB) error { + return conn.QueryRow( "SELECT id, email, name, password_hash, is_admin FROM users WHERE email = ? AND oauth_provider = 'local'", email, ).Scan(&userID, &dbEmail, &dbName, &passwordHash, &isAdmin) + }) if err == sql.ErrNoRows { return nil, fmt.Errorf("invalid credentials") @@ -539,7 +804,7 @@ func (a *Auth) LocalLogin(username, password string) (*Session, error) { Email: dbEmail, Name: dbName, IsAdmin: isAdmin, - ExpiresAt: time.Now().Add(24 * time.Hour), + ExpiresAt: time.Now().Add(SessionDuration), } return session, nil @@ -558,7 +823,9 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error) // Check if user already exists var exists bool - err = a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ?)", email).Scan(&exists) + err = a.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ?)", email).Scan(&exists) + }) if err != nil { return nil, fmt.Errorf("failed to check if user exists: %w", err) } @@ -574,15 +841,27 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error) // Check if this is the first user (make them admin) var userCount int - a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount) + err = a.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount) + }) + if err != nil { + return nil, fmt.Errorf("failed to check user count: %w", err) + } isAdmin := userCount == 0 // Create user var userID int64 - err = a.db.QueryRow( - "INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?) RETURNING id", + err = a.db.With(func(conn *sql.DB) error { + result, err := conn.Exec( + "INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?)", email, name, email, string(hashedPassword), isAdmin, - ).Scan(&userID) + ) + if err != nil { + return err + } + userID, err = result.LastInsertId() + return err + }) if err != nil { return nil, fmt.Errorf("failed to create user: %w", err) } @@ -593,7 +872,7 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error) Email: email, Name: name, IsAdmin: isAdmin, - ExpiresAt: time.Now().Add(24 * time.Hour), + ExpiresAt: time.Now().Add(SessionDuration), } return session, nil @@ -603,7 +882,9 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error) func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) error { // Get current password hash var passwordHash string - err := a.db.QueryRow("SELECT password_hash FROM users WHERE id = ? AND oauth_provider = 'local'", userID).Scan(&passwordHash) + err := a.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT password_hash FROM users WHERE id = ? AND oauth_provider = 'local'", userID).Scan(&passwordHash) + }) if err == sql.ErrNoRows { return fmt.Errorf("user not found or not a local user") } @@ -628,7 +909,10 @@ func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) err } // Update password - _, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), userID) + err = a.db.With(func(conn *sql.DB) error { + _, err := conn.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), userID) + return err + }) if err != nil { return fmt.Errorf("failed to update password: %w", err) } @@ -640,7 +924,9 @@ func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) err func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error { // Verify user exists and is a local user var exists bool - err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ? AND oauth_provider = 'local')", targetUserID).Scan(&exists) + err := a.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ? AND oauth_provider = 'local')", targetUserID).Scan(&exists) + }) if err != nil { return fmt.Errorf("failed to check if user exists: %w", err) } @@ -655,7 +941,10 @@ func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error } // Update password - _, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), targetUserID) + err = a.db.With(func(conn *sql.DB) error { + _, err := conn.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), targetUserID) + return err + }) if err != nil { return fmt.Errorf("failed to update password: %w", err) } @@ -666,7 +955,9 @@ func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error // GetFirstUserID returns the ID of the first user (user with the lowest ID) func (a *Auth) GetFirstUserID() (int64, error) { var firstUserID int64 - err := a.db.QueryRow("SELECT id FROM users ORDER BY id ASC LIMIT 1").Scan(&firstUserID) + err := a.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT id FROM users ORDER BY id ASC LIMIT 1").Scan(&firstUserID) + }) if err == sql.ErrNoRows { return 0, fmt.Errorf("no users found") } @@ -680,7 +971,9 @@ func (a *Auth) GetFirstUserID() (int64, error) { func (a *Auth) SetUserAdminStatus(targetUserID int64, isAdmin bool) error { // Verify user exists var exists bool - err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", targetUserID).Scan(&exists) + err := a.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", targetUserID).Scan(&exists) + }) if err != nil { return fmt.Errorf("failed to check if user exists: %w", err) } @@ -698,7 +991,10 @@ func (a *Auth) SetUserAdminStatus(targetUserID int64, isAdmin bool) error { } // Update admin status - _, err = a.db.Exec("UPDATE users SET is_admin = ? WHERE id = ?", isAdmin, targetUserID) + err = a.db.With(func(conn *sql.DB) error { + _, err := conn.Exec("UPDATE users SET is_admin = ? WHERE id = ?", isAdmin, targetUserID) + return err + }) if err != nil { return fmt.Errorf("failed to update admin status: %w", err) } diff --git a/internal/auth/secrets.go b/internal/auth/secrets.go index 5ce9713..f3a895d 100644 --- a/internal/auth/secrets.go +++ b/internal/auth/secrets.go @@ -6,7 +6,8 @@ import ( "database/sql" "encoding/hex" "fmt" - "os" + "jiggablend/internal/config" + "jiggablend/internal/database" "strings" "sync" "time" @@ -14,21 +15,14 @@ import ( // Secrets handles API key management type Secrets struct { - db *sql.DB + db *database.DB + cfg *config.Config RegistrationMu sync.Mutex // Protects concurrent runner registrations - fixedAPIKey string // Fixed API key from environment variable (optional) } // NewSecrets creates a new secrets manager -func NewSecrets(db *sql.DB) (*Secrets, error) { - s := &Secrets{db: db} - - // Check for fixed API key from environment - if fixedKey := os.Getenv("FIXED_API_KEY"); fixedKey != "" { - s.fixedAPIKey = fixedKey - } - - return s, nil +func NewSecrets(db *database.DB, cfg *config.Config) (*Secrets, error) { + return &Secrets{db: db, cfg: cfg}, nil } // APIKeyInfo represents information about an API key @@ -61,25 +55,34 @@ func (s *Secrets) GenerateRunnerAPIKey(createdBy int64, name, description string keyHash := sha256.Sum256([]byte(key)) keyHashStr := hex.EncodeToString(keyHash[:]) - _, err = s.db.Exec( + var keyInfo APIKeyInfo + err = s.db.With(func(conn *sql.DB) error { + result, err := conn.Exec( `INSERT INTO runner_api_keys (key_prefix, key_hash, name, description, scope, is_active, created_by) VALUES (?, ?, ?, ?, ?, ?, ?)`, keyPrefix, keyHashStr, name, description, scope, true, createdBy, ) if err != nil { - return nil, fmt.Errorf("failed to store API key: %w", err) + return fmt.Errorf("failed to store API key: %w", err) + } + + keyID, err := result.LastInsertId() + if err != nil { + return fmt.Errorf("failed to get inserted key ID: %w", err) } // Get the inserted key info - var keyInfo APIKeyInfo - err = s.db.QueryRow( + err = conn.QueryRow( `SELECT id, name, description, scope, is_active, created_at, created_by - FROM runner_api_keys WHERE key_prefix = ?`, - keyPrefix, + FROM runner_api_keys WHERE id = ?`, + keyID, ).Scan(&keyInfo.ID, &keyInfo.Name, &keyInfo.Description, &keyInfo.Scope, &keyInfo.IsActive, &keyInfo.CreatedAt, &keyInfo.CreatedBy) + return err + }) + if err != nil { - return nil, fmt.Errorf("failed to retrieve created API key: %w", err) + return nil, fmt.Errorf("failed to create API key: %w", err) } keyInfo.Key = key @@ -91,18 +94,25 @@ func (s *Secrets) generateAPIKey() (string, error) { // Generate random suffix randomBytes := make([]byte, 16) if _, err := rand.Read(randomBytes); err != nil { - return "", err + return "", fmt.Errorf("failed to generate random bytes: %w", err) } randomStr := hex.EncodeToString(randomBytes) // Generate a unique prefix (jk_r followed by 1 random digit) prefixDigit := make([]byte, 1) if _, err := rand.Read(prefixDigit); err != nil { - return "", err + return "", fmt.Errorf("failed to generate prefix digit: %w", err) } prefix := fmt.Sprintf("jk_r%d", prefixDigit[0]%10) - return fmt.Sprintf("%s_%s", prefix, randomStr), nil + key := fmt.Sprintf("%s_%s", prefix, randomStr) + + // Validate generated key format + if !strings.HasPrefix(key, "jk_r") { + return "", fmt.Errorf("generated invalid API key format: %s", key) + } + + return key, nil } // ValidateRunnerAPIKey validates an API key and returns the key ID and scope if valid @@ -111,8 +121,9 @@ func (s *Secrets) ValidateRunnerAPIKey(apiKey string) (int64, string, error) { return 0, "", fmt.Errorf("API key is required") } - // Check fixed API key first (for testing/development) - if s.fixedAPIKey != "" && apiKey == s.fixedAPIKey { + // Check fixed API key first (from database config) + fixedKey := s.cfg.FixedAPIKey() + if fixedKey != "" && apiKey == fixedKey { // Return a special ID for fixed API key (doesn't exist in database) return -1, "manager", nil } @@ -137,42 +148,50 @@ func (s *Secrets) ValidateRunnerAPIKey(apiKey string) (int64, string, error) { var scope string var isActive bool - err := s.db.QueryRow( + err := s.db.With(func(conn *sql.DB) error { + err := conn.QueryRow( `SELECT id, scope, is_active FROM runner_api_keys WHERE key_prefix = ? AND key_hash = ?`, keyPrefix, keyHashStr, ).Scan(&keyID, &scope, &isActive) if err == sql.ErrNoRows { - return 0, "", fmt.Errorf("API key not found or invalid - please check that the key is correct and active") + return fmt.Errorf("API key not found or invalid - please check that the key is correct and active") } if err != nil { - return 0, "", fmt.Errorf("failed to validate API key: %w", err) + return fmt.Errorf("failed to validate API key: %w", err) } if !isActive { - return 0, "", fmt.Errorf("API key is inactive") + return fmt.Errorf("API key is inactive") } // Update last_used_at (don't fail if this update fails) - s.db.Exec(`UPDATE runner_api_keys SET last_used_at = ? WHERE id = ?`, time.Now(), keyID) + conn.Exec(`UPDATE runner_api_keys SET last_used_at = ? WHERE id = ?`, time.Now(), keyID) + return nil + }) + + if err != nil { + return 0, "", err + } return keyID, scope, nil } // ListRunnerAPIKeys lists all runner API keys func (s *Secrets) ListRunnerAPIKeys() ([]APIKeyInfo, error) { - rows, err := s.db.Query( + var keys []APIKeyInfo + err := s.db.With(func(conn *sql.DB) error { + rows, err := conn.Query( `SELECT id, key_prefix, name, description, scope, is_active, created_at, created_by FROM runner_api_keys ORDER BY created_at DESC`, ) if err != nil { - return nil, fmt.Errorf("failed to query API keys: %w", err) + return fmt.Errorf("failed to query API keys: %w", err) } defer rows.Close() - var keys []APIKeyInfo for rows.Next() { var key APIKeyInfo var description sql.NullString @@ -188,20 +207,28 @@ func (s *Secrets) ListRunnerAPIKeys() ([]APIKeyInfo, error) { keys = append(keys, key) } - + return nil + }) + if err != nil { + return nil, err + } return keys, nil } // RevokeRunnerAPIKey revokes (deactivates) a runner API key func (s *Secrets) RevokeRunnerAPIKey(keyID int64) error { - _, err := s.db.Exec("UPDATE runner_api_keys SET is_active = false WHERE id = ?", keyID) + return s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec("UPDATE runner_api_keys SET is_active = false WHERE id = ?", keyID) return err + }) } // DeleteRunnerAPIKey deletes a runner API key func (s *Secrets) DeleteRunnerAPIKey(keyID int64) error { - _, err := s.db.Exec("DELETE FROM runner_api_keys WHERE id = ?", keyID) + return s.db.With(func(conn *sql.DB) error { + _, err := conn.Exec("DELETE FROM runner_api_keys WHERE id = ?", keyID) return err + }) } diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..652ee8f --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,303 @@ +package config + +import ( + "database/sql" + "fmt" + "jiggablend/internal/database" + "log" + "os" + "strconv" +) + +// Config keys stored in database +const ( + KeyGoogleClientID = "google_client_id" + KeyGoogleClientSecret = "google_client_secret" + KeyGoogleRedirectURL = "google_redirect_url" + KeyDiscordClientID = "discord_client_id" + KeyDiscordClientSecret = "discord_client_secret" + KeyDiscordRedirectURL = "discord_redirect_url" + KeyEnableLocalAuth = "enable_local_auth" + KeyFixedAPIKey = "fixed_api_key" + KeyRegistrationEnabled = "registration_enabled" + KeyProductionMode = "production_mode" + KeyAllowedOrigins = "allowed_origins" +) + +// Config manages application configuration stored in the database +type Config struct { + db *database.DB +} + +// NewConfig creates a new config manager +func NewConfig(db *database.DB) *Config { + return &Config{db: db} +} + +// InitializeFromEnv loads configuration from environment variables on first run +// Environment variables take precedence only if the config key doesn't exist in the database +// This allows first-run setup via env vars, then subsequent runs use database values +func (c *Config) InitializeFromEnv() error { + envMappings := []struct { + envKey string + configKey string + sensitive bool + }{ + {"GOOGLE_CLIENT_ID", KeyGoogleClientID, false}, + {"GOOGLE_CLIENT_SECRET", KeyGoogleClientSecret, true}, + {"GOOGLE_REDIRECT_URL", KeyGoogleRedirectURL, false}, + {"DISCORD_CLIENT_ID", KeyDiscordClientID, false}, + {"DISCORD_CLIENT_SECRET", KeyDiscordClientSecret, true}, + {"DISCORD_REDIRECT_URL", KeyDiscordRedirectURL, false}, + {"ENABLE_LOCAL_AUTH", KeyEnableLocalAuth, false}, + {"FIXED_API_KEY", KeyFixedAPIKey, true}, + {"PRODUCTION", KeyProductionMode, false}, + {"ALLOWED_ORIGINS", KeyAllowedOrigins, false}, + } + + for _, mapping := range envMappings { + envValue := os.Getenv(mapping.envKey) + if envValue == "" { + continue + } + + // Check if config already exists in database + exists, err := c.Exists(mapping.configKey) + if err != nil { + return fmt.Errorf("failed to check config %s: %w", mapping.configKey, err) + } + + if !exists { + // Store env value in database + if err := c.Set(mapping.configKey, envValue); err != nil { + return fmt.Errorf("failed to store config %s: %w", mapping.configKey, err) + } + if mapping.sensitive { + log.Printf("Stored config from env: %s = [REDACTED]", mapping.configKey) + } else { + log.Printf("Stored config from env: %s = %s", mapping.configKey, envValue) + } + } + } + + return nil +} + +// Get retrieves a config value from the database +func (c *Config) Get(key string) (string, error) { + var value string + err := c.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT value FROM settings WHERE key = ?", key).Scan(&value) + }) + if err == sql.ErrNoRows { + return "", nil + } + if err != nil { + return "", fmt.Errorf("failed to get config %s: %w", key, err) + } + return value, nil +} + +// GetWithDefault retrieves a config value or returns a default if not set +func (c *Config) GetWithDefault(key, defaultValue string) string { + value, err := c.Get(key) + if err != nil || value == "" { + return defaultValue + } + return value +} + +// GetBool retrieves a boolean config value +func (c *Config) GetBool(key string) (bool, error) { + value, err := c.Get(key) + if err != nil { + return false, err + } + return value == "true" || value == "1", nil +} + +// GetBoolWithDefault retrieves a boolean config value or returns a default +func (c *Config) GetBoolWithDefault(key string, defaultValue bool) bool { + value, err := c.GetBool(key) + if err != nil { + return defaultValue + } + // If the key doesn't exist, Get returns empty string which becomes false + // Check if key exists to distinguish between "false" and "not set" + exists, _ := c.Exists(key) + if !exists { + return defaultValue + } + return value +} + +// GetInt retrieves an integer config value +func (c *Config) GetInt(key string) (int, error) { + value, err := c.Get(key) + if err != nil { + return 0, err + } + if value == "" { + return 0, nil + } + return strconv.Atoi(value) +} + +// GetIntWithDefault retrieves an integer config value or returns a default +func (c *Config) GetIntWithDefault(key string, defaultValue int) int { + value, err := c.GetInt(key) + if err != nil { + return defaultValue + } + exists, _ := c.Exists(key) + if !exists { + return defaultValue + } + return value +} + +// Set stores a config value in the database +func (c *Config) Set(key, value string) error { + // Use upsert pattern + exists, err := c.Exists(key) + if err != nil { + return err + } + + err = c.db.With(func(conn *sql.DB) error { + if exists { + _, err = conn.Exec( + "UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?", + value, key, + ) + } else { + _, err = conn.Exec( + "INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)", + key, value, + ) + } + return err + }) + + if err != nil { + return fmt.Errorf("failed to set config %s: %w", key, err) + } + return nil +} + +// SetBool stores a boolean config value +func (c *Config) SetBool(key string, value bool) error { + strValue := "false" + if value { + strValue = "true" + } + return c.Set(key, strValue) +} + +// SetInt stores an integer config value +func (c *Config) SetInt(key string, value int) error { + return c.Set(key, strconv.Itoa(value)) +} + +// Delete removes a config value from the database +func (c *Config) Delete(key string) error { + err := c.db.With(func(conn *sql.DB) error { + _, err := conn.Exec("DELETE FROM settings WHERE key = ?", key) + return err + }) + if err != nil { + return fmt.Errorf("failed to delete config %s: %w", key, err) + } + return nil +} + +// Exists checks if a config key exists in the database +func (c *Config) Exists(key string) (bool, error) { + var exists bool + err := c.db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", key).Scan(&exists) + }) + if err != nil { + return false, fmt.Errorf("failed to check config existence %s: %w", key, err) + } + return exists, nil +} + +// GetAll returns all config values (for debugging/admin purposes) +func (c *Config) GetAll() (map[string]string, error) { + var result map[string]string + err := c.db.With(func(conn *sql.DB) error { + rows, err := conn.Query("SELECT key, value FROM settings") + if err != nil { + return fmt.Errorf("failed to get all config: %w", err) + } + defer rows.Close() + + result = make(map[string]string) + for rows.Next() { + var key, value string + if err := rows.Scan(&key, &value); err != nil { + return fmt.Errorf("failed to scan config row: %w", err) + } + result[key] = value + } + return nil + }) + if err != nil { + return nil, err + } + return result, nil +} + +// --- Convenience methods for specific config values --- + +// GoogleClientID returns the Google OAuth client ID +func (c *Config) GoogleClientID() string { + return c.GetWithDefault(KeyGoogleClientID, "") +} + +// GoogleClientSecret returns the Google OAuth client secret +func (c *Config) GoogleClientSecret() string { + return c.GetWithDefault(KeyGoogleClientSecret, "") +} + +// GoogleRedirectURL returns the Google OAuth redirect URL +func (c *Config) GoogleRedirectURL() string { + return c.GetWithDefault(KeyGoogleRedirectURL, "") +} + +// DiscordClientID returns the Discord OAuth client ID +func (c *Config) DiscordClientID() string { + return c.GetWithDefault(KeyDiscordClientID, "") +} + +// DiscordClientSecret returns the Discord OAuth client secret +func (c *Config) DiscordClientSecret() string { + return c.GetWithDefault(KeyDiscordClientSecret, "") +} + +// DiscordRedirectURL returns the Discord OAuth redirect URL +func (c *Config) DiscordRedirectURL() string { + return c.GetWithDefault(KeyDiscordRedirectURL, "") +} + +// IsLocalAuthEnabled returns whether local authentication is enabled +func (c *Config) IsLocalAuthEnabled() bool { + return c.GetBoolWithDefault(KeyEnableLocalAuth, false) +} + +// FixedAPIKey returns the fixed API key for testing +func (c *Config) FixedAPIKey() string { + return c.GetWithDefault(KeyFixedAPIKey, "") +} + +// IsProductionMode returns whether production mode is enabled +func (c *Config) IsProductionMode() bool { + return c.GetBoolWithDefault(KeyProductionMode, false) +} + +// AllowedOrigins returns the allowed CORS origins +func (c *Config) AllowedOrigins() string { + return c.GetWithDefault(KeyAllowedOrigins, "") +} + diff --git a/internal/database/schema.go b/internal/database/schema.go index f67ed82..5c9c3a6 100644 --- a/internal/database/schema.go +++ b/internal/database/schema.go @@ -4,18 +4,20 @@ import ( "database/sql" "fmt" "log" + "sync" - _ "github.com/marcboeker/go-duckdb/v2" + _ "github.com/mattn/go-sqlite3" ) -// DB wraps the database connection +// DB wraps the database connection with mutex protection type DB struct { - *sql.DB + db *sql.DB + mu sync.Mutex } // NewDB creates a new database connection func NewDB(dbPath string) (*DB, error) { - db, err := sql.Open("duckdb", dbPath) + db, err := sql.Open("sqlite3", dbPath) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } @@ -24,7 +26,12 @@ func NewDB(dbPath string) (*DB, error) { return nil, fmt.Errorf("failed to ping database: %w", err) } - database := &DB{DB: db} + // Enable foreign keys for SQLite + if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { + return nil, fmt.Errorf("failed to enable foreign keys: %w", err) + } + + database := &DB{db: db} if err := database.migrate(); err != nil { return nil, fmt.Errorf("failed to migrate database: %w", err) } @@ -32,58 +39,74 @@ func NewDB(dbPath string) (*DB, error) { return database, nil } +// With executes a function with mutex-protected access to the database +// The function receives the underlying *sql.DB connection +func (db *DB) With(fn func(*sql.DB) error) error { + db.mu.Lock() + defer db.mu.Unlock() + return fn(db.db) +} + +// WithTx executes a function within a transaction with mutex protection +// The function receives a *sql.Tx transaction +// If the function returns an error, the transaction is rolled back +// If the function returns nil, the transaction is committed +func (db *DB) WithTx(fn func(*sql.Tx) error) error { + db.mu.Lock() + defer db.mu.Unlock() + + tx, err := db.db.Begin() + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + + if err := fn(tx); err != nil { + if rbErr := tx.Rollback(); rbErr != nil { + return fmt.Errorf("transaction error: %w, rollback error: %v", err, rbErr) + } + return err + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + return nil +} + // migrate runs database migrations func (db *DB) migrate() error { - // Create sequences for auto-incrementing primary keys - sequences := []string{ - `CREATE SEQUENCE IF NOT EXISTS seq_users_id START 1`, - `CREATE SEQUENCE IF NOT EXISTS seq_jobs_id START 1`, - `CREATE SEQUENCE IF NOT EXISTS seq_runners_id START 1`, - `CREATE SEQUENCE IF NOT EXISTS seq_tasks_id START 1`, - `CREATE SEQUENCE IF NOT EXISTS seq_job_files_id START 1`, - `CREATE SEQUENCE IF NOT EXISTS seq_manager_secrets_id START 1`, - `CREATE SEQUENCE IF NOT EXISTS seq_registration_tokens_id START 1`, - `CREATE SEQUENCE IF NOT EXISTS seq_runner_api_keys_id START 1`, - `CREATE SEQUENCE IF NOT EXISTS seq_task_logs_id START 1`, - `CREATE SEQUENCE IF NOT EXISTS seq_task_steps_id START 1`, - } - - for _, seq := range sequences { - if _, err := db.Exec(seq); err != nil { - return fmt.Errorf("failed to create sequence: %w", err) - } - } - + // SQLite uses INTEGER PRIMARY KEY AUTOINCREMENT instead of sequences schema := ` CREATE TABLE IF NOT EXISTS users ( - id BIGINT PRIMARY KEY DEFAULT nextval('seq_users_id'), + id INTEGER PRIMARY KEY AUTOINCREMENT, email TEXT UNIQUE NOT NULL, name TEXT NOT NULL, oauth_provider TEXT NOT NULL, oauth_id TEXT NOT NULL, password_hash TEXT, - is_admin BOOLEAN NOT NULL DEFAULT false, + is_admin INTEGER NOT NULL DEFAULT 0, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, UNIQUE(oauth_provider, oauth_id) ); CREATE TABLE IF NOT EXISTS runner_api_keys ( - id BIGINT PRIMARY KEY DEFAULT nextval('seq_runner_api_keys_id'), - key_prefix TEXT NOT NULL, -- First part of API key (e.g., "jk_r1") - key_hash TEXT NOT NULL, -- SHA256 hash of full API key - name TEXT NOT NULL, -- Human-readable name - description TEXT, -- Optional description - scope TEXT NOT NULL DEFAULT 'user', -- 'manager' or 'user' - manager scope allows all jobs, user scope only allows jobs from key owner - is_active BOOLEAN NOT NULL DEFAULT true, + id INTEGER PRIMARY KEY AUTOINCREMENT, + key_prefix TEXT NOT NULL, + key_hash TEXT NOT NULL, + name TEXT NOT NULL, + description TEXT, + scope TEXT NOT NULL DEFAULT 'user', + is_active INTEGER NOT NULL DEFAULT 1, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - created_by BIGINT, + created_by INTEGER, FOREIGN KEY (created_by) REFERENCES users(id), UNIQUE(key_prefix) ); CREATE TABLE IF NOT EXISTS jobs ( - id BIGINT PRIMARY KEY DEFAULT nextval('seq_jobs_id'), - user_id BIGINT NOT NULL, + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, job_type TEXT NOT NULL DEFAULT 'render', name TEXT NOT NULL, status TEXT NOT NULL DEFAULT 'pending', @@ -91,9 +114,11 @@ func (db *DB) migrate() error { frame_start INTEGER, frame_end INTEGER, output_format TEXT, - allow_parallel_runners BOOLEAN NOT NULL DEFAULT true, + allow_parallel_runners INTEGER NOT NULL DEFAULT 1, timeout_seconds INTEGER DEFAULT 86400, blend_metadata TEXT, + retry_count INTEGER NOT NULL DEFAULT 0, + max_retries INTEGER NOT NULL DEFAULT 3, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, started_at TIMESTAMP, completed_at TIMESTAMP, @@ -102,25 +127,25 @@ func (db *DB) migrate() error { ); CREATE TABLE IF NOT EXISTS runners ( - id BIGINT PRIMARY KEY DEFAULT nextval('seq_runners_id'), + id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, hostname TEXT NOT NULL, ip_address TEXT NOT NULL, status TEXT NOT NULL DEFAULT 'offline', last_heartbeat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, capabilities TEXT, - api_key_id BIGINT, -- Reference to the API key used for this runner - api_key_scope TEXT NOT NULL DEFAULT 'user', -- Scope of the API key ('manager' or 'user') + api_key_id INTEGER, + api_key_scope TEXT NOT NULL DEFAULT 'user', priority INTEGER NOT NULL DEFAULT 100, - fingerprint TEXT, -- Hardware fingerprint (NULL for fixed API keys) + fingerprint TEXT, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (api_key_id) REFERENCES runner_api_keys(id) ); CREATE TABLE IF NOT EXISTS tasks ( - id BIGINT PRIMARY KEY DEFAULT nextval('seq_tasks_id'), - job_id BIGINT NOT NULL, - runner_id BIGINT, + id INTEGER PRIMARY KEY AUTOINCREMENT, + job_id INTEGER NOT NULL, + runner_id INTEGER, frame_start INTEGER NOT NULL, frame_end INTEGER NOT NULL, status TEXT NOT NULL DEFAULT 'pending', @@ -133,44 +158,50 @@ func (db *DB) migrate() error { created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, started_at TIMESTAMP, completed_at TIMESTAMP, - error_message TEXT + error_message TEXT, + FOREIGN KEY (job_id) REFERENCES jobs(id), + FOREIGN KEY (runner_id) REFERENCES runners(id) ); CREATE TABLE IF NOT EXISTS job_files ( - id BIGINT PRIMARY KEY DEFAULT nextval('seq_job_files_id'), - job_id BIGINT NOT NULL, + id INTEGER PRIMARY KEY AUTOINCREMENT, + job_id INTEGER NOT NULL, file_type TEXT NOT NULL, file_path TEXT NOT NULL, file_name TEXT NOT NULL, - file_size BIGINT NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + file_size INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (job_id) REFERENCES jobs(id) ); CREATE TABLE IF NOT EXISTS manager_secrets ( - id BIGINT PRIMARY KEY DEFAULT nextval('seq_manager_secrets_id'), + id INTEGER PRIMARY KEY AUTOINCREMENT, secret TEXT UNIQUE NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ); CREATE TABLE IF NOT EXISTS task_logs ( - id BIGINT PRIMARY KEY DEFAULT nextval('seq_task_logs_id'), - task_id BIGINT NOT NULL, - runner_id BIGINT, + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_id INTEGER NOT NULL, + runner_id INTEGER, log_level TEXT NOT NULL, message TEXT NOT NULL, step_name TEXT, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (task_id) REFERENCES tasks(id), + FOREIGN KEY (runner_id) REFERENCES runners(id) ); CREATE TABLE IF NOT EXISTS task_steps ( - id BIGINT PRIMARY KEY DEFAULT nextval('seq_task_steps_id'), - task_id BIGINT NOT NULL, + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_id INTEGER NOT NULL, step_name TEXT NOT NULL, status TEXT NOT NULL DEFAULT 'pending', started_at TIMESTAMP, completed_at TIMESTAMP, duration_ms INTEGER, - error_message TEXT + error_message TEXT, + FOREIGN KEY (task_id) REFERENCES tasks(id) ); CREATE INDEX IF NOT EXISTS idx_jobs_user_id ON jobs(user_id); @@ -184,6 +215,7 @@ func (db *DB) migrate() error { CREATE INDEX IF NOT EXISTS idx_job_files_job_id ON job_files(job_id); CREATE INDEX IF NOT EXISTS idx_runner_api_keys_prefix ON runner_api_keys(key_prefix); CREATE INDEX IF NOT EXISTS idx_runner_api_keys_active ON runner_api_keys(is_active); + CREATE INDEX IF NOT EXISTS idx_runner_api_keys_created_by ON runner_api_keys(created_by); CREATE INDEX IF NOT EXISTS idx_runners_api_key_id ON runners(api_key_id); CREATE INDEX IF NOT EXISTS idx_task_logs_task_id_created_at ON task_logs(task_id, created_at); CREATE INDEX IF NOT EXISTS idx_task_logs_task_id_id ON task_logs(task_id, id DESC); @@ -196,9 +228,28 @@ func (db *DB) migrate() error { value TEXT NOT NULL, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ); + + CREATE TABLE IF NOT EXISTS sessions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT UNIQUE NOT NULL, + user_id INTEGER NOT NULL, + email TEXT NOT NULL, + name TEXT NOT NULL, + is_admin INTEGER NOT NULL DEFAULT 0, + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users(id) + ); + + CREATE INDEX IF NOT EXISTS idx_sessions_session_id ON sessions(session_id); + CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id); + CREATE INDEX IF NOT EXISTS idx_sessions_expires_at ON sessions(expires_at); ` - if _, err := db.Exec(schema); err != nil { + if err := db.With(func(conn *sql.DB) error { + _, err := conn.Exec(schema) + return err + }); err != nil { return fmt.Errorf("failed to create schema: %w", err) } @@ -212,19 +263,26 @@ func (db *DB) migrate() error { } for _, migration := range migrations { - // DuckDB supports IF NOT EXISTS for ALTER TABLE, so we can safely execute - if _, err := db.Exec(migration); err != nil { + if err := db.With(func(conn *sql.DB) error { + _, err := conn.Exec(migration) + return err + }); err != nil { // Log but don't fail - column might already exist or table might not exist yet // This is fine for migrations that run after schema creation - // For the file_size migration, if it fails (e.g., already BIGINT), that's fine + log.Printf("Migration warning: %v", err) } } // Initialize registration_enabled setting (default: true) if it doesn't exist var settingCount int - err := db.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount) + err := db.With(func(conn *sql.DB) error { + return conn.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount) + }) if err == nil && settingCount == 0 { - _, err = db.Exec("INSERT INTO settings (key, value) VALUES (?, ?)", "registration_enabled", "true") + err = db.With(func(conn *sql.DB) error { + _, err := conn.Exec("INSERT INTO settings (key, value) VALUES (?, ?)", "registration_enabled", "true") + return err + }) if err != nil { // Log but don't fail - setting might have been created by another process log.Printf("Note: Could not initialize registration_enabled setting: %v", err) @@ -234,7 +292,16 @@ func (db *DB) migrate() error { return nil } +// Ping checks the database connection +func (db *DB) Ping() error { + db.mu.Lock() + defer db.mu.Unlock() + return db.db.Ping() +} + // Close closes the database connection func (db *DB) Close() error { - return db.DB.Close() + db.mu.Lock() + defer db.mu.Unlock() + return db.db.Close() } diff --git a/internal/logger/logger.go b/internal/logger/logger.go index 3c5db35..d2ec81e 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -1,31 +1,87 @@ package logger import ( + "fmt" "io" "log" "os" "path/filepath" "sync" - - "gopkg.in/natefinch/lumberjack.v2" ) +// Level represents log severity +type Level int + +const ( + LevelDebug Level = iota + LevelInfo + LevelWarn + LevelError +) + +var levelNames = map[Level]string{ + LevelDebug: "DEBUG", + LevelInfo: "INFO", + LevelWarn: "WARN", + LevelError: "ERROR", +} + +// ParseLevel parses a level string into a Level +func ParseLevel(s string) Level { + switch s { + case "debug", "DEBUG": + return LevelDebug + case "info", "INFO": + return LevelInfo + case "warn", "WARN", "warning", "WARNING": + return LevelWarn + case "error", "ERROR": + return LevelError + default: + return LevelInfo + } +} + var ( defaultLogger *Logger once sync.Once + currentLevel Level = LevelInfo ) -// Logger wraps the standard log.Logger with file and stdout output +// Logger wraps the standard log.Logger with optional file output and levels type Logger struct { *log.Logger fileWriter io.WriteCloser } -// Init initializes the default logger with both file and stdout output -func Init(logDir, logFileName string, maxSizeMB int, maxBackups int, maxAgeDays int) error { +// SetLevel sets the global log level +func SetLevel(level Level) { + currentLevel = level +} + +// GetLevel returns the current log level +func GetLevel() Level { + return currentLevel +} + +// InitStdout initializes the logger to only write to stdout +func InitStdout() { + once.Do(func() { + log.SetOutput(os.Stdout) + log.SetFlags(log.LstdFlags | log.Lshortfile) + defaultLogger = &Logger{ + Logger: log.Default(), + fileWriter: nil, + } + }) +} + +// InitWithFile initializes the logger with both file and stdout output +// The file is truncated on each start +func InitWithFile(logPath string) error { var err error once.Do(func() { - defaultLogger, err = New(logDir, logFileName, maxSizeMB, maxBackups, maxAgeDays) + defaultLogger, err = NewWithFile(logPath) if err != nil { return } @@ -37,22 +93,19 @@ func Init(logDir, logFileName string, maxSizeMB int, maxBackups int, maxAgeDays return err } -// New creates a new logger that writes to both stdout and a log file -func New(logDir, logFileName string, maxSizeMB int, maxBackups int, maxAgeDays int) (*Logger, error) { +// NewWithFile creates a new logger that writes to both stdout and a log file +// The file is truncated on each start +func NewWithFile(logPath string) (*Logger, error) { // Ensure log directory exists + logDir := filepath.Dir(logPath) if err := os.MkdirAll(logDir, 0755); err != nil { return nil, err } - logPath := filepath.Join(logDir, logFileName) - - // Create file writer with rotation - fileWriter := &lumberjack.Logger{ - Filename: logPath, - MaxSize: maxSizeMB, // megabytes - MaxBackups: maxBackups, // number of backup files - MaxAge: maxAgeDays, // days - Compress: true, // compress old log files + // Create/truncate the log file + fileWriter, err := os.Create(logPath) + if err != nil { + return nil, err } // Create multi-writer that writes to both stdout and file @@ -80,48 +133,91 @@ func GetDefault() *Logger { return defaultLogger } -// Printf logs a formatted message -func Printf(format string, v ...interface{}) { - if defaultLogger != nil { - defaultLogger.Printf(format, v...) - } else { - log.Printf(format, v...) +// logf logs a formatted message at the given level +func logf(level Level, format string, v ...interface{}) { + if level < currentLevel { + return } + prefix := fmt.Sprintf("[%s] ", levelNames[level]) + msg := fmt.Sprintf(format, v...) + log.Print(prefix + msg) } -// Print logs a message -func Print(v ...interface{}) { - if defaultLogger != nil { - defaultLogger.Print(v...) - } else { - log.Print(v...) +// logln logs a message at the given level +func logln(level Level, v ...interface{}) { + if level < currentLevel { + return } + prefix := fmt.Sprintf("[%s] ", levelNames[level]) + msg := fmt.Sprint(v...) + log.Print(prefix + msg) } -// Println logs a message with newline -func Println(v ...interface{}) { - if defaultLogger != nil { - defaultLogger.Println(v...) - } else { - log.Println(v...) - } +// Debug logs a debug message +func Debug(v ...interface{}) { + logln(LevelDebug, v...) } -// Fatal logs a message and exits +// Debugf logs a formatted debug message +func Debugf(format string, v ...interface{}) { + logf(LevelDebug, format, v...) +} + +// Info logs an info message +func Info(v ...interface{}) { + logln(LevelInfo, v...) +} + +// Infof logs a formatted info message +func Infof(format string, v ...interface{}) { + logf(LevelInfo, format, v...) +} + +// Warn logs a warning message +func Warn(v ...interface{}) { + logln(LevelWarn, v...) +} + +// Warnf logs a formatted warning message +func Warnf(format string, v ...interface{}) { + logf(LevelWarn, format, v...) +} + +// Error logs an error message +func Error(v ...interface{}) { + logln(LevelError, v...) +} + +// Errorf logs a formatted error message +func Errorf(format string, v ...interface{}) { + logf(LevelError, format, v...) +} + +// Fatal logs an error message and exits func Fatal(v ...interface{}) { - if defaultLogger != nil { - defaultLogger.Fatal(v...) - } else { - log.Fatal(v...) - } + logln(LevelError, v...) + os.Exit(1) } -// Fatalf logs a formatted message and exits +// Fatalf logs a formatted error message and exits func Fatalf(format string, v ...interface{}) { - if defaultLogger != nil { - defaultLogger.Fatalf(format, v...) - } else { - log.Fatalf(format, v...) - } + logf(LevelError, format, v...) + os.Exit(1) } +// --- Backwards compatibility (maps to Info level) --- + +// Printf logs a formatted message at Info level +func Printf(format string, v ...interface{}) { + logf(LevelInfo, format, v...) +} + +// Print logs a message at Info level +func Print(v ...interface{}) { + logln(LevelInfo, v...) +} + +// Println logs a message at Info level +func Println(v ...interface{}) { + logln(LevelInfo, v...) + } diff --git a/internal/runner/client.go b/internal/runner/client.go index 2fce75b..eae4160 100644 --- a/internal/runner/client.go +++ b/internal/runner/client.go @@ -25,6 +25,7 @@ import ( "sync" "time" + "jiggablend/pkg/executils" "jiggablend/pkg/scripts" "jiggablend/pkg/types" @@ -45,19 +46,19 @@ type Client struct { stopChan chan struct{} stepStartTimes map[string]time.Time // key: "taskID:stepName" stepTimesMu sync.RWMutex - workspaceDir string // Persistent workspace directory for this runner - runningProcs sync.Map // map[int64]*exec.Cmd - tracks running processes by task ID - capabilities map[string]interface{} // Cached capabilities from initial probe (includes bools and numbers) - capabilitiesMu sync.RWMutex // Protects capabilities - hwAccelCache map[string]bool // Cached hardware acceleration detection results - hwAccelCacheMu sync.RWMutex // Protects hwAccelCache - vaapiDevices []string // Cached VAAPI device paths (all available devices) - vaapiDevicesMu sync.RWMutex // Protects vaapiDevices - allocatedDevices map[int64]string // map[taskID]device - tracks which device is allocated to which task - allocatedDevicesMu sync.RWMutex // Protects allocatedDevices - longRunningClient *http.Client // HTTP client for long-running operations (no timeout) - fingerprint string // Unique hardware fingerprint for this runner - fingerprintMu sync.RWMutex // Protects fingerprint + workspaceDir string // Persistent workspace directory for this runner + processTracker *executils.ProcessTracker // Tracks running processes for cleanup + capabilities map[string]interface{} // Cached capabilities from initial probe (includes bools and numbers) + capabilitiesMu sync.RWMutex // Protects capabilities + hwAccelCache map[string]bool // Cached hardware acceleration detection results + hwAccelCacheMu sync.RWMutex // Protects hwAccelCache + vaapiDevices []string // Cached VAAPI device paths (all available devices) + vaapiDevicesMu sync.RWMutex // Protects vaapiDevices + allocatedDevices map[int64]string // map[taskID]device - tracks which device is allocated to which task + allocatedDevicesMu sync.RWMutex // Protects allocatedDevices + longRunningClient *http.Client // HTTP client for long-running operations (no timeout) + fingerprint string // Unique hardware fingerprint for this runner + fingerprintMu sync.RWMutex // Protects fingerprint } // NewClient creates a new runner client @@ -70,6 +71,7 @@ func NewClient(managerURL, name, hostname string) *Client { longRunningClient: &http.Client{Timeout: 0}, // No timeout for long-running operations (context downloads, file uploads/downloads) stopChan: make(chan struct{}), stepStartTimes: make(map[string]time.Time), + processTracker: executils.NewProcessTracker(), } // Generate fingerprint immediately client.generateFingerprint() @@ -226,12 +228,6 @@ func (c *Client) probeCapabilities() map[string]interface{} { c.probeGPUCapabilities(capabilities) } else { capabilities["ffmpeg"] = false - // Set defaults when ffmpeg is not available - capabilities["vaapi"] = false - capabilities["vaapi_gpu_count"] = 0 - capabilities["nvenc"] = false - capabilities["nvenc_gpu_count"] = 0 - capabilities["video_gpu_count"] = 0 } return capabilities @@ -256,60 +252,6 @@ func (c *Client) probeGPUCapabilities(capabilities map[string]interface{}) { log.Printf("Available hardware encoders: %v", getKeys(hwEncoders)) } - // Check for VAAPI devices and count them - log.Printf("Checking for VAAPI hardware acceleration...") - - // First check if encoder is listed (more reliable than testing) - cmd := exec.Command("ffmpeg", "-hide_banner", "-encoders") - output, err := cmd.CombinedOutput() - hasVAAPIEncoder := false - if err == nil { - encoderOutput := string(output) - if strings.Contains(encoderOutput, "h264_vaapi") { - hasVAAPIEncoder = true - log.Printf("VAAPI encoder (h264_vaapi) found in ffmpeg encoders list") - } - } - - if hasVAAPIEncoder { - // Try to find and test devices - vaapiDevices := c.findVAAPIDevices() - capabilities["vaapi_gpu_count"] = len(vaapiDevices) - if len(vaapiDevices) > 0 { - capabilities["vaapi"] = true - log.Printf("VAAPI detected: %d GPU device(s) available: %v", len(vaapiDevices), vaapiDevices) - } else { - capabilities["vaapi"] = false - log.Printf("VAAPI encoder available but no working devices found") - log.Printf(" This might indicate:") - log.Printf(" - Missing or incorrect GPU drivers") - log.Printf(" - Missing libva or mesa-va-drivers packages") - log.Printf(" - Permission issues accessing /dev/dri devices") - log.Printf(" - GPU not properly initialized") - } - } else { - capabilities["vaapi"] = false - capabilities["vaapi_gpu_count"] = 0 - log.Printf("VAAPI encoder not available in ffmpeg") - log.Printf(" This might indicate:") - log.Printf(" - FFmpeg was not compiled with VAAPI support") - log.Printf(" - Missing libva development libraries during FFmpeg compilation") - } - - // Check for NVENC (NVIDIA) - try to detect multiple GPUs - log.Printf("Checking for NVENC hardware acceleration...") - if c.checkEncoderAvailable("h264_nvenc") { - capabilities["nvenc"] = true - // Try to detect actual GPU count using nvidia-smi if available - nvencCount := c.detectNVENCCount() - capabilities["nvenc_gpu_count"] = nvencCount - log.Printf("NVENC detected: %d GPU(s)", nvencCount) - } else { - capabilities["nvenc"] = false - capabilities["nvenc_gpu_count"] = 0 - log.Printf("NVENC encoder not available") - } - // Check for other hardware encoders (for completeness) log.Printf("Checking for other hardware encoders...") if c.checkEncoderAvailable("h264_qsv") { @@ -368,73 +310,6 @@ func (c *Client) probeGPUCapabilities(capabilities map[string]interface{}) { capabilities["mediacodec"] = false capabilities["mediacodec_gpu_count"] = 0 } - - // Calculate total GPU count for video encoding - // Priority: VAAPI > NVENC > QSV > VideoToolbox > AMF > others - vaapiCount := 0 - if count, ok := capabilities["vaapi_gpu_count"].(int); ok { - vaapiCount = count - } - nvencCount := 0 - if count, ok := capabilities["nvenc_gpu_count"].(int); ok { - nvencCount = count - } - qsvCount := 0 - if count, ok := capabilities["qsv_gpu_count"].(int); ok { - qsvCount = count - } - videotoolboxCount := 0 - if count, ok := capabilities["videotoolbox_gpu_count"].(int); ok { - videotoolboxCount = count - } - amfCount := 0 - if count, ok := capabilities["amf_gpu_count"].(int); ok { - amfCount = count - } - - // Total GPU count - use the best available (they can't be used simultaneously) - totalGPUs := vaapiCount - if totalGPUs == 0 { - totalGPUs = nvencCount - } - if totalGPUs == 0 { - totalGPUs = qsvCount - } - if totalGPUs == 0 { - totalGPUs = videotoolboxCount - } - if totalGPUs == 0 { - totalGPUs = amfCount - } - capabilities["video_gpu_count"] = totalGPUs - - if totalGPUs > 0 { - log.Printf("Total video GPU count: %d", totalGPUs) - } else { - log.Printf("No hardware-accelerated video encoding GPUs detected (will use software encoding)") - } -} - -// detectNVENCCount tries to detect the actual number of NVIDIA GPUs using nvidia-smi -func (c *Client) detectNVENCCount() int { - // Try to use nvidia-smi to count GPUs - cmd := exec.Command("nvidia-smi", "--list-gpus") - output, err := cmd.CombinedOutput() - if err == nil { - // Count lines that contain "GPU" (each GPU is listed on a separate line) - lines := strings.Split(string(output), "\n") - count := 0 - for _, line := range lines { - if strings.Contains(line, "GPU") { - count++ - } - } - if count > 0 { - return count - } - } - // Fallback to 1 if nvidia-smi is not available - return 1 } // getKeys returns all keys from a map as a slice (helper function) @@ -926,29 +801,13 @@ func (c *Client) sendLog(taskID int64, logLevel types.LogLevel, message, stepNam // KillAllProcesses kills all running processes tracked by this client func (c *Client) KillAllProcesses() { log.Printf("Killing all running processes...") - var killedCount int - c.runningProcs.Range(func(key, value interface{}) bool { - taskID := key.(int64) - cmd := value.(*exec.Cmd) - if cmd.Process != nil { - log.Printf("Killing process for task %d (PID: %d)", taskID, cmd.Process.Pid) - // Try graceful kill first (SIGTERM) - if err := cmd.Process.Signal(os.Interrupt); err != nil { - log.Printf("Failed to send SIGINT to process %d: %v", cmd.Process.Pid, err) - } - // Give it a moment to clean up - time.Sleep(100 * time.Millisecond) - // Force kill if still running - if err := cmd.Process.Kill(); err != nil { - log.Printf("Failed to kill process %d: %v", cmd.Process.Pid, err) - } else { - killedCount++ - } - } - // Release any allocated device for this task - c.releaseVAAPIDevice(taskID) - return true - }) + killedCount := c.processTracker.KillAll() + // Release all allocated VAAPI devices + c.allocatedDevicesMu.Lock() + for taskID := range c.allocatedDevices { + delete(c.allocatedDevices, taskID) + } + c.allocatedDevicesMu.Unlock() log.Printf("Killed %d process(es)", killedCount) } @@ -1272,55 +1131,33 @@ func (c *Client) processTask(task map[string]interface{}, jobName string, output // Run Blender with GPU enabled via Python script // Use -s (start) and -e (end) for frame ranges, or -f for single frame + // Use Blender's automatic frame numbering with #### pattern var cmd *exec.Cmd args := []string{"-b", blendFile, "--python", scriptPath} if enableExecution { args = append(args, "--enable-autoexec") } - // Always render frames individually for precise control over file naming - // This avoids Blender's automatic frame numbering quirks - for frame := frameStart; frame <= frameEnd; frame++ { - // Create temp output pattern for this frame - tempPattern := filepath.Join(outputDir, fmt.Sprintf("temp_frame.%s", strings.ToLower(renderFormat))) - tempAbsPattern, _ := filepath.Abs(tempPattern) - // Build args for this specific frame - frameArgs := []string{"-b", blendFile, "--python", scriptPath} - if enableExecution { - frameArgs = append(frameArgs, "--enable-autoexec") - } - frameArgs = append(frameArgs, "-o", tempAbsPattern, "-f", fmt.Sprintf("%d", frame)) + // Output pattern uses #### which Blender will replace with frame numbers + outputPattern := filepath.Join(outputDir, fmt.Sprintf("frame_####.%s", strings.ToLower(renderFormat))) + outputAbsPattern, _ := filepath.Abs(outputPattern) + args = append(args, "-o", outputAbsPattern) - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Rendering frame %d...", frame), "render_blender") - - frameCmd := exec.Command("blender", frameArgs...) - frameCmd.Dir = workDir - frameCmd.Env = os.Environ() - - // Run this frame - if output, err := frameCmd.CombinedOutput(); err != nil { - errMsg := fmt.Sprintf("blender failed on frame %d: %v (output: %s)", frame, err, string(output)) - c.sendLog(taskID, types.LogLevelError, errMsg, "render_blender") - return errors.New(errMsg) - } - - // Immediately rename the temp file to the proper frame-numbered name - finalName := fmt.Sprintf("frame_%04d.%s", frame, strings.ToLower(renderFormat)) - finalPath := filepath.Join(outputDir, finalName) - tempPath := filepath.Join(outputDir, fmt.Sprintf("temp_frame.%s", strings.ToLower(renderFormat))) - - if err := os.Rename(tempPath, finalPath); err != nil { - errMsg := fmt.Sprintf("failed to rename temp file for frame %d: %v", frame, err) - c.sendLog(taskID, types.LogLevelError, errMsg, "render_blender") - return errors.New(errMsg) - } - - c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Completed frame %d -> %s", frame, finalName), "render_blender") + if frameStart == frameEnd { + // Single frame + args = append(args, "-f", fmt.Sprintf("%d", frameStart)) + c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Rendering frame %d...", frameStart), "render_blender") + } else { + // Frame range + args = append(args, "-s", fmt.Sprintf("%d", frameStart), "-e", fmt.Sprintf("%d", frameEnd)) + c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Rendering frames %d-%d...", frameStart, frameEnd), "render_blender") } - // Skip the rest of the function since we handled all frames above - c.sendStepUpdate(taskID, "render_blender", types.StepStatusCompleted, "") - return nil + // Create and run Blender command + cmd = exec.Command("blender", args...) + cmd.Dir = workDir + cmd.Env = os.Environ() + // Blender will handle headless rendering automatically // We preserve the environment to allow GPU access if available @@ -1350,8 +1187,8 @@ func (c *Client) processTask(task map[string]interface{}, jobName string, output } // Register process for cleanup on shutdown - c.runningProcs.Store(taskID, cmd) - defer c.runningProcs.Delete(taskID) + c.processTracker.Track(taskID, cmd) + defer c.processTracker.Untrack(taskID) // Stream stdout line by line stdoutDone := make(chan bool) @@ -1396,15 +1233,23 @@ func (c *Client) processTask(task map[string]interface{}, jobName string, output <-stdoutDone <-stderrDone if err != nil { - errMsg := fmt.Sprintf("blender failed: %v", err) + var errMsg string + if exitErr, ok := err.(*exec.ExitError); ok { + if exitErr.ExitCode() == 137 { + errMsg = "Blender was killed due to excessive memory usage (OOM)" + } else { + errMsg = fmt.Sprintf("blender failed: %v", err) + } + } else { + errMsg = fmt.Sprintf("blender failed: %v", err) + } c.sendLog(taskID, types.LogLevelError, errMsg, "render_blender") c.sendStepUpdate(taskID, "render_blender", types.StepStatusFailed, errMsg) return errors.New(errMsg) } - // For frame ranges, we rendered each frame individually with temp naming - // The files are already properly named during the individual frame rendering - // No additional renaming needed + // Blender has rendered frames with automatic numbering using the #### pattern + // Files will be named like frame_0001.png, frame_0002.png, etc. // Find rendered output file(s) // For frame ranges, we'll find all frames in the upload step @@ -1748,8 +1593,8 @@ func (c *Client) processVideoGenerationTask(task map[string]interface{}, jobID i // Extract frame number pattern (e.g., frame_2470.exr -> frame_%04d.exr) baseName := filepath.Base(firstFrame) // Find the numeric part and replace it with %04d pattern - // Use regex to find digits (including negative) after underscore and before extension - re := regexp.MustCompile(`_(-?\d+)\.`) + // Use regex to find digits (positive only, negative frames not supported) after underscore and before extension + re := regexp.MustCompile(`_(\d+)\.`) var pattern string var startNumber int frameNumStr := re.FindStringSubmatch(baseName) @@ -1763,6 +1608,7 @@ func (c *Client) processVideoGenerationTask(task map[string]interface{}, jobID i startNumber = extractFrameNumber(baseName) pattern = strings.Replace(baseName, fmt.Sprintf("%d", startNumber), "%04d", 1) } + // Pattern path should be in workDir where the frame files are downloaded patternPath := filepath.Join(workDir, pattern) // Allocate a VAAPI device for this task (if available) @@ -1891,8 +1737,8 @@ func (c *Client) processVideoGenerationTask(task map[string]interface{}, jobID i } // Register process for cleanup on shutdown - c.runningProcs.Store(taskID, cmd) - defer c.runningProcs.Delete(taskID) + c.processTracker.Track(taskID, cmd) + defer c.processTracker.Untrack(taskID) // Stream stdout line by line stdoutDone := make(chan bool) @@ -1959,26 +1805,25 @@ func (c *Client) processVideoGenerationTask(task map[string]interface{}, jobID i <-stderrDone if err != nil { + var errMsg string + if exitErr, ok := err.(*exec.ExitError); ok { + if exitErr.ExitCode() == 137 { + errMsg = "FFmpeg was killed due to excessive memory usage (OOM)" + } else { + errMsg = fmt.Sprintf("ffmpeg encoding failed: %v", err) + } + } else { + errMsg = fmt.Sprintf("ffmpeg encoding failed: %v", err) + } // Check for size-related errors and provide helpful messages - if sizeErr := c.checkFFmpegSizeError("ffmpeg encoding failed"); sizeErr != nil { + if sizeErr := c.checkFFmpegSizeError(errMsg); sizeErr != nil { c.sendLog(taskID, types.LogLevelError, sizeErr.Error(), "generate_video") c.sendStepUpdate(taskID, "generate_video", types.StepStatusFailed, sizeErr.Error()) return sizeErr } - - // Try alternative method with concat demuxer - c.sendLog(taskID, types.LogLevelWarn, "Primary ffmpeg encoding failed, trying concat method...", "generate_video") - err = c.generateMP4WithConcat(frameFiles, outputMP4, workDir, allocatedDevice, outputFormat, codec, pixFmt, useAlpha, useHardware, frameRate) - if err != nil { - // Check for size errors in concat method too - if sizeErr := c.checkFFmpegSizeError(err.Error()); sizeErr != nil { - c.sendLog(taskID, types.LogLevelError, sizeErr.Error(), "generate_video") - c.sendStepUpdate(taskID, "generate_video", types.StepStatusFailed, sizeErr.Error()) - return sizeErr - } - c.sendStepUpdate(taskID, "generate_video", types.StepStatusFailed, err.Error()) - return err - } + c.sendLog(taskID, types.LogLevelError, errMsg, "generate_video") + c.sendStepUpdate(taskID, "generate_video", types.StepStatusFailed, errMsg) + return errors.New(errMsg) } // Check if MP4 was created @@ -2771,7 +2616,7 @@ func (c *Client) testGenericEncoder(encoder string) bool { // generateMP4WithConcat uses ffmpeg concat demuxer as fallback // device parameter is optional - if provided, it will be used for VAAPI encoding -func (c *Client) generateMP4WithConcat(frameFiles []string, outputMP4, workDir string, device string, outputFormat string, codec string, pixFmt string, useAlpha bool, useHardware bool, frameRate float64) error { +func (c *Client) generateMP4WithConcat(taskID int, frameFiles []string, outputMP4, workDir string, device string, outputFormat string, codec string, pixFmt string, useAlpha bool, useHardware bool, frameRate float64) error { // Create file list for ffmpeg concat demuxer listFile := filepath.Join(workDir, "frames.txt") listFileHandle, err := os.Create(listFile) @@ -2907,11 +2752,23 @@ func (c *Client) generateMP4WithConcat(frameFiles []string, outputMP4, workDir s <-stderrDone if err != nil { + var errMsg string + if exitErr, ok := err.(*exec.ExitError); ok { + if exitErr.ExitCode() == 137 { + errMsg = "FFmpeg was killed due to excessive memory usage (OOM)" + } else { + errMsg = fmt.Sprintf("ffmpeg concat failed: %v", err) + } + } else { + errMsg = fmt.Sprintf("ffmpeg concat failed: %v", err) + } // Check for size-related errors - if sizeErr := c.checkFFmpegSizeError("ffmpeg concat failed"); sizeErr != nil { + if sizeErr := c.checkFFmpegSizeError(errMsg); sizeErr != nil { return sizeErr } - return fmt.Errorf("ffmpeg concat failed: %w", err) + c.sendLog(int64(taskID), types.LogLevelError, errMsg, "generate_video") + c.sendStepUpdate(int64(taskID), "generate_video", types.StepStatusFailed, errMsg) + return errors.New(errMsg) } if _, err := os.Stat(outputMP4); os.IsNotExist(err) { @@ -3695,8 +3552,8 @@ sys.stdout.flush() } // Register process for cleanup on shutdown - c.runningProcs.Store(taskID, cmd) - defer c.runningProcs.Delete(taskID) + c.processTracker.Track(taskID, cmd) + defer c.processTracker.Untrack(taskID) // Stream stdout line by line and collect for JSON parsing stdoutDone := make(chan bool) @@ -3743,7 +3600,16 @@ sys.stdout.flush() <-stdoutDone <-stderrDone if err != nil { - errMsg := fmt.Sprintf("blender metadata extraction failed: %v", err) + var errMsg string + if exitErr, ok := err.(*exec.ExitError); ok { + if exitErr.ExitCode() == 137 { + errMsg = "Blender metadata extraction was killed due to excessive memory usage (OOM)" + } else { + errMsg = fmt.Sprintf("blender metadata extraction failed: %v", err) + } + } else { + errMsg = fmt.Sprintf("blender metadata extraction failed: %v", err) + } c.sendLog(taskID, types.LogLevelError, errMsg, "extract_metadata") c.sendStepUpdate(taskID, "extract_metadata", types.StepStatusFailed, errMsg) return errors.New(errMsg) diff --git a/internal/storage/storage.go b/internal/storage/storage.go index 1b7fee7..464e3fe 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -378,43 +378,47 @@ func (s *Storage) CreateJobContext(jobID int64) (string, error) { return "", fmt.Errorf("failed to open file %s: %w", filePath, err) } - info, err := file.Stat() + // Use a function closure to ensure file is closed even on error + err = func() error { + defer file.Close() + + info, err := file.Stat() + if err != nil { + return fmt.Errorf("failed to stat file %s: %w", filePath, err) + } + + // Get relative path for tar header + relPath, err := filepath.Rel(jobPath, filePath) + if err != nil { + return fmt.Errorf("failed to get relative path for %s: %w", filePath, err) + } + + // Normalize path separators for tar (use forward slashes) + tarPath := filepath.ToSlash(relPath) + + // Create tar header + header, err := tar.FileInfoHeader(info, "") + if err != nil { + return fmt.Errorf("failed to create tar header for %s: %w", filePath, err) + } + header.Name = tarPath + + // Write header + if err := tarWriter.WriteHeader(header); err != nil { + return fmt.Errorf("failed to write tar header for %s: %w", filePath, err) + } + + // Copy file contents using streaming + if _, err := io.Copy(tarWriter, file); err != nil { + return fmt.Errorf("failed to write file %s to tar: %w", filePath, err) + } + + return nil + }() + if err != nil { - file.Close() - return "", fmt.Errorf("failed to stat file %s: %w", filePath, err) + return "", err } - - // Get relative path for tar header - relPath, err := filepath.Rel(jobPath, filePath) - if err != nil { - file.Close() - return "", fmt.Errorf("failed to get relative path for %s: %w", filePath, err) - } - - // Normalize path separators for tar (use forward slashes) - tarPath := filepath.ToSlash(relPath) - - // Create tar header - header, err := tar.FileInfoHeader(info, "") - if err != nil { - file.Close() - return "", fmt.Errorf("failed to create tar header for %s: %w", filePath, err) - } - header.Name = tarPath - - // Write header - if err := tarWriter.WriteHeader(header); err != nil { - file.Close() - return "", fmt.Errorf("failed to write tar header for %s: %w", filePath, err) - } - - // Copy file contents using streaming - if _, err := io.Copy(tarWriter, file); err != nil { - file.Close() - return "", fmt.Errorf("failed to write file %s to tar: %w", filePath, err) - } - - file.Close() } // Ensure all data is flushed @@ -550,42 +554,47 @@ func (s *Storage) CreateJobContextFromDir(sourceDir string, jobID int64, exclude return "", fmt.Errorf("failed to open file %s: %w", filePath, err) } - info, err := file.Stat() + // Use a function closure to ensure file is closed even on error + err = func() error { + defer file.Close() + + info, err := file.Stat() + if err != nil { + return fmt.Errorf("failed to stat file %s: %w", filePath, err) + } + + // Get relative path and strip common prefix if present + relPath := relPaths[i] + tarPath := filepath.ToSlash(relPath) + + // Strip common prefix if found + if commonPrefix != "" && strings.HasPrefix(tarPath, commonPrefix) { + tarPath = strings.TrimPrefix(tarPath, commonPrefix) + } + + // Create tar header + header, err := tar.FileInfoHeader(info, "") + if err != nil { + return fmt.Errorf("failed to create tar header for %s: %w", filePath, err) + } + header.Name = tarPath + + // Write header + if err := tarWriter.WriteHeader(header); err != nil { + return fmt.Errorf("failed to write tar header for %s: %w", filePath, err) + } + + // Copy file contents using streaming + if _, err := io.Copy(tarWriter, file); err != nil { + return fmt.Errorf("failed to write file %s to tar: %w", filePath, err) + } + + return nil + }() + if err != nil { - file.Close() - return "", fmt.Errorf("failed to stat file %s: %w", filePath, err) + return "", err } - - // Get relative path and strip common prefix if present - relPath := relPaths[i] - tarPath := filepath.ToSlash(relPath) - - // Strip common prefix if found - if commonPrefix != "" && strings.HasPrefix(tarPath, commonPrefix) { - tarPath = strings.TrimPrefix(tarPath, commonPrefix) - } - - // Create tar header - header, err := tar.FileInfoHeader(info, "") - if err != nil { - file.Close() - return "", fmt.Errorf("failed to create tar header for %s: %w", filePath, err) - } - header.Name = tarPath - - // Write header - if err := tarWriter.WriteHeader(header); err != nil { - file.Close() - return "", fmt.Errorf("failed to write tar header for %s: %w", filePath, err) - } - - // Copy file contents using streaming - if _, err := io.Copy(tarWriter, file); err != nil { - file.Close() - return "", fmt.Errorf("failed to write file %s to tar: %w", filePath, err) - } - - file.Close() } // Ensure all data is flushed diff --git a/jiggablend b/jiggablend new file mode 100755 index 0000000..1598af5 Binary files /dev/null and b/jiggablend differ diff --git a/pkg/executils/exec.go b/pkg/executils/exec.go new file mode 100644 index 0000000..3e4c18d --- /dev/null +++ b/pkg/executils/exec.go @@ -0,0 +1,366 @@ +package executils + +import ( + "bufio" + "errors" + "fmt" + "os" + "os/exec" + "sync" + "time" + + "jiggablend/pkg/types" +) + +// DefaultTracker is the global default process tracker +// Use this for processes that should be tracked globally and killed on shutdown +var DefaultTracker = NewProcessTracker() + +// ProcessTracker tracks running processes for cleanup +type ProcessTracker struct { + processes sync.Map // map[int64]*exec.Cmd - tracks running processes by task ID +} + +// NewProcessTracker creates a new process tracker +func NewProcessTracker() *ProcessTracker { + return &ProcessTracker{} +} + +// Track registers a process for tracking +func (pt *ProcessTracker) Track(taskID int64, cmd *exec.Cmd) { + pt.processes.Store(taskID, cmd) +} + +// Untrack removes a process from tracking +func (pt *ProcessTracker) Untrack(taskID int64) { + pt.processes.Delete(taskID) +} + +// Get returns the command for a task ID if it exists +func (pt *ProcessTracker) Get(taskID int64) (*exec.Cmd, bool) { + if val, ok := pt.processes.Load(taskID); ok { + return val.(*exec.Cmd), true + } + return nil, false +} + +// Kill kills a specific process by task ID +// Returns true if the process was found and killed +func (pt *ProcessTracker) Kill(taskID int64) bool { + cmd, ok := pt.Get(taskID) + if !ok || cmd.Process == nil { + return false + } + + // Try graceful kill first (SIGINT) + if err := cmd.Process.Signal(os.Interrupt); err != nil { + // If SIGINT fails, try SIGKILL + cmd.Process.Kill() + } else { + // Give it a moment to clean up gracefully + time.Sleep(100 * time.Millisecond) + // Force kill if still running + cmd.Process.Kill() + } + + pt.Untrack(taskID) + return true +} + +// KillAll kills all tracked processes +// Returns the number of processes killed +func (pt *ProcessTracker) KillAll() int { + var killedCount int + pt.processes.Range(func(key, value interface{}) bool { + taskID := key.(int64) + cmd := value.(*exec.Cmd) + if cmd.Process != nil { + // Try graceful kill first (SIGINT) + if err := cmd.Process.Signal(os.Interrupt); err == nil { + // Give it a moment to clean up + time.Sleep(100 * time.Millisecond) + } + // Force kill + cmd.Process.Kill() + killedCount++ + } + pt.processes.Delete(taskID) + return true + }) + return killedCount +} + +// Count returns the number of tracked processes +func (pt *ProcessTracker) Count() int { + count := 0 + pt.processes.Range(func(key, value interface{}) bool { + count++ + return true + }) + return count +} + +// CommandResult holds the output from a command execution +type CommandResult struct { + Stdout string + Stderr string + ExitCode int +} + +// RunCommand executes a command and returns the output +// If tracker is provided, the process will be registered for tracking +// This is useful for commands where you need to capture output (like metadata extraction) +func RunCommand( + cmdPath string, + args []string, + dir string, + env []string, + taskID int64, + tracker *ProcessTracker, +) (*CommandResult, error) { + cmd := exec.Command(cmdPath, args...) + cmd.Dir = dir + if env != nil { + cmd.Env = env + } + + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stdout pipe: %w", err) + } + + stderrPipe, err := cmd.StderrPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stderr pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("failed to start command: %w", err) + } + + // Track the process if tracker is provided + if tracker != nil { + tracker.Track(taskID, cmd) + defer tracker.Untrack(taskID) + } + + // Collect stdout + var stdoutBuf, stderrBuf []byte + var stdoutErr, stderrErr error + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + stdoutBuf, stdoutErr = readAll(stdoutPipe) + }() + + go func() { + defer wg.Done() + stderrBuf, stderrErr = readAll(stderrPipe) + }() + + waitErr := cmd.Wait() + wg.Wait() + + // Check for read errors + if stdoutErr != nil { + return nil, fmt.Errorf("failed to read stdout: %w", stdoutErr) + } + if stderrErr != nil { + return nil, fmt.Errorf("failed to read stderr: %w", stderrErr) + } + + result := &CommandResult{ + Stdout: string(stdoutBuf), + Stderr: string(stderrBuf), + } + + if waitErr != nil { + if exitErr, ok := waitErr.(*exec.ExitError); ok { + result.ExitCode = exitErr.ExitCode() + } else { + result.ExitCode = -1 + } + return result, waitErr + } + + result.ExitCode = 0 + return result, nil +} + +// readAll reads all data from a reader +func readAll(r interface{ Read([]byte) (int, error) }) ([]byte, error) { + var buf []byte + tmp := make([]byte, 4096) + for { + n, err := r.Read(tmp) + if n > 0 { + buf = append(buf, tmp[:n]...) + } + if err != nil { + if err.Error() == "EOF" { + break + } + return buf, err + } + } + return buf, nil +} + +// LogSender is a function type for sending logs +type LogSender func(taskID int, level types.LogLevel, message string, stepName string) + +// LineFilter is a function that processes a line and returns whether to filter it out and the log level +type LineFilter func(line string) (shouldFilter bool, level types.LogLevel) + +// RunCommandWithStreaming executes a command with streaming output and OOM detection +// If tracker is provided, the process will be registered for tracking +func RunCommandWithStreaming( + cmdPath string, + args []string, + dir string, + env []string, + taskID int, + stepName string, + logSender LogSender, + stdoutFilter LineFilter, + stderrFilter LineFilter, + oomMessage string, + tracker *ProcessTracker, +) error { + cmd := exec.Command(cmdPath, args...) + cmd.Dir = dir + cmd.Env = env + + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + errMsg := fmt.Sprintf("failed to create stdout pipe: %v", err) + logSender(taskID, types.LogLevelError, errMsg, stepName) + return errors.New(errMsg) + } + + stderrPipe, err := cmd.StderrPipe() + if err != nil { + errMsg := fmt.Sprintf("failed to create stderr pipe: %v", err) + logSender(taskID, types.LogLevelError, errMsg, stepName) + return errors.New(errMsg) + } + + if err := cmd.Start(); err != nil { + errMsg := fmt.Sprintf("failed to start command: %v", err) + logSender(taskID, types.LogLevelError, errMsg, stepName) + return errors.New(errMsg) + } + + // Track the process if tracker is provided + if tracker != nil { + tracker.Track(int64(taskID), cmd) + defer tracker.Untrack(int64(taskID)) + } + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + scanner := bufio.NewScanner(stdoutPipe) + for scanner.Scan() { + line := scanner.Text() + if line != "" { + shouldFilter, level := stdoutFilter(line) + if !shouldFilter { + logSender(taskID, level, line, stepName) + } + } + } + }() + + go func() { + defer wg.Done() + scanner := bufio.NewScanner(stderrPipe) + for scanner.Scan() { + line := scanner.Text() + if line != "" { + shouldFilter, level := stderrFilter(line) + if !shouldFilter { + logSender(taskID, level, line, stepName) + } + } + } + }() + + err = cmd.Wait() + wg.Wait() + + if err != nil { + var errMsg string + if exitErr, ok := err.(*exec.ExitError); ok { + if exitErr.ExitCode() == 137 { + errMsg = oomMessage + } else { + errMsg = fmt.Sprintf("command failed: %v", err) + } + } else { + errMsg = fmt.Sprintf("command failed: %v", err) + } + logSender(taskID, types.LogLevelError, errMsg, stepName) + return errors.New(errMsg) + } + + return nil +} + +// ============================================================================ +// Helper functions using DefaultTracker +// ============================================================================ + +// Run executes a command using the default tracker and returns the output +// This is a convenience wrapper around RunCommand that uses DefaultTracker +func Run(cmdPath string, args []string, dir string, env []string, taskID int64) (*CommandResult, error) { + return RunCommand(cmdPath, args, dir, env, taskID, DefaultTracker) +} + +// RunStreaming executes a command with streaming output using the default tracker +// This is a convenience wrapper around RunCommandWithStreaming that uses DefaultTracker +func RunStreaming( + cmdPath string, + args []string, + dir string, + env []string, + taskID int, + stepName string, + logSender LogSender, + stdoutFilter LineFilter, + stderrFilter LineFilter, + oomMessage string, +) error { + return RunCommandWithStreaming(cmdPath, args, dir, env, taskID, stepName, logSender, stdoutFilter, stderrFilter, oomMessage, DefaultTracker) +} + +// KillAll kills all processes tracked by the default tracker +// Returns the number of processes killed +func KillAll() int { + return DefaultTracker.KillAll() +} + +// Kill kills a specific process by task ID using the default tracker +// Returns true if the process was found and killed +func Kill(taskID int64) bool { + return DefaultTracker.Kill(taskID) +} + +// Track registers a process with the default tracker +func Track(taskID int64, cmd *exec.Cmd) { + DefaultTracker.Track(taskID, cmd) +} + +// Untrack removes a process from the default tracker +func Untrack(taskID int64) { + DefaultTracker.Untrack(taskID) +} + +// GetTrackedCount returns the number of processes tracked by the default tracker +func GetTrackedCount() int { + return DefaultTracker.Count() +} diff --git a/pkg/scripts/scripts/extract_metadata.py b/pkg/scripts/scripts/extract_metadata.py index d04775b..c25723d 100644 --- a/pkg/scripts/scripts/extract_metadata.py +++ b/pkg/scripts/scripts/extract_metadata.py @@ -46,6 +46,10 @@ scene = bpy.context.scene frame_start = scene.frame_start frame_end = scene.frame_end +# Check for negative frames (not supported) +has_negative_start = frame_start < 0 +has_negative_end = frame_end < 0 + # Also check for actual animation range (keyframes) # Find the earliest and latest keyframes across all objects animation_start = None @@ -54,15 +58,21 @@ animation_end = None for obj in scene.objects: if obj.animation_data and obj.animation_data.action: action = obj.animation_data.action - if action.fcurves: - for fcurve in action.fcurves: - if fcurve.keyframe_points: - for keyframe in fcurve.keyframe_points: - frame = int(keyframe.co[0]) - if animation_start is None or frame < animation_start: - animation_start = frame - if animation_end is None or frame > animation_end: - animation_end = frame + # Check if action has fcurves attribute (varies by Blender version/context) + try: + fcurves = action.fcurves if hasattr(action, 'fcurves') else None + if fcurves: + for fcurve in fcurves: + if fcurve.keyframe_points: + for keyframe in fcurve.keyframe_points: + frame = int(keyframe.co[0]) + if animation_start is None or frame < animation_start: + animation_start = frame + if animation_end is None or frame > animation_end: + animation_end = frame + except (AttributeError, TypeError) as e: + # Action doesn't have fcurves or fcurves is not iterable - skip this object + pass # Use animation range if available, otherwise use scene frame range # If scene range seems wrong (start == end), prefer animation range @@ -72,6 +82,11 @@ if animation_start is not None and animation_end is not None: frame_start = animation_start frame_end = animation_end +# Check for negative frames (not supported) +has_negative_start = frame_start < 0 +has_negative_end = frame_end < 0 +has_negative_animation = (animation_start is not None and animation_start < 0) or (animation_end is not None and animation_end < 0) + # Extract render settings render = scene.render resolution_x = render.resolution_x @@ -87,56 +102,230 @@ engine_settings = {} if engine == 'CYCLES': cycles = scene.cycles + # Get denoiser settings - in Blender 3.0+ it's on the view layer + denoiser = 'OPENIMAGEDENOISE' # Default + denoising_use_gpu = False + denoising_input_passes = 'RGB_ALBEDO_NORMAL' # Default: Albedo and Normal + denoising_prefilter = 'ACCURATE' # Default + denoising_quality = 'HIGH' # Default (for OpenImageDenoise) + try: + view_layer = bpy.context.view_layer + if hasattr(view_layer, 'cycles'): + vl_cycles = view_layer.cycles + denoiser = getattr(vl_cycles, 'denoiser', 'OPENIMAGEDENOISE') + denoising_use_gpu = getattr(vl_cycles, 'denoising_use_gpu', False) + denoising_input_passes = getattr(vl_cycles, 'denoising_input_passes', 'RGB_ALBEDO_NORMAL') + denoising_prefilter = getattr(vl_cycles, 'denoising_prefilter', 'ACCURATE') + # Quality is only for OpenImageDenoise in Blender 4.0+ + denoising_quality = getattr(vl_cycles, 'denoising_quality', 'HIGH') + except: + pass + engine_settings = { - "samples": getattr(cycles, 'samples', 128), + # Sampling settings + "samples": getattr(cycles, 'samples', 4096), # Max Samples + "adaptive_min_samples": getattr(cycles, 'adaptive_min_samples', 0), # Min Samples + "use_adaptive_sampling": getattr(cycles, 'use_adaptive_sampling', True), # Noise Threshold enabled + "adaptive_threshold": getattr(cycles, 'adaptive_threshold', 0.01), # Noise Threshold value + "time_limit": getattr(cycles, 'time_limit', 0.0), # Time Limit (0 = disabled) + + # Denoising settings "use_denoising": getattr(cycles, 'use_denoising', False), - "denoising_radius": getattr(cycles, 'denoising_radius', 0), - "denoising_strength": getattr(cycles, 'denoising_strength', 0.0), + "denoiser": denoiser, + "denoising_use_gpu": denoising_use_gpu, + "denoising_input_passes": denoising_input_passes, + "denoising_prefilter": denoising_prefilter, + "denoising_quality": denoising_quality, + + # Path Guiding settings + "use_guiding": getattr(cycles, 'use_guiding', False), + "guiding_training_samples": getattr(cycles, 'guiding_training_samples', 128), + "use_surface_guiding": getattr(cycles, 'use_surface_guiding', True), + "use_volume_guiding": getattr(cycles, 'use_volume_guiding', True), + + # Lights settings + "use_light_tree": getattr(cycles, 'use_light_tree', True), + "light_sampling_threshold": getattr(cycles, 'light_sampling_threshold', 0.01), + + # Device "device": getattr(cycles, 'device', 'CPU'), - "use_adaptive_sampling": getattr(cycles, 'use_adaptive_sampling', False), - "adaptive_threshold": getattr(cycles, 'adaptive_threshold', 0.01) if getattr(cycles, 'use_adaptive_sampling', False) else 0.01, - "use_fast_gi": getattr(cycles, 'use_fast_gi', False), - "light_tree": getattr(cycles, 'use_light_tree', False), - "use_light_linking": getattr(cycles, 'use_light_linking', False), - "caustics_reflective": getattr(cycles, 'caustics_reflective', False), - "caustics_refractive": getattr(cycles, 'caustics_refractive', False), - "blur_glossy": getattr(cycles, 'blur_glossy', 0.0), + + # Advanced/Seed settings + "seed": getattr(cycles, 'seed', 0), + "use_animated_seed": getattr(cycles, 'use_animated_seed', False), + "sampling_pattern": getattr(cycles, 'sampling_pattern', 'AUTOMATIC'), + "scrambling_distance": getattr(cycles, 'scrambling_distance', 1.0), + "auto_scrambling_distance_multiplier": getattr(cycles, 'auto_scrambling_distance_multiplier', 1.0), + "preview_scrambling_distance": getattr(cycles, 'preview_scrambling_distance', False), + "min_light_bounces": getattr(cycles, 'min_light_bounces', 0), + "min_transparent_bounces": getattr(cycles, 'min_transparent_bounces', 0), + + # Clamping + "sample_clamp_direct": getattr(cycles, 'sample_clamp_direct', 0.0), + "sample_clamp_indirect": getattr(cycles, 'sample_clamp_indirect', 0.0), + + # Light Paths / Bounces "max_bounces": getattr(cycles, 'max_bounces', 12), "diffuse_bounces": getattr(cycles, 'diffuse_bounces', 4), "glossy_bounces": getattr(cycles, 'glossy_bounces', 4), "transmission_bounces": getattr(cycles, 'transmission_bounces', 12), "volume_bounces": getattr(cycles, 'volume_bounces', 0), "transparent_max_bounces": getattr(cycles, 'transparent_max_bounces', 8), + + # Caustics + "caustics_reflective": getattr(cycles, 'caustics_reflective', False), + "caustics_refractive": getattr(cycles, 'caustics_refractive', False), + "blur_glossy": getattr(cycles, 'blur_glossy', 0.0), # Filter Glossy + + # Fast GI Approximation + "use_fast_gi": getattr(cycles, 'use_fast_gi', False), + "fast_gi_method": getattr(cycles, 'fast_gi_method', 'REPLACE'), # REPLACE or ADD + "ao_bounces": getattr(cycles, 'ao_bounces', 1), # Viewport bounces + "ao_bounces_render": getattr(cycles, 'ao_bounces_render', 1), # Render bounces + + # Volumes + "volume_step_rate": getattr(cycles, 'volume_step_rate', 1.0), + "volume_preview_step_rate": getattr(cycles, 'volume_preview_step_rate', 1.0), + "volume_max_steps": getattr(cycles, 'volume_max_steps', 1024), + + # Film + "film_exposure": getattr(cycles, 'film_exposure', 1.0), "film_transparent": getattr(cycles, 'film_transparent', False), + "film_transparent_glass": getattr(cycles, 'film_transparent_glass', False), + "film_transparent_roughness": getattr(cycles, 'film_transparent_roughness', 0.1), + "filter_type": getattr(cycles, 'filter_type', 'BLACKMAN_HARRIS'), # BOX, GAUSSIAN, BLACKMAN_HARRIS + "filter_width": getattr(cycles, 'filter_width', 1.5), + "pixel_filter_type": getattr(cycles, 'pixel_filter_type', 'BLACKMAN_HARRIS'), + + # Performance + "use_auto_tile": getattr(cycles, 'use_auto_tile', True), + "tile_size": getattr(cycles, 'tile_size', 2048), + "use_persistent_data": getattr(cycles, 'use_persistent_data', False), + + # Hair/Curves + "use_hair": getattr(cycles, 'use_hair', True), + "hair_subdivisions": getattr(cycles, 'hair_subdivisions', 2), + "hair_shape": getattr(cycles, 'hair_shape', 'THICK'), # ROUND, RIBBONS, THICK + + # Simplify (from scene.render) + "use_simplify": getattr(scene.render, 'use_simplify', False), + "simplify_subdivision_render": getattr(scene.render, 'simplify_subdivision_render', 6), + "simplify_child_particles_render": getattr(scene.render, 'simplify_child_particles_render', 1.0), + + # Other + "use_light_linking": getattr(cycles, 'use_light_linking', False), "use_layer_samples": getattr(cycles, 'use_layer_samples', False), } elif engine == 'EEVEE' or engine == 'EEVEE_NEXT': + # Treat EEVEE_NEXT as EEVEE (modern Blender uses EEVEE for what was EEVEE_NEXT) eevee = scene.eevee engine_settings = { + # Sampling "taa_render_samples": getattr(eevee, 'taa_render_samples', 64), + "taa_samples": getattr(eevee, 'taa_samples', 16), # Viewport samples + "use_taa_reprojection": getattr(eevee, 'use_taa_reprojection', True), + + # Clamping + "clamp_surface_direct": getattr(eevee, 'clamp_surface_direct', 0.0), + "clamp_surface_indirect": getattr(eevee, 'clamp_surface_indirect', 0.0), + "clamp_volume_direct": getattr(eevee, 'clamp_volume_direct', 0.0), + "clamp_volume_indirect": getattr(eevee, 'clamp_volume_indirect', 0.0), + + # Shadows + "shadow_cube_size": getattr(eevee, 'shadow_cube_size', '512'), + "shadow_cascade_size": getattr(eevee, 'shadow_cascade_size', '1024'), + "use_shadow_high_bitdepth": getattr(eevee, 'use_shadow_high_bitdepth', False), + "use_soft_shadows": getattr(eevee, 'use_soft_shadows', True), + "light_threshold": getattr(eevee, 'light_threshold', 0.01), + + # Raytracing (EEVEE Next / modern EEVEE) + "use_raytracing": getattr(eevee, 'use_raytracing', False), + "ray_tracing_method": getattr(eevee, 'ray_tracing_method', 'SCREEN'), # SCREEN or PROBE + "ray_tracing_options_trace_max_roughness": getattr(eevee, 'ray_tracing_options', {}).get('trace_max_roughness', 0.5) if hasattr(getattr(eevee, 'ray_tracing_options', None), 'get') else 0.5, + + # Screen Space Reflections (legacy/fallback) + "use_ssr": getattr(eevee, 'use_ssr', False), + "use_ssr_refraction": getattr(eevee, 'use_ssr_refraction', False), + "use_ssr_halfres": getattr(eevee, 'use_ssr_halfres', True), + "ssr_quality": getattr(eevee, 'ssr_quality', 0.25), + "ssr_max_roughness": getattr(eevee, 'ssr_max_roughness', 0.5), + "ssr_thickness": getattr(eevee, 'ssr_thickness', 0.2), + "ssr_border_fade": getattr(eevee, 'ssr_border_fade', 0.075), + "ssr_firefly_fac": getattr(eevee, 'ssr_firefly_fac', 10.0), + + # Ambient Occlusion + "use_gtao": getattr(eevee, 'use_gtao', False), + "gtao_distance": getattr(eevee, 'gtao_distance', 0.2), + "gtao_factor": getattr(eevee, 'gtao_factor', 1.0), + "gtao_quality": getattr(eevee, 'gtao_quality', 0.25), + "use_gtao_bent_normals": getattr(eevee, 'use_gtao_bent_normals', True), + "use_gtao_bounce": getattr(eevee, 'use_gtao_bounce', True), + + # Bloom "use_bloom": getattr(eevee, 'use_bloom', False), "bloom_threshold": getattr(eevee, 'bloom_threshold', 0.8), - "bloom_intensity": getattr(eevee, 'bloom_intensity', 0.05), + "bloom_knee": getattr(eevee, 'bloom_knee', 0.5), "bloom_radius": getattr(eevee, 'bloom_radius', 6.5), - "use_ssr": getattr(eevee, 'use_ssr', True), - "use_ssr_refraction": getattr(eevee, 'use_ssr_refraction', False), - "ssr_quality": getattr(eevee, 'ssr_quality', 'MEDIUM'), - "use_ssao": getattr(eevee, 'use_ssao', True), - "ssao_quality": getattr(eevee, 'ssao_quality', 'MEDIUM'), - "ssao_distance": getattr(eevee, 'ssao_distance', 0.2), - "ssao_factor": getattr(eevee, 'ssao_factor', 1.0), - "use_soft_shadows": getattr(eevee, 'use_soft_shadows', True), - "use_shadow_high_bitdepth": getattr(eevee, 'use_shadow_high_bitdepth', True), - "use_volumetric": getattr(eevee, 'use_volumetric', False), + "bloom_color": list(getattr(eevee, 'bloom_color', (1.0, 1.0, 1.0))), + "bloom_intensity": getattr(eevee, 'bloom_intensity', 0.05), + "bloom_clamp": getattr(eevee, 'bloom_clamp', 0.0), + + # Depth of Field + "bokeh_max_size": getattr(eevee, 'bokeh_max_size', 100.0), + "bokeh_threshold": getattr(eevee, 'bokeh_threshold', 1.0), + "bokeh_neighbor_max": getattr(eevee, 'bokeh_neighbor_max', 10.0), + "bokeh_denoise_fac": getattr(eevee, 'bokeh_denoise_fac', 0.75), + "use_bokeh_high_quality_slight_defocus": getattr(eevee, 'use_bokeh_high_quality_slight_defocus', False), + "use_bokeh_jittered": getattr(eevee, 'use_bokeh_jittered', False), + "bokeh_overblur": getattr(eevee, 'bokeh_overblur', 5.0), + + # Subsurface Scattering + "sss_samples": getattr(eevee, 'sss_samples', 7), + "sss_jitter_threshold": getattr(eevee, 'sss_jitter_threshold', 0.3), + + # Volumetrics + "use_volumetric_lights": getattr(eevee, 'use_volumetric_lights', True), + "use_volumetric_shadows": getattr(eevee, 'use_volumetric_shadows', False), + "volumetric_start": getattr(eevee, 'volumetric_start', 0.1), + "volumetric_end": getattr(eevee, 'volumetric_end', 100.0), "volumetric_tile_size": getattr(eevee, 'volumetric_tile_size', '8'), "volumetric_samples": getattr(eevee, 'volumetric_samples', 64), - "volumetric_start": getattr(eevee, 'volumetric_start', 0.0), - "volumetric_end": getattr(eevee, 'volumetric_end', 100.0), - "use_volumetric_lights": getattr(eevee, 'use_volumetric_lights', True), - "use_volumetric_shadows": getattr(eevee, 'use_volumetric_shadows', True), - "use_gtao": getattr(eevee, 'use_gtao', False), - "gtao_quality": getattr(eevee, 'gtao_quality', 'MEDIUM'), + "volumetric_sample_distribution": getattr(eevee, 'volumetric_sample_distribution', 0.8), + "volumetric_ray_depth": getattr(eevee, 'volumetric_ray_depth', 16), + + # Motion Blur + "use_motion_blur": getattr(eevee, 'use_motion_blur', False), + "motion_blur_position": getattr(eevee, 'motion_blur_position', 'CENTER'), + "motion_blur_shutter": getattr(eevee, 'motion_blur_shutter', 0.5), + "motion_blur_depth_scale": getattr(eevee, 'motion_blur_depth_scale', 100.0), + "motion_blur_max": getattr(eevee, 'motion_blur_max', 32), + "motion_blur_steps": getattr(eevee, 'motion_blur_steps', 1), + + # Film "use_overscan": getattr(eevee, 'use_overscan', False), + "overscan_size": getattr(eevee, 'overscan_size', 3.0), + + # Indirect Lighting + "gi_diffuse_bounces": getattr(eevee, 'gi_diffuse_bounces', 3), + "gi_cubemap_resolution": getattr(eevee, 'gi_cubemap_resolution', '512'), + "gi_visibility_resolution": getattr(eevee, 'gi_visibility_resolution', '32'), + "gi_irradiance_smoothing": getattr(eevee, 'gi_irradiance_smoothing', 0.1), + "gi_glossy_clamp": getattr(eevee, 'gi_glossy_clamp', 0.0), + "gi_filter_quality": getattr(eevee, 'gi_filter_quality', 3.0), + "gi_show_irradiance": getattr(eevee, 'gi_show_irradiance', False), + "gi_show_cubemaps": getattr(eevee, 'gi_show_cubemaps', False), + "gi_auto_bake": getattr(eevee, 'gi_auto_bake', False), + + # Hair/Curves + "hair_type": getattr(eevee, 'hair_type', 'STRIP'), # STRIP or STRAND + + # Performance + "use_shadow_jitter_viewport": getattr(eevee, 'use_shadow_jitter_viewport', True), + + # Simplify (from scene.render) + "use_simplify": getattr(scene.render, 'use_simplify', False), + "simplify_subdivision_render": getattr(scene.render, 'simplify_subdivision_render', 6), + "simplify_child_particles_render": getattr(scene.render, 'simplify_child_particles_render', 1.0), } else: # For other engines, extract basic samples if available @@ -149,10 +338,20 @@ camera_count = len([obj for obj in scene.objects if obj.type == 'CAMERA']) object_count = len(scene.objects) material_count = len(bpy.data.materials) +# Extract Blender version info +# bpy.app.version gives the current running Blender version +# For the file's saved version, we check bpy.data.version (version the file was saved with) +blender_version = { + "current": bpy.app.version_string, # Version of Blender running this script + "file_saved_with": ".".join(map(str, bpy.data.version)) if hasattr(bpy.data, 'version') else None, # Version file was saved with +} + # Build metadata dictionary metadata = { "frame_start": frame_start, "frame_end": frame_end, + "has_negative_frames": has_negative_start or has_negative_end or has_negative_animation, + "blender_version": blender_version, "render_settings": { "resolution_x": resolution_x, "resolution_y": resolution_y, diff --git a/pkg/scripts/scripts/render_blender.py.template b/pkg/scripts/scripts/render_blender.py.template index a5af2a8..ed13a3e 100644 --- a/pkg/scripts/scripts/render_blender.py.template +++ b/pkg/scripts/scripts/render_blender.py.template @@ -338,9 +338,27 @@ if current_engine == 'CYCLES': if gpu_available: scene.cycles.device = 'GPU' print(f"Using GPU for rendering (blend file had: {current_device})") + + # Auto-enable GPU denoising when using GPU (OpenImageDenoise supports all GPUs) + try: + view_layer = bpy.context.view_layer + if hasattr(view_layer, 'cycles') and hasattr(view_layer.cycles, 'denoising_use_gpu'): + view_layer.cycles.denoising_use_gpu = True + print("Auto-enabled GPU denoising (OpenImageDenoise)") + except Exception as e: + print(f"Could not auto-enable GPU denoising: {e}") else: scene.cycles.device = 'CPU' print(f"GPU not available, using CPU for rendering (blend file had: {current_device})") + + # Ensure GPU denoising is disabled when using CPU + try: + view_layer = bpy.context.view_layer + if hasattr(view_layer, 'cycles') and hasattr(view_layer.cycles, 'denoising_use_gpu'): + view_layer.cycles.denoising_use_gpu = False + print("Using CPU denoising") + except Exception as e: + pass # Verify device setting if current_engine == 'CYCLES': diff --git a/pkg/types/types.go b/pkg/types/types.go index d46e11c..301dda5 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -227,8 +227,9 @@ type TaskLogEntry struct { // BlendMetadata represents extracted metadata from a blend file type BlendMetadata struct { - FrameStart int `json:"frame_start"` - FrameEnd int `json:"frame_end"` + FrameStart int `json:"frame_start"` + FrameEnd int `json:"frame_end"` + HasNegativeFrames bool `json:"has_negative_frames"` // True if blend file has negative frame numbers (not supported) RenderSettings RenderSettings `json:"render_settings"` SceneInfo SceneInfo `json:"scene_info"` MissingFilesInfo *MissingFilesInfo `json:"missing_files_info,omitempty"` diff --git a/web/embed.go b/web/embed.go new file mode 100644 index 0000000..f1786ba --- /dev/null +++ b/web/embed.go @@ -0,0 +1,45 @@ +package web + +import ( + "embed" + "io/fs" + "net/http" + "strings" +) + +//go:embed dist/* +var distFS embed.FS + +// GetFileSystem returns an http.FileSystem for the embedded web UI files +func GetFileSystem() http.FileSystem { + subFS, err := fs.Sub(distFS, "dist") + if err != nil { + panic(err) + } + return http.FS(subFS) +} + +// SPAHandler returns an http.Handler that serves the embedded SPA +// It serves static files if they exist, otherwise falls back to index.html +func SPAHandler() http.Handler { + fsys := GetFileSystem() + fileServer := http.FileServer(fsys) + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + + // Try to open the file + f, err := fsys.Open(strings.TrimPrefix(path, "/")) + if err != nil { + // File doesn't exist, serve index.html for SPA routing + r.URL.Path = "/" + fileServer.ServeHTTP(w, r) + return + } + f.Close() + + // File exists, serve it + fileServer.ServeHTTP(w, r) + }) +} + diff --git a/web/src/App.jsx b/web/src/App.jsx index fb95691..abd3d76 100644 --- a/web/src/App.jsx +++ b/web/src/App.jsx @@ -5,6 +5,8 @@ import Layout from './components/Layout'; import JobList from './components/JobList'; import JobSubmission from './components/JobSubmission'; import AdminPanel from './components/AdminPanel'; +import ErrorBoundary from './components/ErrorBoundary'; +import LoadingSpinner from './components/LoadingSpinner'; import './styles/index.css'; function App() { @@ -17,7 +19,7 @@ function App() { if (loading) { return (
-
+
); } @@ -26,26 +28,20 @@ function App() { return loginComponent; } - // Wrapper to check auth before changing tabs - const handleTabChange = async (newTab) => { - // Check auth before allowing navigation - try { - await refresh(); - // If refresh succeeds, user is still authenticated - setActiveTab(newTab); - } catch (error) { - // Auth check failed, user will be set to null and login will show - console.error('Auth check failed on navigation:', error); - } + // Wrapper to change tabs - only check auth on mount, not on every navigation + const handleTabChange = (newTab) => { + setActiveTab(newTab); }; return ( - {activeTab === 'jobs' && } - {activeTab === 'submit' && ( - handleTabChange('jobs')} /> - )} - {activeTab === 'admin' && } + + {activeTab === 'jobs' && } + {activeTab === 'submit' && ( + handleTabChange('jobs')} /> + )} + {activeTab === 'admin' && } + ); } diff --git a/web/src/components/AdminPanel.jsx b/web/src/components/AdminPanel.jsx index 92603d5..11d9ec0 100644 --- a/web/src/components/AdminPanel.jsx +++ b/web/src/components/AdminPanel.jsx @@ -1,7 +1,9 @@ -import { useState, useEffect } from 'react'; -import { admin } from '../utils/api'; +import { useState, useEffect, useRef } from 'react'; +import { admin, jobs, normalizeArrayResponse } from '../utils/api'; +import { wsManager } from '../utils/websocket'; import UserJobs from './UserJobs'; import PasswordChange from './PasswordChange'; +import LoadingSpinner from './LoadingSpinner'; export default function AdminPanel() { const [activeSection, setActiveSection] = useState('api-keys'); @@ -16,16 +18,110 @@ export default function AdminPanel() { const [selectedUser, setSelectedUser] = useState(null); const [registrationEnabled, setRegistrationEnabled] = useState(true); const [passwordChangeUser, setPasswordChangeUser] = useState(null); + const listenerIdRef = useRef(null); // Listener ID for shared WebSocket + const subscribedChannelsRef = useRef(new Set()); // Track confirmed subscribed channels + const pendingSubscriptionsRef = useRef(new Set()); // Track pending subscriptions (waiting for confirmation) + + // Connect to shared WebSocket on mount + useEffect(() => { + listenerIdRef.current = wsManager.subscribe('adminpanel', { + open: () => { + console.log('AdminPanel: Shared WebSocket connected'); + // Subscribe to runners if already viewing runners section + if (activeSection === 'runners') { + subscribeToRunners(); + } + }, + message: (data) => { + // Handle subscription responses + if (data.type === 'subscribed' && data.channel) { + pendingSubscriptionsRef.current.delete(data.channel); + subscribedChannelsRef.current.add(data.channel); + console.log('Successfully subscribed to channel:', data.channel); + } else if (data.type === 'subscription_error' && data.channel) { + pendingSubscriptionsRef.current.delete(data.channel); + subscribedChannelsRef.current.delete(data.channel); + console.error('Subscription failed for channel:', data.channel, data.error); + } + + // Handle runners channel messages + if (data.channel === 'runners' && data.type === 'runner_status') { + // Update runner in list + setRunners(prev => { + const index = prev.findIndex(r => r.id === data.runner_id); + if (index >= 0 && data.data) { + const updated = [...prev]; + updated[index] = { ...updated[index], ...data.data }; + return updated; + } + return prev; + }); + } + }, + error: (error) => { + console.error('AdminPanel: Shared WebSocket error:', error); + }, + close: (event) => { + console.log('AdminPanel: Shared WebSocket closed:', event); + subscribedChannelsRef.current.clear(); + pendingSubscriptionsRef.current.clear(); + } + }); + + // Ensure connection is established + wsManager.connect(); + + return () => { + // Unsubscribe from all channels before unmounting + unsubscribeFromRunners(); + if (listenerIdRef.current) { + wsManager.unsubscribe(listenerIdRef.current); + listenerIdRef.current = null; + } + }; + }, []); + + const subscribeToRunners = () => { + const channel = 'runners'; + if (wsManager.getReadyState() !== WebSocket.OPEN) { + return; + } + // Don't subscribe if already subscribed or pending + if (subscribedChannelsRef.current.has(channel) || pendingSubscriptionsRef.current.has(channel)) { + return; + } + wsManager.send({ type: 'subscribe', channel }); + pendingSubscriptionsRef.current.add(channel); + console.log('Subscribing to runners channel'); + }; + + const unsubscribeFromRunners = () => { + const channel = 'runners'; + if (wsManager.getReadyState() !== WebSocket.OPEN) { + return; + } + if (!subscribedChannelsRef.current.has(channel)) { + return; // Not subscribed + } + wsManager.send({ type: 'unsubscribe', channel }); + subscribedChannelsRef.current.delete(channel); + pendingSubscriptionsRef.current.delete(channel); + console.log('Unsubscribed from runners channel'); + }; useEffect(() => { if (activeSection === 'api-keys') { loadAPIKeys(); + unsubscribeFromRunners(); } else if (activeSection === 'runners') { loadRunners(); + subscribeToRunners(); } else if (activeSection === 'users') { loadUsers(); + unsubscribeFromRunners(); } else if (activeSection === 'settings') { loadSettings(); + unsubscribeFromRunners(); } }, [activeSection]); @@ -33,7 +129,7 @@ export default function AdminPanel() { setLoading(true); try { const data = await admin.listAPIKeys(); - setApiKeys(Array.isArray(data) ? data : []); + setApiKeys(normalizeArrayResponse(data)); } catch (error) { console.error('Failed to load API keys:', error); setApiKeys([]); @@ -47,7 +143,7 @@ export default function AdminPanel() { setLoading(true); try { const data = await admin.listRunners(); - setRunners(Array.isArray(data) ? data : []); + setRunners(normalizeArrayResponse(data)); } catch (error) { console.error('Failed to load runners:', error); setRunners([]); @@ -61,7 +157,7 @@ export default function AdminPanel() { setLoading(true); try { const data = await admin.listUsers(); - setUsers(Array.isArray(data) ? data : []); + setUsers(normalizeArrayResponse(data)); } catch (error) { console.error('Failed to load users:', error); setUsers([]); @@ -121,29 +217,22 @@ export default function AdminPanel() { } }; - const revokeAPIKey = async (keyId) => { - if (!confirm('Are you sure you want to revoke this API key? Revoked keys cannot be used for new runner registrations.')) { - return; - } - try { - await admin.revokeAPIKey(keyId); - await loadAPIKeys(); - } catch (error) { - console.error('Failed to revoke API key:', error); - alert('Failed to revoke API key'); - } - }; + const [deletingKeyId, setDeletingKeyId] = useState(null); + const [deletingRunnerId, setDeletingRunnerId] = useState(null); - const deleteAPIKey = async (keyId) => { - if (!confirm('Are you sure you want to permanently delete this API key? This action cannot be undone.')) { + const revokeAPIKey = async (keyId) => { + if (!confirm('Are you sure you want to delete this API key? This action cannot be undone.')) { return; } + setDeletingKeyId(keyId); try { await admin.deleteAPIKey(keyId); await loadAPIKeys(); } catch (error) { console.error('Failed to delete API key:', error); alert('Failed to delete API key'); + } finally { + setDeletingKeyId(null); } }; @@ -152,12 +241,15 @@ export default function AdminPanel() { if (!confirm('Are you sure you want to delete this runner?')) { return; } + setDeletingRunnerId(runnerId); try { await admin.deleteRunner(runnerId); await loadRunners(); } catch (error) { console.error('Failed to delete runner:', error); alert('Failed to delete runner'); + } finally { + setDeletingRunnerId(null); } }; @@ -313,9 +405,7 @@ export default function AdminPanel() {

API Keys

{loading ? ( -
-
-
+ ) : !apiKeys || apiKeys.length === 0 ? (

No API keys generated yet.

) : ( @@ -384,21 +474,13 @@ export default function AdminPanel() { {new Date(key.created_at).toLocaleString()} - {key.is_active && !expired && ( - - )} @@ -416,9 +498,7 @@ export default function AdminPanel() {

Runner Management

{loading ? ( -
-
-
+ ) : !runners || runners.length === 0 ? (

No runners registered.

) : ( @@ -506,9 +586,10 @@ export default function AdminPanel() { @@ -558,9 +639,7 @@ export default function AdminPanel() {

User Management

{loading ? ( -
-
-
+ ) : !users || users.length === 0 ? (

No users found.

) : ( diff --git a/web/src/components/ErrorBoundary.jsx b/web/src/components/ErrorBoundary.jsx new file mode 100644 index 0000000..f090aae --- /dev/null +++ b/web/src/components/ErrorBoundary.jsx @@ -0,0 +1,41 @@ +import React from 'react'; + +class ErrorBoundary extends React.Component { + constructor(props) { + super(props); + this.state = { hasError: false, error: null }; + } + + static getDerivedStateFromError(error) { + return { hasError: true, error }; + } + + componentDidCatch(error, errorInfo) { + console.error('ErrorBoundary caught an error:', error, errorInfo); + } + + render() { + if (this.state.hasError) { + return ( +
+

Something went wrong

+

{this.state.error?.message || 'An unexpected error occurred'}

+ +
+ ); + } + + return this.props.children; + } +} + +export default ErrorBoundary; + diff --git a/web/src/components/ErrorMessage.jsx b/web/src/components/ErrorMessage.jsx new file mode 100644 index 0000000..c7bc9c6 --- /dev/null +++ b/web/src/components/ErrorMessage.jsx @@ -0,0 +1,26 @@ +import React from 'react'; + +/** + * Shared ErrorMessage component for consistent error display + * Sanitizes error messages to prevent XSS + */ +export default function ErrorMessage({ error, className = '' }) { + if (!error) return null; + + // Sanitize error message - escape HTML entities + const sanitize = (text) => { + const div = document.createElement('div'); + div.textContent = text; + return div.innerHTML; + }; + + const sanitizedError = typeof error === 'string' ? sanitize(error) : sanitize(error.message || 'An error occurred'); + + return ( +
+

Error:

+

+

+ ); +} + diff --git a/web/src/components/JobDetails.jsx b/web/src/components/JobDetails.jsx index 73228fe..a8d850b 100644 --- a/web/src/components/JobDetails.jsx +++ b/web/src/components/JobDetails.jsx @@ -1,7 +1,10 @@ import { useState, useEffect, useRef } from 'react'; import { jobs, REQUEST_SUPERSEDED } from '../utils/api'; +import { wsManager } from '../utils/websocket'; import VideoPlayer from './VideoPlayer'; import FileExplorer from './FileExplorer'; +import ErrorMessage from './ErrorMessage'; +import LoadingSpinner from './LoadingSpinner'; export default function JobDetails({ job, onClose, onUpdate }) { const [jobDetails, setJobDetails] = useState(job); @@ -17,80 +20,66 @@ export default function JobDetails({ job, onClose, onUpdate }) { const [expandedSteps, setExpandedSteps] = useState(new Set()); const [streaming, setStreaming] = useState(false); const [previewImage, setPreviewImage] = useState(null); // { url, fileName } or null - const wsRef = useRef(null); - const jobWsRef = useRef(null); // Separate ref for job WebSocket + const listenerIdRef = useRef(null); // Listener ID for shared WebSocket + const subscribedChannelsRef = useRef(new Set()); // Track confirmed subscribed channels + const pendingSubscriptionsRef = useRef(new Set()); // Track pending subscriptions (waiting for confirmation) const logContainerRefs = useRef({}); // Refs for each step's log container const shouldAutoScrollRefs = useRef({}); // Auto-scroll state per step + const abortControllerRef = useRef(null); // AbortController for HTTP requests + + // Sync job prop to state when it changes + useEffect(() => { + setJobDetails(job); + }, [job.id, job.status, job.progress]); useEffect(() => { + // Create new AbortController for this effect + abortControllerRef.current = new AbortController(); + loadDetails(); - // Use WebSocket for real-time updates instead of polling - if (jobDetails.status === 'running' || jobDetails.status === 'pending' || !jobDetails.status) { - connectJobWebSocket(); - return () => { - if (jobWsRef.current) { - try { - jobWsRef.current.close(); - } catch (e) { - // Ignore errors when closing - } - jobWsRef.current = null; - } - if (wsRef.current) { - try { - wsRef.current.close(); - } catch (e) { - // Ignore errors when closing - } - wsRef.current = null; - } - }; - } else { - // Job is completed/failed/cancelled - close WebSocket - if (jobWsRef.current) { - try { - jobWsRef.current.close(); - } catch (e) { - // Ignore errors when closing - } - jobWsRef.current = null; + // Use shared WebSocket manager for real-time updates + listenerIdRef.current = wsManager.subscribe(`jobdetails_${job.id}`, { + open: () => { + console.log('JobDetails: Shared WebSocket connected for job', job.id); + // Subscribe to job channel + subscribe(`job:${job.id}`); + }, + message: (data) => { + handleWebSocketMessage(data); + }, + error: (error) => { + console.error('JobDetails: Shared WebSocket error:', error); + }, + close: (event) => { + console.log('JobDetails: Shared WebSocket closed:', event); + subscribedChannelsRef.current.clear(); + pendingSubscriptionsRef.current.clear(); } - if (wsRef.current) { - try { - wsRef.current.close(); - } catch (e) { - // Ignore errors when closing - } - wsRef.current = null; - } - } - }, [job.id, jobDetails.status]); + }); + + // Ensure connection is established + wsManager.connect(); - useEffect(() => { - // Load logs and steps for all running tasks - if (jobDetails.status === 'running' && tasks.length > 0) { - const runningTasks = tasks.filter(t => t.status === 'running' || t.status === 'pending'); - runningTasks.forEach(task => { - if (!taskData[task.id]) { - loadTaskData(task.id); - } - }); - // Start streaming for the first running task (WebSocket supports one at a time) - if (runningTasks.length > 0 && !streaming) { - startLogStream(runningTasks.map(t => t.id)); - } - } else if (wsRef.current) { - wsRef.current.close(); - wsRef.current = null; - setStreaming(false); - } return () => { - if (wsRef.current && jobDetails.status !== 'running') { - wsRef.current.close(); - wsRef.current = null; + // Cancel any pending HTTP requests + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + abortControllerRef.current = null; + } + // Unsubscribe from all channels + unsubscribeAll(); + if (listenerIdRef.current) { + wsManager.unsubscribe(listenerIdRef.current); + listenerIdRef.current = null; } }; - }, [tasks, jobDetails.status]); + }, [job.id]); + + useEffect(() => { + // Update log subscriptions based on expanded tasks (not steps) + updateLogSubscriptions(); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [expandedTasks, tasks.length, jobDetails.status]); // Use tasks.length instead of tasks to avoid unnecessary re-runs // Auto-scroll logs to bottom when new logs arrive useEffect(() => { @@ -119,11 +108,17 @@ export default function JobDetails({ job, onClose, onUpdate }) { try { setLoading(true); // Use summary endpoint for tasks initially - much faster + const signal = abortControllerRef.current?.signal; const [details, fileList, taskListResult] = await Promise.all([ - jobs.get(job.id), - jobs.getFiles(job.id, { limit: 50 }), // Only load first page of files - jobs.getTasksSummary(job.id, { limit: 100, sort: 'frame_start:asc' }), // Use summary endpoint + jobs.get(job.id, { signal }), + jobs.getFiles(job.id, { limit: 50, signal }), // Only load first page of files + jobs.getTasksSummary(job.id, { sort: 'frame_start:asc', signal }), // Get all tasks ]); + + // Check if request was aborted + if (signal?.aborted) { + return; + } setJobDetails(details); // Handle paginated file response - check for superseded sentinel @@ -159,9 +154,11 @@ export default function JobDetails({ job, onClose, onUpdate }) { // Fetch context archive contents separately (may not exist for old jobs) try { - const contextList = await jobs.getContextArchive(job.id); + const contextList = await jobs.getContextArchive(job.id, { signal }); + if (signal?.aborted) return; setContextFiles(contextList || []); } catch (error) { + if (signal?.aborted) return; // Context archive may not exist for old jobs setContextFiles([]); } @@ -205,11 +202,17 @@ export default function JobDetails({ job, onClose, onUpdate }) { const loadTaskData = async (taskId) => { try { console.log(`Loading task data for task ${taskId}...`); + const signal = abortControllerRef.current?.signal; const [logsResult, steps] = await Promise.all([ - jobs.getTaskLogs(job.id, taskId, { limit: 1000 }), // Increased limit for completed tasks - jobs.getTaskSteps(job.id, taskId), + jobs.getTaskLogs(job.id, taskId, { limit: 1000, signal }), // Increased limit for completed tasks + jobs.getTaskSteps(job.id, taskId, { signal }), ]); + // Check if request was aborted + if (signal?.aborted) { + return; + } + // Check for superseded sentinel if (logsResult === REQUEST_SUPERSEDED || steps === REQUEST_SUPERSEDED) { return; // Request was superseded, skip this update @@ -247,7 +250,14 @@ export default function JobDetails({ job, onClose, onUpdate }) { const loadTaskStepsOnly = async (taskId) => { try { - const steps = await jobs.getTaskSteps(job.id, taskId); + const signal = abortControllerRef.current?.signal; + const steps = await jobs.getTaskSteps(job.id, taskId, { signal }); + + // Check if request was aborted + if (signal?.aborted) { + return; + } + // Check for superseded sentinel if (steps === REQUEST_SUPERSEDED) { return; // Request was superseded, skip this update @@ -267,356 +277,416 @@ export default function JobDetails({ job, onClose, onUpdate }) { } }; - const connectJobWebSocket = () => { - try { - // Close existing connection if any - if (jobWsRef.current) { - try { - jobWsRef.current.close(); - } catch (e) { - // Ignore errors when closing - } - jobWsRef.current = null; - } - - const ws = jobs.streamJobWebSocket(job.id); - jobWsRef.current = ws; // Store reference - - ws.onopen = () => { - console.log('Job WebSocket connected for job', job.id); - }; - - ws.onmessage = (event) => { - try { - const data = JSON.parse(event.data); - console.log('Job WebSocket message received:', data.type, data); - - if (data.type === 'job_update' && data.data) { - // Update job details - setJobDetails(prev => ({ ...prev, ...data.data })); - } else if (data.type === 'task_update' && data.data) { - // Update task in list - setTasks(prev => { - // Ensure prev is always an array - const prevArray = Array.isArray(prev) ? prev : []; - if (!data.task_id) { - console.warn('task_update message missing task_id:', data); - return prevArray; - } - const index = prevArray.findIndex(t => t.id === data.task_id); - if (index >= 0) { - const updated = [...prevArray]; - updated[index] = { ...updated[index], ...data.data }; - return updated; - } - // If task not found, it might be a new task - reload to be safe - if (data.data && (data.data.status === 'running' || data.data.status === 'pending')) { - setTimeout(() => { - const reloadTasks = async () => { - try { - const taskListResult = await jobs.getTasksSummary(job.id, { limit: 100, sort: 'frame_start:asc' }); - // Check for superseded sentinel - if (taskListResult === REQUEST_SUPERSEDED) { - return; // Request was superseded, skip this update - } - const taskData = taskListResult.data || taskListResult; - const taskSummaries = Array.isArray(taskData) ? taskData : []; - const tasksForDisplay = taskSummaries.map(summary => ({ - id: summary.id, - job_id: job.id, - frame_start: summary.frame_start, - frame_end: summary.frame_end, - status: summary.status, - task_type: summary.task_type, - runner_id: summary.runner_id, - current_step: null, - retry_count: 0, - max_retries: 3, - created_at: new Date().toISOString(), - })); - setTasks(Array.isArray(tasksForDisplay) ? tasksForDisplay : []); - } catch (error) { - console.error('Failed to reload tasks:', error); - } - }; - reloadTasks(); - }, 100); - } - return prevArray; - }); - } else if (data.type === 'task_added' && data.data) { - // New task was added - reload task summaries to get the new task - console.log('task_added message received, reloading tasks...', data); - const reloadTasks = async () => { - try { - const taskListResult = await jobs.getTasksSummary(job.id, { limit: 100, sort: 'frame_start:asc' }); - // Check for superseded sentinel - if (taskListResult === REQUEST_SUPERSEDED) { - return; // Request was superseded, skip this update - } - const taskData = taskListResult.data || taskListResult; - const taskSummaries = Array.isArray(taskData) ? taskData : []; - const tasksForDisplay = taskSummaries.map(summary => ({ - id: summary.id, - job_id: job.id, - frame_start: summary.frame_start, - frame_end: summary.frame_end, - status: summary.status, - task_type: summary.task_type, - runner_id: summary.runner_id, - current_step: null, - retry_count: 0, - max_retries: 3, - created_at: new Date().toISOString(), - })); - setTasks(Array.isArray(tasksForDisplay) ? tasksForDisplay : []); - } catch (error) { - console.error('Failed to reload tasks:', error); - // Fallback to full reload - loadDetails(); - } - }; - reloadTasks(); - } else if (data.type === 'tasks_added' && data.data) { - // Multiple new tasks were added - reload task summaries - console.log('tasks_added message received, reloading tasks...', data); - const reloadTasks = async () => { - try { - const taskListResult = await jobs.getTasksSummary(job.id, { limit: 100, sort: 'frame_start:asc' }); - // Check for superseded sentinel - if (taskListResult === REQUEST_SUPERSEDED) { - return; // Request was superseded, skip this update - } - const taskData = taskListResult.data || taskListResult; - const taskSummaries = Array.isArray(taskData) ? taskData : []; - const tasksForDisplay = taskSummaries.map(summary => ({ - id: summary.id, - job_id: job.id, - frame_start: summary.frame_start, - frame_end: summary.frame_end, - status: summary.status, - task_type: summary.task_type, - runner_id: summary.runner_id, - current_step: null, - retry_count: 0, - max_retries: 3, - created_at: new Date().toISOString(), - })); - setTasks(Array.isArray(tasksForDisplay) ? tasksForDisplay : []); - } catch (error) { - console.error('Failed to reload tasks:', error); - // Fallback to full reload - loadDetails(); - } - }; - reloadTasks(); - } else if (data.type === 'file_added' && data.data) { - // New file was added - reload file list - const reloadFiles = async () => { - try { - const fileList = await jobs.getFiles(job.id, { limit: 50 }); - // Check for superseded sentinel - if (fileList === REQUEST_SUPERSEDED) { - return; // Request was superseded, skip this update - } - const fileData = fileList.data || fileList; - setFiles(Array.isArray(fileData) ? fileData : []); - } catch (error) { - console.error('Failed to reload files:', error); - } - }; - reloadFiles(); - } else if (data.type === 'step_update' && data.data && data.task_id) { - // Step was created or updated - update task data - console.log('step_update message received:', data); - setTaskData(prev => { - const taskId = data.task_id; - const current = prev[taskId] || { steps: [], logs: [] }; - const stepData = data.data; - - // Find if step already exists - const existingSteps = current.steps || []; - const stepIndex = existingSteps.findIndex(s => s.step_name === stepData.step_name); - - let updatedSteps; - if (stepIndex >= 0) { - // Update existing step - updatedSteps = [...existingSteps]; - updatedSteps[stepIndex] = { - ...updatedSteps[stepIndex], - ...stepData, - id: stepData.step_id || updatedSteps[stepIndex].id, - }; - } else { - // Add new step - updatedSteps = [...existingSteps, { - id: stepData.step_id, - step_name: stepData.step_name, - status: stepData.status, - duration_ms: stepData.duration_ms, - error_message: stepData.error_message, - }]; - } - - return { - ...prev, - [taskId]: { - ...current, - steps: updatedSteps, - } - }; - }); - } else if (data.type === 'connected') { - // Connection established - } - } catch (error) { - console.error('Failed to parse WebSocket message:', error); - } - }; - - ws.onerror = (error) => { - console.error('Job WebSocket error:', { - error, - readyState: ws.readyState, - url: ws.url, - jobId: job.id, - status: jobDetails.status - }); - // WebSocket errors don't provide much detail, but we can check readyState - if (ws.readyState === WebSocket.CLOSED || ws.readyState === WebSocket.CLOSING) { - console.warn('Job WebSocket is closed or closing, will attempt reconnect'); - } - }; - - ws.onclose = (event) => { - console.log('Job WebSocket closed:', { - code: event.code, - reason: event.reason, - wasClean: event.wasClean, - jobId: job.id, - status: jobDetails.status - }); - jobWsRef.current = null; - - // Code 1006 = Abnormal Closure (connection lost without close frame) - // Code 1000 = Normal Closure - // Code 1001 = Going Away (server restart, etc.) - // We should reconnect for abnormal closures (1006) or unexpected closes - const shouldReconnect = !event.wasClean || event.code === 1006 || event.code === 1001; - - // Get current status from state to avoid stale closure - const currentStatus = jobDetails.status; - const isActiveJob = currentStatus === 'running' || currentStatus === 'pending'; - - if (shouldReconnect && isActiveJob) { - console.log(`Attempting to reconnect job WebSocket in 2 seconds... (code: ${event.code})`); - setTimeout(() => { - // Check status again before reconnecting (might have changed) - // Use a ref or check the current state directly - if ((!jobWsRef.current || jobWsRef.current.readyState === WebSocket.CLOSED)) { - // Re-check if job is still active by reading current state - // We'll check this in connectJobWebSocket if needed - connectJobWebSocket(); - } - }, 2000); - } else if (!isActiveJob) { - console.log('Job is no longer active, not reconnecting WebSocket'); - } - }; - } catch (error) { - console.error('Failed to connect job WebSocket:', error); - } - }; - - const startLogStream = (taskIds) => { - if (taskIds.length === 0 || streaming) return; - - // Don't start streaming if job is no longer running - if (jobDetails.status !== 'running' && jobDetails.status !== 'pending') { - console.log('Job is not running, skipping log stream'); + const subscribe = (channel) => { + if (wsManager.getReadyState() !== WebSocket.OPEN) { return; } - - setStreaming(true); - // For now, stream the first task's logs (WebSocket supports one task at a time) - // In the future, we could have multiple WebSocket connections - const primaryTaskId = taskIds[0]; - const ws = jobs.streamTaskLogsWebSocket(job.id, primaryTaskId); - wsRef.current = ws; + // Don't subscribe if already subscribed or pending + if (subscribedChannelsRef.current.has(channel) || pendingSubscriptionsRef.current.has(channel)) { + return; // Already subscribed or subscription pending + } + wsManager.send({ type: 'subscribe', channel }); + pendingSubscriptionsRef.current.add(channel); // Mark as pending + }; - ws.onmessage = (event) => { - try { - const data = JSON.parse(event.data); + const unsubscribe = (channel) => { + if (wsManager.getReadyState() !== WebSocket.OPEN) { + return; + } + if (!subscribedChannelsRef.current.has(channel)) { + return; // Not subscribed + } + wsManager.send({ type: 'unsubscribe', channel }); + subscribedChannelsRef.current.delete(channel); + console.log('Unsubscribed from channel:', channel); + }; + + const unsubscribeAll = () => { + subscribedChannelsRef.current.forEach(channel => { + unsubscribe(channel); + }); + }; + + const updateLogSubscriptions = () => { + if (wsManager.getReadyState() !== WebSocket.OPEN) { + return; + } + + // Determine which log channels should be subscribed + const shouldSubscribe = new Set(); + const isRunning = jobDetails.status === 'running' || jobDetails.status === 'pending'; + + // Subscribe to logs when task is expanded (not when step is expanded) + if (isRunning) { + expandedTasks.forEach(taskId => { + const channel = `logs:${job.id}:${taskId}`; + shouldSubscribe.add(channel); + }); + } + + // Subscribe to new channels + shouldSubscribe.forEach(channel => { + subscribe(channel); + }); + + // Unsubscribe from channels that shouldn't be subscribed + subscribedChannelsRef.current.forEach(channel => { + if (channel.startsWith('logs:') && !shouldSubscribe.has(channel)) { + unsubscribe(channel); + } + }); + }; + + const handleWebSocketMessage = (data) => { + try { + console.log('JobDetails: Client WebSocket message received:', data.type, data.channel, data); + + // Handle subscription responses + if (data.type === 'subscribed' && data.channel) { + pendingSubscriptionsRef.current.delete(data.channel); // Remove from pending + subscribedChannelsRef.current.add(data.channel); // Add to confirmed + console.log('Successfully subscribed to channel:', data.channel, 'Total subscriptions:', subscribedChannelsRef.current.size); + } else if (data.type === 'subscription_error' && data.channel) { + pendingSubscriptionsRef.current.delete(data.channel); // Remove from pending + subscribedChannelsRef.current.delete(data.channel); // Remove from confirmed (if it was there) + console.error('Subscription failed for channel:', data.channel, data.error); + // If it's the job channel, this is a critical error + if (data.channel === `job:${job.id}`) { + console.error('Failed to subscribe to job channel - job may not exist or access denied'); + } + } + + // Handle job channel messages + // Check both explicit channel and job_id match (for backwards compatibility) + const isJobChannel = data.channel === `job:${job.id}` || + (data.job_id === job.id && !data.channel); + if (isJobChannel) { + console.log('Job channel message received:', data.type, data); + if (data.type === 'job_update' && data.data) { + // Update job details + console.log('Updating job details:', data.data); + setJobDetails(prev => { + const updated = { ...prev, ...data.data }; + console.log('Job details updated:', { + old_progress: prev.progress, + new_progress: updated.progress, + old_status: prev.status, + new_status: updated.status + }); + // Notify parent component of update + if (onUpdate) { + onUpdate(data.job_id || job.id, updated); + } + return updated; + }); + } else if (data.type === 'task_update') { + // Handle task_update - data.data contains the update fields + const taskId = data.task_id || (data.data && (data.data.id || data.data.task_id)); + console.log('Task update received:', { task_id: taskId, data: data.data, full_message: data }); + + if (!taskId) { + console.warn('task_update message missing task_id:', data); + return; + } + + if (!data.data) { + console.warn('task_update message missing data:', data); + return; + } + + // Update task in list + setTasks(prev => { + // Ensure prev is always an array + const prevArray = Array.isArray(prev) ? prev : []; + const index = prevArray.findIndex(t => t.id === taskId); + + if (index >= 0) { + // Task exists - update it + const updated = [...prevArray]; + const oldTask = updated[index]; + // Create a completely new task object to ensure React detects the change + const newTask = { + ...oldTask, + // Explicitly update each field from data.data to ensure changes are detected + status: data.data.status !== undefined ? data.data.status : oldTask.status, + runner_id: data.data.runner_id !== undefined ? data.data.runner_id : oldTask.runner_id, + started_at: data.data.started_at !== undefined ? data.data.started_at : oldTask.started_at, + completed_at: data.data.completed_at !== undefined ? data.data.completed_at : oldTask.completed_at, + error_message: data.data.error_message !== undefined ? data.data.error_message : oldTask.error_message, + output_path: data.data.output_path !== undefined ? data.data.output_path : oldTask.output_path, + current_step: data.data.current_step !== undefined ? data.data.current_step : oldTask.current_step, + // Merge any other fields + ...Object.keys(data.data).reduce((acc, key) => { + if (!['status', 'runner_id', 'started_at', 'completed_at', 'error_message', 'output_path', 'current_step'].includes(key)) { + acc[key] = data.data[key]; + } + return acc; + }, {}) + }; + updated[index] = newTask; + console.log('Updated task at index', index, { + task_id: taskId, + old_status: oldTask.status, + new_status: newTask.status, + old_runner_id: oldTask.runner_id, + new_runner_id: newTask.runner_id, + update_data: data.data, + full_new_task: newTask + }); + return updated; + } + // Task not found - check if data contains full task info (from initial state) + // Check both 'id' and 'task_id' fields + const taskIdFromData = data.data.id || data.data.task_id; + if (data.data && typeof data.data === 'object' && taskIdFromData && taskIdFromData === taskId) { + // This is a full task object from initial state - add it + console.log('Adding new task from initial state:', data.data); + return [...prevArray, { ...data.data, id: taskIdFromData }]; + } + // If task not found and it's a partial update, reload tasks to get the full list + console.log('Task not found in list, reloading tasks...'); + setTimeout(() => { + const reloadTasks = async () => { + try { + const signal = abortControllerRef.current?.signal; + const taskListResult = await jobs.getTasksSummary(job.id, { sort: 'frame_start:asc', signal }); + + // Check if request was aborted + if (signal?.aborted) { + return; + } + + // Check for superseded sentinel + if (taskListResult === REQUEST_SUPERSEDED) { + return; // Request was superseded, skip this update + } + const taskData = taskListResult.data || taskListResult; + const taskSummaries = Array.isArray(taskData) ? taskData : []; + const tasksForDisplay = taskSummaries.map(summary => ({ + id: summary.id, + job_id: job.id, + frame_start: summary.frame_start, + frame_end: summary.frame_end, + status: summary.status, + task_type: summary.task_type, + runner_id: summary.runner_id, + current_step: summary.current_step || null, + retry_count: summary.retry_count || 0, + max_retries: summary.max_retries || 3, + created_at: summary.created_at || new Date().toISOString(), + started_at: summary.started_at, + completed_at: summary.completed_at, + error_message: summary.error_message, + output_path: summary.output_path, + })); + setTasks(Array.isArray(tasksForDisplay) ? tasksForDisplay : []); + } catch (error) { + console.error('Failed to reload tasks:', error); + } + }; + reloadTasks(); + }, 100); + return prevArray; + }); + } else if (data.type === 'task_added' && data.data) { + // New task was added - reload task summaries to get the new task + console.log('task_added message received, reloading tasks...', data); + const reloadTasks = async () => { + try { + const signal = abortControllerRef.current?.signal; + const taskListResult = await jobs.getTasksSummary(job.id, { limit: 100, sort: 'frame_start:asc', signal }); + + // Check if request was aborted + if (signal?.aborted) { + return; + } + + // Check for superseded sentinel + if (taskListResult === REQUEST_SUPERSEDED) { + return; // Request was superseded, skip this update + } + const taskData = taskListResult.data || taskListResult; + const taskSummaries = Array.isArray(taskData) ? taskData : []; + const tasksForDisplay = taskSummaries.map(summary => ({ + id: summary.id, + job_id: job.id, + frame_start: summary.frame_start, + frame_end: summary.frame_end, + status: summary.status, + task_type: summary.task_type, + runner_id: summary.runner_id, + current_step: null, + retry_count: 0, + max_retries: 3, + created_at: new Date().toISOString(), + })); + setTasks(Array.isArray(tasksForDisplay) ? tasksForDisplay : []); + } catch (error) { + console.error('Failed to reload tasks:', error); + // Fallback to full reload + loadDetails(); + } + }; + reloadTasks(); + } else if (data.type === 'tasks_added' && data.data) { + // Multiple new tasks were added - reload task summaries + console.log('tasks_added message received, reloading tasks...', data); + const reloadTasks = async () => { + try { + const signal = abortControllerRef.current?.signal; + const taskListResult = await jobs.getTasksSummary(job.id, { limit: 100, sort: 'frame_start:asc', signal }); + + // Check if request was aborted + if (signal?.aborted) { + return; + } + + // Check for superseded sentinel + if (taskListResult === REQUEST_SUPERSEDED) { + return; // Request was superseded, skip this update + } + const taskData = taskListResult.data || taskListResult; + const taskSummaries = Array.isArray(taskData) ? taskData : []; + const tasksForDisplay = taskSummaries.map(summary => ({ + id: summary.id, + job_id: job.id, + frame_start: summary.frame_start, + frame_end: summary.frame_end, + status: summary.status, + task_type: summary.task_type, + runner_id: summary.runner_id, + current_step: null, + retry_count: 0, + max_retries: 3, + created_at: new Date().toISOString(), + })); + setTasks(Array.isArray(tasksForDisplay) ? tasksForDisplay : []); + } catch (error) { + console.error('Failed to reload tasks:', error); + // Fallback to full reload + loadDetails(); + } + }; + reloadTasks(); + } else if (data.type === 'file_added' && data.data) { + // New file was added - reload file list + const reloadFiles = async () => { + try { + const fileList = await jobs.getFiles(job.id, { limit: 50 }); + // Check for superseded sentinel + if (fileList === REQUEST_SUPERSEDED) { + return; // Request was superseded, skip this update + } + const fileData = fileList.data || fileList; + setFiles(Array.isArray(fileData) ? fileData : []); + } catch (error) { + console.error('Failed to reload files:', error); + } + }; + reloadFiles(); + } else if (data.type === 'step_update' && data.data && data.task_id) { + // Step was created or updated - update task data + console.log('step_update message received:', data); + setTaskData(prev => { + const taskId = data.task_id; + const current = prev[taskId] || { steps: [], logs: [] }; + const stepData = data.data; + + // Find if step already exists + const existingSteps = current.steps || []; + const stepIndex = existingSteps.findIndex(s => s.step_name === stepData.step_name); + + let updatedSteps; + if (stepIndex >= 0) { + // Update existing step + updatedSteps = [...existingSteps]; + updatedSteps[stepIndex] = { + ...updatedSteps[stepIndex], + ...stepData, + id: stepData.step_id || updatedSteps[stepIndex].id, + }; + } else { + // Add new step + updatedSteps = [...existingSteps, { + id: stepData.step_id, + step_name: stepData.step_name, + status: stepData.status, + duration_ms: stepData.duration_ms, + error_message: stepData.error_message, + }]; + } + + return { + ...prev, + [taskId]: { + ...current, + steps: updatedSteps, + } + }; + }); + } + } else if (data.channel && data.channel.startsWith('logs:')) { + // Handle log channel messages if (data.type === 'log' && data.data) { const log = data.data; + // Get task_id from log data or top-level message + const taskId = log.task_id || data.task_id; + if (!taskId) { + console.warn('Log message missing task_id:', data); + return; + } + console.log('Received log for task:', taskId, log); setTaskData(prev => { - const taskId = log.task_id; const current = prev[taskId] || { steps: [], logs: [] }; + + // If log has a step_name, ensure the step exists in the steps array + let updatedSteps = current.steps || []; + if (log.step_name) { + const stepExists = updatedSteps.some(s => s.step_name === log.step_name); + if (!stepExists) { + // Create placeholder step for logs that arrive before step_update + console.log('Creating placeholder step for:', log.step_name, 'in task:', taskId); + updatedSteps = [...updatedSteps, { + id: null, // Will be updated when step_update arrives + step_name: log.step_name, + status: 'running', // Default to running since we're receiving logs + duration_ms: null, + error_message: null, + }]; + } + } + // Check if log already exists (avoid duplicates) if (!current.logs.find(l => l.id === log.id)) { return { ...prev, [taskId]: { ...current, + steps: updatedSteps, logs: [...current.logs, log] } }; } - return prev; + // Even if log is duplicate, update steps if needed + return { + ...prev, + [taskId]: { + ...current, + steps: updatedSteps, + } + }; }); - } else if (data.type === 'connected') { - // Connection established } - } catch (error) { - console.error('Failed to parse log message:', error); + } else if (data.type === 'connected') { + // Connection established } - }; - - ws.onopen = () => { - console.log('Log WebSocket connected for task', primaryTaskId); - }; - - ws.onerror = (error) => { - console.error('Log WebSocket error:', { - error, - readyState: ws.readyState, - url: ws.url, - taskId: primaryTaskId, - jobId: job.id - }); - setStreaming(false); - }; - - ws.onclose = (event) => { - console.log('Log WebSocket closed:', { - code: event.code, - reason: event.reason, - wasClean: event.wasClean, - taskId: primaryTaskId, - jobId: job.id - }); - setStreaming(false); - wsRef.current = null; - - // Code 1006 = Abnormal Closure (connection lost without close frame) - // Code 1000 = Normal Closure - // Code 1001 = Going Away (server restart, etc.) - const shouldReconnect = !event.wasClean || event.code === 1006 || event.code === 1001; - - // Auto-reconnect if job is still running and close was unexpected - if (shouldReconnect && jobDetails.status === 'running' && taskIds.length > 0) { - console.log(`Attempting to reconnect log WebSocket in 2 seconds... (code: ${event.code})`); - setTimeout(() => { - // Check status again before reconnecting (might have changed) - // The startLogStream function will check if job is still running - if (jobDetails.status === 'running' && taskIds.length > 0) { - startLogStream(taskIds); - } - }, 2000); - } - }; + } catch (error) { + console.error('Failed to parse WebSocket message:', error); + } }; + // startLogStream is no longer needed - subscriptions are managed by updateLogSubscriptions + const toggleTask = async (taskId) => { const newExpanded = new Set(expandedTasks); if (newExpanded.has(taskId)) { @@ -629,10 +699,17 @@ export default function JobDetails({ job, onClose, onUpdate }) { if (currentTask && !currentTask.created_at) { // This is a summary - fetch full task details try { + const signal = abortControllerRef.current?.signal; const fullTasks = await jobs.getTasks(job.id, { limit: 1, + signal, // We can't filter by task ID, so we'll get all and find the one we need }); + + // Check if request was aborted + if (signal?.aborted) { + return; + } const taskData = fullTasks.data || fullTasks; const fullTask = Array.isArray(taskData) ? taskData.find(t => t.id === taskId) : null; if (fullTask) { @@ -834,11 +911,7 @@ export default function JobDetails({ job, onClose, onUpdate }) {
- {loading && ( -
-
-
- )} + {loading && } {!loading && ( <> @@ -850,7 +923,7 @@ export default function JobDetails({ job, onClose, onUpdate }) {

Progress

- {jobDetails.progress.toFixed(1)}% + {(jobDetails.progress || 0).toFixed(1)}%

@@ -911,12 +984,7 @@ export default function JobDetails({ job, onClose, onUpdate }) {
)} - {jobDetails.error_message && ( -
-

Error:

-

{jobDetails.error_message}

-
- )} +

diff --git a/web/src/components/JobList.jsx b/web/src/components/JobList.jsx index e3d4fae..6584b29 100644 --- a/web/src/components/JobList.jsx +++ b/web/src/components/JobList.jsx @@ -1,6 +1,8 @@ import { useState, useEffect, useRef } from 'react'; -import { jobs } from '../utils/api'; +import { jobs, normalizeArrayResponse } from '../utils/api'; +import { wsManager } from '../utils/websocket'; import JobDetails from './JobDetails'; +import LoadingSpinner from './LoadingSpinner'; export default function JobList() { const [jobList, setJobList] = useState([]); @@ -8,140 +10,75 @@ export default function JobList() { const [selectedJob, setSelectedJob] = useState(null); const [pagination, setPagination] = useState({ total: 0, limit: 50, offset: 0 }); const [hasMore, setHasMore] = useState(true); - const pollingIntervalRef = useRef(null); - const wsRef = useRef(null); + const listenerIdRef = useRef(null); useEffect(() => { loadJobs(); - // Use WebSocket for real-time updates instead of polling - connectWebSocket(); - return () => { - if (pollingIntervalRef.current) { - clearInterval(pollingIntervalRef.current); - } - if (wsRef.current) { - try { - wsRef.current.close(); - } catch (e) { - // Ignore errors when closing - } - wsRef.current = null; - } - }; - }, []); - - const connectWebSocket = () => { - try { - // Close existing connection if any - if (wsRef.current) { - try { - wsRef.current.close(); - } catch (e) { - // Ignore errors when closing - } - wsRef.current = null; - } - - const ws = jobs.streamJobsWebSocket(); - wsRef.current = ws; - - ws.onopen = () => { - console.log('Job list WebSocket connected'); - }; - - ws.onmessage = (event) => { - try { - const data = JSON.parse(event.data); + // Use shared WebSocket manager for real-time updates + listenerIdRef.current = wsManager.subscribe('joblist', { + open: () => { + console.log('JobList: Shared WebSocket connected'); + // Load initial job list via HTTP to get current state + loadJobs(); + }, + message: (data) => { + console.log('JobList: Client WebSocket message received:', data.type, data.channel, data); + // Handle jobs channel messages (always broadcasted) + if (data.channel === 'jobs') { if (data.type === 'job_update' && data.data) { + console.log('JobList: Updating job:', data.job_id, data.data); // Update job in list setJobList(prev => { - const index = prev.findIndex(j => j.id === data.job_id); + const prevArray = Array.isArray(prev) ? prev : []; + const index = prevArray.findIndex(j => j.id === data.job_id); if (index >= 0) { - const updated = [...prev]; + const updated = [...prevArray]; updated[index] = { ...updated[index], ...data.data }; + console.log('JobList: Updated job at index', index, updated[index]); return updated; } // If job not in current page, reload to get updated list if (data.data.status === 'completed' || data.data.status === 'failed') { loadJobs(); } - return prev; + return prevArray; + }); + } else if (data.type === 'job_created' && data.data) { + console.log('JobList: New job created:', data.job_id, data.data); + // New job created - add to list + setJobList(prev => { + const prevArray = Array.isArray(prev) ? prev : []; + // Check if job already exists (avoid duplicates) + if (prevArray.findIndex(j => j.id === data.job_id) >= 0) { + return prevArray; + } + // Add new job at the beginning + return [data.data, ...prevArray]; }); - } else if (data.type === 'connected') { - // Connection established } - } catch (error) { - console.error('Failed to parse WebSocket message:', error); + } else if (data.type === 'connected') { + // Connection established + console.log('JobList: WebSocket connected'); } - }; - - ws.onerror = (error) => { - console.error('Job list WebSocket error:', { - error, - readyState: ws.readyState, - url: ws.url - }); - // WebSocket errors don't provide much detail, but we can check readyState - if (ws.readyState === WebSocket.CLOSED || ws.readyState === WebSocket.CLOSING) { - console.warn('Job list WebSocket is closed or closing, will fallback to polling'); - // Fallback to polling on error - startAdaptivePolling(); - } - }; - - ws.onclose = (event) => { - console.log('Job list WebSocket closed:', { - code: event.code, - reason: event.reason, - wasClean: event.wasClean - }); - wsRef.current = null; - - // Code 1006 = Abnormal Closure (connection lost without close frame) - // Code 1000 = Normal Closure - // Code 1001 = Going Away (server restart, etc.) - // We should reconnect for abnormal closures (1006) or unexpected closes - const shouldReconnect = !event.wasClean || event.code === 1006 || event.code === 1001; - - if (shouldReconnect) { - console.log(`Attempting to reconnect job list WebSocket in 2 seconds... (code: ${event.code})`); - setTimeout(() => { - if (wsRef.current === null || (wsRef.current && wsRef.current.readyState === WebSocket.CLOSED)) { - connectWebSocket(); - } - }, 2000); - } else { - // Clean close (code 1000) - fallback to polling - console.log('WebSocket closed cleanly, falling back to polling'); - startAdaptivePolling(); - } - }; - } catch (error) { - console.error('Failed to connect WebSocket:', error); - // Fallback to polling - startAdaptivePolling(); - } - }; - - const startAdaptivePolling = () => { - const checkAndPoll = () => { - const hasRunningJobs = jobList.some(job => job.status === 'running' || job.status === 'pending'); - const interval = hasRunningJobs ? 5000 : 10000; // 5s for running, 10s for completed - - if (pollingIntervalRef.current) { - clearInterval(pollingIntervalRef.current); + }, + error: (error) => { + console.error('JobList: Shared WebSocket error:', error); + }, + close: (event) => { + console.log('JobList: Shared WebSocket closed:', event); + } + }); + + // Ensure connection is established + wsManager.connect(); + + return () => { + if (listenerIdRef.current) { + wsManager.unsubscribe(listenerIdRef.current); + listenerIdRef.current = null; } - - pollingIntervalRef.current = setInterval(() => { - loadJobs(); - }, interval); }; - - checkAndPoll(); - // Re-check interval when job list changes - const checkInterval = setInterval(checkAndPoll, 5000); - return () => clearInterval(checkInterval); - }; + }, []); const loadJobs = async (append = false) => { try { @@ -153,20 +90,27 @@ export default function JobList() { }); // Handle both old format (array) and new format (object with data, total, etc.) - const jobsData = result.data || result; - const total = result.total !== undefined ? result.total : jobsData.length; + const jobsArray = normalizeArrayResponse(result); + const total = result.total !== undefined ? result.total : jobsArray.length; if (append) { - setJobList(prev => [...prev, ...jobsData]); + setJobList(prev => { + const prevArray = Array.isArray(prev) ? prev : []; + return [...prevArray, ...jobsArray]; + }); setPagination(prev => ({ ...prev, offset, total })); } else { - setJobList(jobsData); + setJobList(jobsArray); setPagination({ total, limit: result.limit || pagination.limit, offset: result.offset || 0 }); } - setHasMore(offset + jobsData.length < total); + setHasMore(offset + jobsArray.length < total); } catch (error) { console.error('Failed to load jobs:', error); + // Ensure jobList is always an array even on error + if (!append) { + setJobList([]); + } } finally { setLoading(false); } @@ -206,12 +150,21 @@ export default function JobList() { const handleDelete = async (jobId) => { if (!confirm('Are you sure you want to permanently delete this job? This action cannot be undone.')) return; try { - await jobs.delete(jobId); - loadJobs(); + // Optimistically update the list + setJobList(prev => { + const prevArray = Array.isArray(prev) ? prev : []; + return prevArray.filter(j => j.id !== jobId); + }); if (selectedJob && selectedJob.id === jobId) { setSelectedJob(null); } + // Then actually delete + await jobs.delete(jobId); + // Reload to ensure consistency + loadJobs(); } catch (error) { + // On error, reload to restore correct state + loadJobs(); alert('Failed to delete job: ' + error.message); } }; @@ -228,11 +181,7 @@ export default function JobList() { }; if (loading && jobList.length === 0) { - return ( -
-
-
- ); + return ; } if (jobList.length === 0) { diff --git a/web/src/components/JobSubmission.jsx b/web/src/components/JobSubmission.jsx index 76389fe..44fcbda 100644 --- a/web/src/components/JobSubmission.jsx +++ b/web/src/components/JobSubmission.jsx @@ -1,6 +1,9 @@ import { useState, useEffect, useRef } from 'react'; import { jobs } from '../utils/api'; +import { wsManager } from '../utils/websocket'; import JobDetails from './JobDetails'; +import ErrorMessage from './ErrorMessage'; +import LoadingSpinner from './LoadingSpinner'; export default function JobSubmission({ onSuccess }) { const [step, setStep] = useState(1); // 1 = upload & extract metadata, 2 = missing addons (if any), 3 = configure & submit @@ -28,6 +31,7 @@ export default function JobSubmission({ onSuccess }) { const [blendFiles, setBlendFiles] = useState([]); // For ZIP files with multiple blend files const [selectedMainBlend, setSelectedMainBlend] = useState(''); const [confirmedMissingFiles, setConfirmedMissingFiles] = useState(false); // Confirmation for missing files + const [uploadTimeRemaining, setUploadTimeRemaining] = useState(null); // Estimated time remaining in seconds // Use refs to track cancellation state across re-renders const isCancelledRef = useRef(false); @@ -36,12 +40,156 @@ export default function JobSubmission({ onSuccess }) { const cleanupRef = useRef(null); const formatManuallyChangedRef = useRef(false); // Track if user manually changed output format const stepRef = useRef(step); // Track current step to avoid stale closures - + const uploadStartTimeRef = useRef(null); // Track when upload started + const listenerIdRef = useRef(null); // Listener ID for shared WebSocket + const subscribedChannelsRef = useRef(new Set()); // Track confirmed subscribed channels + const pendingSubscriptionsRef = useRef(new Set()); // Track pending subscriptions (waiting for confirmation) + // Keep stepRef in sync with step state useEffect(() => { stepRef.current = step; }, [step]); + // Helper function to format time remaining + const formatTimeRemaining = (seconds) => { + if (!seconds || seconds < 0 || !isFinite(seconds)) return null; + + if (seconds < 60) { + return `${Math.round(seconds)}s`; + } else if (seconds < 3600) { + const mins = Math.floor(seconds / 60); + const secs = Math.round(seconds % 60); + return `${mins}m ${secs}s`; + } else if (seconds < 86400) { + const hours = Math.floor(seconds / 3600); + const mins = Math.floor((seconds % 3600) / 60); + return `${hours}h ${mins}m`; + } else { + const days = Math.floor(seconds / 86400); + const hours = Math.floor((seconds % 86400) / 3600); + const mins = Math.floor((seconds % 3600) / 60); + return `${days}d ${hours}h ${mins}m`; + } + }; + + // Connect to shared WebSocket on mount + useEffect(() => { + listenerIdRef.current = wsManager.subscribe('jobsubmission', { + open: () => { + console.log('JobSubmission: Shared WebSocket connected'); + }, + message: (data) => { + // Handle subscription responses + if (data.type === 'subscribed' && data.channel) { + pendingSubscriptionsRef.current.delete(data.channel); + subscribedChannelsRef.current.add(data.channel); + console.log('Successfully subscribed to channel:', data.channel); + } else if (data.type === 'subscription_error' && data.channel) { + pendingSubscriptionsRef.current.delete(data.channel); + subscribedChannelsRef.current.delete(data.channel); + console.error('Subscription failed for channel:', data.channel, data.error); + // If it's the upload channel we're trying to subscribe to, show error + if (data.channel.startsWith('upload:')) { + setError(`Failed to subscribe to upload progress: ${data.error}`); + } + } + + // Handle upload progress messages + if (data.channel && data.channel.startsWith('upload:') && subscribedChannelsRef.current.has(data.channel)) { + if (data.type === 'upload_progress' || data.type === 'processing_status') { + const progress = data.data?.progress || 0; + const status = data.data?.status || 'uploading'; + const message = data.data?.message || ''; + + setUploadProgress(progress); + + // Calculate time remaining for upload progress + if (status === 'uploading' && progress > 0 && progress < 100) { + if (!uploadStartTimeRef.current) { + uploadStartTimeRef.current = Date.now(); + } + const elapsed = (Date.now() - uploadStartTimeRef.current) / 1000; // seconds + const remaining = (elapsed / progress) * (100 - progress); + setUploadTimeRemaining(remaining); + } else if (status === 'completed' || status === 'error') { + setUploadTimeRemaining(null); + uploadStartTimeRef.current = null; + } + + if (status === 'uploading') { + setMetadataStatus('extracting'); + } else if (status === 'processing' || status === 'extracting_zip' || status === 'extracting_metadata' || status === 'creating_context') { + setMetadataStatus('processing'); + // Reset time remaining for processing phase + setUploadTimeRemaining(null); + } else if (status === 'completed') { + setMetadataStatus('completed'); + setIsUploading(false); + setUploadTimeRemaining(null); + uploadStartTimeRef.current = null; + // Unsubscribe from upload channel + unsubscribeFromUploadChannel(data.channel); + } else if (status === 'error') { + setMetadataStatus('error'); + setIsUploading(false); + setUploadTimeRemaining(null); + uploadStartTimeRef.current = null; + setError(message || 'Upload/processing failed'); + // Unsubscribe from upload channel + unsubscribeFromUploadChannel(data.channel); + } + } + } + }, + error: (error) => { + console.error('JobSubmission: Shared WebSocket error:', error); + }, + close: (event) => { + console.log('JobSubmission: Shared WebSocket closed:', event); + subscribedChannelsRef.current.clear(); + pendingSubscriptionsRef.current.clear(); + } + }); + + // Ensure connection is established + wsManager.connect(); + + return () => { + // Unsubscribe from all channels before unmounting + unsubscribeFromAllChannels(); + if (listenerIdRef.current) { + wsManager.unsubscribe(listenerIdRef.current); + listenerIdRef.current = null; + } + }; + }, []); + + // Helper function to unsubscribe from upload channel + const unsubscribeFromUploadChannel = (channel) => { + if (wsManager.getReadyState() !== WebSocket.OPEN) { + return; + } + if (!subscribedChannelsRef.current.has(channel)) { + return; // Not subscribed + } + wsManager.send({ type: 'unsubscribe', channel }); + subscribedChannelsRef.current.delete(channel); + pendingSubscriptionsRef.current.delete(channel); + console.log('Unsubscribed from upload channel:', channel); + }; + + // Helper function to unsubscribe from all channels + const unsubscribeFromAllChannels = () => { + if (wsManager.getReadyState() !== WebSocket.OPEN) { + return; + } + subscribedChannelsRef.current.forEach(channel => { + wsManager.send({ type: 'unsubscribe', channel }); + }); + subscribedChannelsRef.current.clear(); + pendingSubscriptionsRef.current.clear(); + }; + // No polling needed - metadata is extracted synchronously during upload const handleFileChange = async (e) => { @@ -57,6 +205,8 @@ export default function JobSubmission({ onSuccess }) { setCurrentJobId(null); setUploadSessionId(null); setUploadProgress(0); + setUploadTimeRemaining(null); + uploadStartTimeRef.current = null; setBlendFiles([]); setSelectedMainBlend(''); formatManuallyChangedRef.current = false; // Reset when new file is selected @@ -69,24 +219,33 @@ export default function JobSubmission({ onSuccess }) { try { setIsUploading(true); setUploadProgress(0); + setUploadTimeRemaining(null); + uploadStartTimeRef.current = Date.now(); setMetadataStatus('extracting'); // Upload file to new endpoint (no job required) const result = await jobs.uploadFileForJobCreation(selectedFile, (progress) => { + // XHR progress as fallback, but WebSocket is primary setUploadProgress(progress); - // After upload completes, show processing state - if (progress >= 100) { - setMetadataStatus('processing'); + // Calculate time remaining for XHR progress + if (progress > 0 && progress < 100 && uploadStartTimeRef.current) { + const elapsed = (Date.now() - uploadStartTimeRef.current) / 1000; // seconds + const remaining = (elapsed / progress) * (100 - progress); + setUploadTimeRemaining(remaining); } }, selectedMainBlend || undefined); - // Keep showing processing state until we have the result - setMetadataStatus('processing'); - setUploadProgress(100); - // Store session ID for later use when creating the job if (result.session_id) { setUploadSessionId(result.session_id); + + // Subscribe to upload progress channel + if (wsManager.getReadyState() === WebSocket.OPEN) { + const channel = `upload:${result.session_id}`; + wsManager.send({ type: 'subscribe', channel }); + // Don't set subscribedUploadChannelRef yet - wait for confirmation + console.log('Subscribing to upload channel:', channel); + } } // Check if ZIP extraction found multiple blend files @@ -141,6 +300,9 @@ export default function JobSubmission({ onSuccess }) { setMetadataStatus('error'); setIsUploading(false); setUploadProgress(0); + setUploadSessionId(null); + setUploadTimeRemaining(null); + uploadStartTimeRef.current = null; setError(err.message || 'Failed to upload file and extract metadata'); } } @@ -155,26 +317,39 @@ export default function JobSubmission({ onSuccess }) { try { setIsUploading(true); setUploadProgress(0); + setUploadTimeRemaining(null); + uploadStartTimeRef.current = Date.now(); setMetadataStatus('extracting'); // Re-upload with selected main blend file const result = await jobs.uploadFileForJobCreation(file, (progress) => { + // XHR progress as fallback, but WebSocket is primary setUploadProgress(progress); - // After upload completes, show processing state - if (progress >= 100) { - setMetadataStatus('processing'); + // Calculate time remaining for XHR progress + if (progress > 0 && progress < 100 && uploadStartTimeRef.current) { + const elapsed = (Date.now() - uploadStartTimeRef.current) / 1000; // seconds + const remaining = (elapsed / progress) * (100 - progress); + setUploadTimeRemaining(remaining); } }, selectedMainBlend); - // Keep showing processing state until we have the result - setMetadataStatus('processing'); - setUploadProgress(100); setBlendFiles([]); - // Store session ID - if (result.session_id) { - setUploadSessionId(result.session_id); - } + // Store session ID and subscribe to upload progress + if (result.session_id) { + setUploadSessionId(result.session_id); + + // Subscribe to upload progress channel + if (wsManager.getReadyState() === WebSocket.OPEN) { + const channel = `upload:${result.session_id}`; + // Don't subscribe if already subscribed or pending + if (!subscribedChannelsRef.current.has(channel) && !pendingSubscriptionsRef.current.has(channel)) { + wsManager.send({ type: 'subscribe', channel }); + pendingSubscriptionsRef.current.add(channel); + console.log('Subscribing to upload channel:', channel); + } + } + } // Upload and processing complete setIsUploading(false); @@ -216,6 +391,10 @@ export default function JobSubmission({ onSuccess }) { setError(err.message || 'Failed to upload with selected blend file'); setIsUploading(false); setMetadataStatus('error'); + setUploadProgress(0); + setUploadSessionId(null); + setUploadTimeRemaining(null); + uploadStartTimeRef.current = null; } }; @@ -269,7 +448,16 @@ export default function JobSubmission({ onSuccess }) { throw new Error('File upload session not found. Please upload the file again.'); } - if (parseInt(formData.frame_end) < parseInt(formData.frame_start)) { + const frameStart = parseInt(formData.frame_start); + const frameEnd = parseInt(formData.frame_end); + + if (frameStart < 0) { + throw new Error('Frame start must be 0 or greater. Negative starting frames are not supported.'); + } + if (frameEnd < 0) { + throw new Error('Frame end must be 0 or greater. Negative frame numbers are not supported.'); + } + if (frameEnd < frameStart) { throw new Error('Invalid frame range'); } @@ -302,7 +490,34 @@ export default function JobSubmission({ onSuccess }) { // Set created job to show details setCreatedJob(jobDetails); } catch (err) { - setError(err.message || 'Failed to submit job'); + const errorMessage = err.message || 'Failed to submit job'; + + // Check if this is a session expiry error + if (errorMessage.includes('upload session') || + errorMessage.includes('Context archive not found') || + errorMessage.includes('Please upload the file again')) { + // Reset the entire form - upload session has expired + setError('Your upload session has expired. Please upload your file again.'); + setFile(null); + setMetadata(null); + setMetadataStatus(null); + setUploadSessionId(null); + setStep(1); + setFormData({ + name: '', + frame_start: 1, + frame_end: 10, + output_format: 'PNG', + allow_parallel_runners: true, + render_settings: null, + unhide_objects: false, + enable_execution: false, + }); + setShowAdvancedSettings(false); + formatManuallyChangedRef.current = false; + } else { + setError(errorMessage); + } setSubmitting(false); } }; @@ -347,39 +562,42 @@ export default function JobSubmission({ onSuccess }) {

Submit New Job

-
+
+ {/* Step 1: Upload */}
= 1 ? 'text-orange-500 font-medium' : 'text-gray-500'}`}>
= 1 ? 'bg-orange-600 text-white' : 'bg-gray-700'}`}> {step > 1 ? '✓' : '1'}
- Upload & Extract Metadata + Upload
- {metadata?.missing_files_info?.missing_addons && metadata.missing_files_info.missing_addons.length > 0 && ( - <> -
-
= 2 ? 'text-orange-500 font-medium' : 'text-gray-500'}`}> -
= 2 ? 'bg-orange-600 text-white' : 'bg-gray-700'}`}> - {step > 2 ? '✓' : '2'} -
- Missing Addons -
- - )}
+ {/* Step 2: Missing Addons (always shown, skipped if no addons) */} + {(() => { + const hasMissingAddons = metadata?.missing_files_info?.missing_addons && metadata.missing_files_info.missing_addons.length > 0; + const step2Completed = step > 2 || (step === 3 && !hasMissingAddons); + const step2Active = step === 2 || (step > 1 && hasMissingAddons && step < 3); + const step2Skipped = step >= 3 && !hasMissingAddons; + return ( +
+
+ {step2Completed ? '✓' : step2Skipped ? '—' : '2'} +
+ Addons +
+ ); + })()} +
+ {/* Step 3: Configure & Submit */}
= 3 ? 'text-orange-500 font-medium' : 'text-gray-500'}`}>
= 3 ? 'bg-orange-600 text-white' : 'bg-gray-700'}`}> - {step > 3 ? '✓' : (metadata?.missing_files_info?.missing_addons && metadata.missing_files_info.missing_addons.length > 0 ? '3' : '2')} + {step > 3 ? '✓' : '3'}
- Configure & Submit + Configure
- {error && ( -
- {error} -
- )} + {step === 1 ? ( // Step 1: Upload file and extract metadata @@ -437,7 +655,14 @@ export default function JobSubmission({ onSuccess }) {
Uploading file... - {Math.round(uploadProgress)}% +
+ {uploadTimeRemaining && ( + + ~{formatTimeRemaining(uploadTimeRemaining)} remaining + + )} + {Math.round(uploadProgress)}% +
)} {metadataStatus === 'completed' && metadata && ( -
-
Metadata extracted successfully!
-
-
Frames: {metadata.frame_start} - {metadata.frame_end}
-
Resolution: {metadata.render_settings?.resolution_x} x {metadata.render_settings?.resolution_y}
-
Frame Rate: {metadata.render_settings?.frame_rate || 24} fps
-
Engine: {metadata.render_settings?.engine}
- {metadata.render_settings?.engine_settings?.samples && ( -
Cycles Samples: {metadata.render_settings.engine_settings.samples}
- )} - {metadata.render_settings?.engine_settings?.taa_render_samples && ( -
EEVEE Samples: {metadata.render_settings.engine_settings.taa_render_samples}
- )} +
+ {metadata.has_negative_frames && ( +
+
⚠️ Negative Frame Numbers Detected
+
+ Your Blender file contains negative frame numbers (frame_start: {metadata.frame_start}, frame_end: {metadata.frame_end}). + Negative starting frames are not supported and may not work exactly as you expect. + Please adjust your Blender file's frame range settings to start at 0 or higher before submitting. +
+
+ )} +
+
Metadata extracted successfully!
+
+ {metadata.blender_version?.file_saved_with && ( +
+ Blender Version: {metadata.blender_version.file_saved_with} + (saved with — render version may differ) +
+ )} +
Frames: {metadata.frame_start} - {metadata.frame_end}
+
Resolution: {metadata.render_settings?.resolution_x} x {metadata.render_settings?.resolution_y}
+
Frame Rate: {metadata.render_settings?.frame_rate || 24} fps
+
Engine: {metadata.render_settings?.engine}
+ {metadata.render_settings?.engine_settings?.samples && ( +
Cycles Samples: {metadata.render_settings.engine_settings.samples}
+ )} + {metadata.render_settings?.engine_settings?.taa_render_samples && ( +
EEVEE Samples: {metadata.render_settings.engine_settings.taa_render_samples}
+ )} +
+
-
)} {metadataStatus === 'error' && ( @@ -565,31 +808,46 @@ export default function JobSubmission({ onSuccess }) { />
-
-
- - setFormData({ ...formData, frame_start: e.target.value })} - required - className="w-full px-4 py-2 bg-gray-900 border border-gray-600 rounded-lg text-gray-100 focus:ring-2 focus:ring-orange-500 focus:border-transparent" - /> +
+
+ EXPERIMENTAL +

+ Frame range auto-detection may vary by Blender version or hardware. Verify these values match your blend file. +

-
- - setFormData({ ...formData, frame_end: e.target.value })} - required - min={formData.frame_start} - className="w-full px-4 py-2 bg-gray-900 border border-gray-600 rounded-lg text-gray-100 focus:ring-2 focus:ring-orange-500 focus:border-transparent" - /> +
+
+ + setFormData({ ...formData, frame_start: e.target.value })} + required + min="0" + className="w-full px-4 py-2 bg-gray-900 border border-yellow-400/50 rounded-lg text-gray-100 focus:ring-2 focus:ring-yellow-500 focus:border-transparent" + /> + {formData.frame_start < 0 && ( +

Frame start must be 0 or greater. Negative frames are not supported.

+ )} +
+
+ + setFormData({ ...formData, frame_end: e.target.value })} + required + min={Math.max(0, formData.frame_start)} + className="w-full px-4 py-2 bg-gray-900 border border-yellow-400/50 rounded-lg text-gray-100 focus:ring-2 focus:ring-yellow-500 focus:border-transparent" + /> + {formData.frame_end < 0 && ( +

Frame end must be 0 or greater. Negative frames are not supported.

+ )} +
@@ -667,6 +925,12 @@ export default function JobSubmission({ onSuccess }) {
Metadata from blend file:
+ {metadata.blender_version?.file_saved_with && ( +
+ Blender Version: {metadata.blender_version.file_saved_with} + (saved with — render version may differ) +
+ )}
Frames: {metadata.frame_start} - {metadata.frame_end}
Resolution: {metadata.render_settings?.resolution_x} x {metadata.render_settings?.resolution_y}
Frame Rate: {metadata.render_settings?.frame_rate || 24} fps
@@ -744,7 +1008,6 @@ export default function JobSubmission({ onSuccess }) { > -
@@ -796,7 +1059,7 @@ export default function JobSubmission({ onSuccess }) { setFormData({ ...formData, @@ -817,110 +1080,2391 @@ export default function JobSubmission({ onSuccess }) { {/* Cycles Settings */} {formData.render_settings.engine === 'cycles' && formData.render_settings.engine_settings && ( -
+
Cycles Settings
-
- - setFormData({ - ...formData, - render_settings: { - ...formData.render_settings, - engine_settings: { - ...formData.render_settings.engine_settings, - samples: parseInt(e.target.value) || 128, + {/* Sampling Section */} +
+
Sampling
+ + {/* Noise Threshold */} +
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_adaptive_sampling: e.target.checked, + } } - } - })} - min="1" - className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" - /> + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + adaptive_threshold: parseFloat(e.target.value) || 0.01, + } + } + })} + disabled={formData.render_settings.engine_settings.use_adaptive_sampling === false} + min="0" + max="1" + className="flex-1 px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent disabled:opacity-50" + /> +
+ + {/* Max Samples */} +
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + samples: parseInt(e.target.value) || 4096, + } + } + })} + min="1" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+ + {/* Min Samples */} +
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + adaptive_min_samples: parseInt(e.target.value) || 0, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+ + {/* Time Limit */} +
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + time_limit: parseFloat(e.target.value) || 0, + } + } + })} + min="0" + step="any" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
-
- setFormData({ - ...formData, - render_settings: { - ...formData.render_settings, - engine_settings: { - ...formData.render_settings.engine_settings, - use_denoising: e.target.checked, +
+
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_denoising: e.target.checked, + } } - } - })} - className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" - /> - + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+ + {formData.render_settings.engine_settings.use_denoising && ( +
+ {/* Denoiser Info */} +
+ Using OpenImageDenoise (GPU agnostic) +
+ + {/* Passes */} +
+ + +
+ + {/* Prefilter */} +
+ + +
+ + {/* Quality (OpenImageDenoise only) */} + {(formData.render_settings.engine_settings.denoiser || 'OPENIMAGEDENOISE') === 'OPENIMAGEDENOISE' && ( +
+ + +
+ )} + +
+ )}
-
- setFormData({ - ...formData, - render_settings: { - ...formData.render_settings, - engine_settings: { - ...formData.render_settings.engine_settings, - use_adaptive_sampling: e.target.checked, + {/* Sampling > Path Guiding */} +
+
Sampling › Path Guiding
+ +
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_guiding: e.target.checked, + } } - } - })} - className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" - /> - + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+ + {formData.render_settings.engine_settings.use_guiding && ( +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + guiding_training_samples: parseInt(e.target.value) || 128, + } + } + })} + min="1" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+ +
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_surface_guiding: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+ +
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_volume_guiding: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+
+ )}
-
- - setFormData({ - ...formData, - render_settings: { - ...formData.render_settings, - engine_settings: { - ...formData.render_settings.engine_settings, - max_bounces: parseInt(e.target.value) || 12, + {/* Sampling > Lights */} +
+
Sampling › Lights
+ +
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_light_tree: e.target.checked, + } } - } - })} - min="0" - className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" - /> + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+ +
+ + { + const val = parseFloat(e.target.value); + const clampedVal = isNaN(val) ? 0.01 : Math.max(0, Math.min(1, val)); + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + light_sampling_threshold: clampedVal, + } + } + }); + }} + min="0" + max="1" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
-
+ {/* Light Paths Section - matches Blender's panel */} +
+
Light Paths › Max Bounces
+ +
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + max_bounces: parseInt(e.target.value) || 12, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+ +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + diffuse_bounces: parseInt(e.target.value) || 4, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + glossy_bounces: parseInt(e.target.value) || 4, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + transmission_bounces: parseInt(e.target.value) || 12, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + volume_bounces: parseInt(e.target.value) || 0, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ +
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + transparent_max_bounces: parseInt(e.target.value) || 8, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + {/* Light Paths > Clamping */} +
+
Light Paths › Clamping
+ +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + sample_clamp_direct: parseFloat(e.target.value) || 0, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + sample_clamp_indirect: parseFloat(e.target.value) || 0, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+
+ + {/* Light Paths > Caustics */} +
+
Light Paths › Caustics
+ +
+
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + caustics_reflective: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + caustics_refractive: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+
+ +
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + blur_glossy: parseFloat(e.target.value) || 0, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + {/* Light Paths > Fast GI */} +
+
Light Paths › Fast GI Approximation
+ +
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_fast_gi: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+ + {formData.render_settings.engine_settings.use_fast_gi && ( +
+
+ + +
+ +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + ao_bounces: parseInt(e.target.value) || 1, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + ao_bounces_render: parseInt(e.target.value) || 1, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+
+ )} +
+ + {/* Curves (Hair) Section - matches Blender's panel */} +
+
Curves
+ +
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_hair: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+ + {formData.render_settings.engine_settings.use_hair !== false && ( +
+
+ + +
+ +
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + hair_subdivisions: parseInt(e.target.value) || 2, + } + } + })} + min="0" + max="24" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ )} +
+ + {/* Volumes Section */} +
+
Volumes
+ +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + volume_step_rate: parseFloat(e.target.value) || 1.0, + } + } + })} + min="0.01" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + volume_max_steps: parseInt(e.target.value) || 1024, + } + } + })} + min="1" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+
+ + {/* Film Section */} +
+
Film
+ +
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + film_exposure: parseFloat(e.target.value) || 1.0, + } + } + })} + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+ +
+
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + film_transparent: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + film_transparent_glass: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+
+ +
+ + +
+ +
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + filter_width: parseFloat(e.target.value) || 1.5, + } + } + })} + min="0.01" + max="10" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + {/* Performance Section */} +
+
Performance
+ +
+
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_auto_tile: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_persistent_data: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+
+ +
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + tile_size: parseInt(e.target.value) || 2048, + } + } + })} + min="8" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + {/* Simplify Section */} +
+
Simplify
+ +
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_simplify: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+ + {formData.render_settings.engine_settings.use_simplify && ( +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + simplify_subdivision_render: parseInt(e.target.value) || 6, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+ +
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + simplify_child_particles_render: parseFloat(e.target.value) || 1.0, + } + } + })} + min="0" + max="1" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ )} +
+
+ )} + + {/* EEVEE Settings */} + {formData.render_settings.engine === 'eevee' && formData.render_settings.engine_settings && ( +
+
EEVEE Settings
+ + {/* Sampling Section */} +
+
Sampling
+ +
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + taa_render_samples: parseInt(e.target.value) || 64, + } + } + })} + min="1" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+ +
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + taa_samples: parseInt(e.target.value) || 16, + } + } + })} + min="1" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+ +
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_taa_reprojection: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+
+ + {/* Sampling > Clamping */} +
+
Sampling › Clamping
+ +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + clamp_surface_direct: parseFloat(e.target.value) || 0, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + clamp_surface_indirect: parseFloat(e.target.value) || 0, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+
+ + {/* Shadows Section */} +
+
Shadows
+ +
+
+ + +
+
+ + +
+
+ +
+
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_soft_shadows: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_shadow_high_bitdepth: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+
+ +
+ + { + const val = parseFloat(e.target.value); + const clampedVal = isNaN(val) ? 0.01 : Math.max(0, Math.min(1, val)); + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + light_threshold: clampedVal, + } + } + }); + }} + min="0" + max="1" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + {/* Raytracing Section */} +
+
Raytracing
+ +
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_raytracing: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+ + {formData.render_settings.engine_settings.use_raytracing && ( +
+ + +
+ )} +
+ + {/* Screen Space Reflections Section */} +
+
Screen Space Reflections
+ +
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_ssr: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+ + {formData.render_settings.engine_settings.use_ssr && ( +
+
+
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_ssr_refraction: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_ssr_halfres: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+
+ +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + ssr_quality: parseFloat(e.target.value) || 0.25, + } + } + })} + min="0" + max="1" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + ssr_max_roughness: parseFloat(e.target.value) || 0.5, + } + } + })} + min="0" + max="1" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + ssr_thickness: parseFloat(e.target.value) || 0.2, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + ssr_border_fade: parseFloat(e.target.value) || 0.075, + } + } + })} + min="0" + max="0.5" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+
+ )} +
+ + {/* Ambient Occlusion Section */} +
+
Ambient Occlusion
+ +
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_gtao: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+ + {formData.render_settings.engine_settings.use_gtao && ( +
+
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + gtao_distance: parseFloat(e.target.value) || 0.2, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + gtao_factor: parseFloat(e.target.value) || 1.0, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ +
+
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_gtao_bent_normals: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_gtao_bounce: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+
+
+ )} +
+ + {/* Bloom Section */} +
+
Bloom
+ +
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_bloom: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+ + {formData.render_settings.engine_settings.use_bloom && ( +
+
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + bloom_threshold: parseFloat(e.target.value) || 0.8, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + bloom_knee: parseFloat(e.target.value) || 0.5, + } + } + })} + min="0" + max="1" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + bloom_radius: parseFloat(e.target.value) || 6.5, + } + } + })} + min="0" + max="2048" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + bloom_intensity: parseFloat(e.target.value) || 0.05, + } + } + })} + min="0" + max="1" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ +
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + bloom_clamp: parseFloat(e.target.value) || 0, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ )} +
+ + {/* Depth of Field Section */} +
+
Depth of Field
+ +
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + bokeh_max_size: parseFloat(e.target.value) || 100, + } + } + })} + min="0" + max="2048" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+ +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + bokeh_threshold: parseFloat(e.target.value) || 1.0, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + bokeh_neighbor_max: parseFloat(e.target.value) || 10, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + bokeh_denoise_fac: parseFloat(e.target.value) || 0.75, + } + } + })} + min="0" + max="1" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + bokeh_overblur: parseFloat(e.target.value) || 5, + } + } + })} + min="0" + max="100" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ +
+
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_bokeh_high_quality_slight_defocus: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_bokeh_jittered: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+
+
+ + {/* Subsurface Scattering Section */} +
+
Subsurface Scattering
+ +
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + sss_samples: parseInt(e.target.value) || 7, + } + } + })} + min="1" + max="32" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+ +
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + sss_jitter_threshold: parseFloat(e.target.value) || 0.3, + } + } + })} + min="0" + max="1" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + {/* Volumetrics Section */} +
+
Volumetrics
+ +
+
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_volumetric_lights: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_volumetric_shadows: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+
+ +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + volumetric_start: parseFloat(e.target.value) || 0.1, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + volumetric_end: parseFloat(e.target.value) || 100, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ +
+
+ + +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + volumetric_samples: parseInt(e.target.value) || 64, + } + } + })} + min="1" + max="256" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ +
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + volumetric_sample_distribution: parseFloat(e.target.value) || 0.8, + } + } + })} + min="0" + max="1" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + {/* Motion Blur Section */} +
+
Motion Blur
+ +
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_motion_blur: e.target.checked, + } + } + })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
+ + {formData.render_settings.engine_settings.use_motion_blur && ( +
+
+ + +
+ +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + motion_blur_shutter: parseFloat(e.target.value) || 0.5, + } + } + })} + min="0" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + motion_blur_steps: parseInt(e.target.value) || 1, + } + } + })} + min="1" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+
+
+ )} +
+ + {/* Indirect Lighting Section */} +
+
Indirect Lighting
+
setFormData({ ...formData, render_settings: { ...formData.render_settings, engine_settings: { ...formData.render_settings.engine_settings, - diffuse_bounces: parseInt(e.target.value) || 4, + gi_diffuse_bounces: parseInt(e.target.value) || 3, } } })} @@ -928,140 +3472,132 @@ export default function JobSubmission({ onSuccess }) { className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" />
-
- + +
+
+ + +
+
+ + +
+
+ +
setFormData({ ...formData, render_settings: { ...formData.render_settings, engine_settings: { ...formData.render_settings.engine_settings, - glossy_bounces: parseInt(e.target.value) || 4, + gi_auto_bake: e.target.checked, } } })} - min="0" - className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" /> +
-
- )} - {/* EEVEE Settings */} - {(formData.render_settings.engine === 'eevee' || formData.render_settings.engine === 'eevee_next') && formData.render_settings.engine_settings && ( -
-
EEVEE Settings
- -
- - setFormData({ - ...formData, - render_settings: { - ...formData.render_settings, - engine_settings: { - ...formData.render_settings.engine_settings, - taa_render_samples: parseInt(e.target.value) || 64, + {/* Film Section */} +
+
Film
+ +
+ setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + use_overscan: e.target.checked, + } } - } - })} - min="1" - className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" - /> -
+ })} + className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" + /> + +
-
- setFormData({ - ...formData, - render_settings: { - ...formData.render_settings, - engine_settings: { - ...formData.render_settings.engine_settings, - use_bloom: e.target.checked, - } - } - })} - className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" - /> - -
- -
- setFormData({ - ...formData, - render_settings: { - ...formData.render_settings, - engine_settings: { - ...formData.render_settings.engine_settings, - use_ssr: e.target.checked, - } - } - })} - className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" - /> - -
- -
- setFormData({ - ...formData, - render_settings: { - ...formData.render_settings, - engine_settings: { - ...formData.render_settings.engine_settings, - use_ssao: e.target.checked, - } - } - })} - className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" - /> - -
- -
- setFormData({ - ...formData, - render_settings: { - ...formData.render_settings, - engine_settings: { - ...formData.render_settings.engine_settings, - use_volumetric: e.target.checked, - } - } - })} - className="h-4 w-4 text-orange-600 focus:ring-orange-500 border-gray-600 bg-gray-800 rounded" - /> - + {formData.render_settings.engine_settings.use_overscan && ( +
+ + setFormData({ + ...formData, + render_settings: { + ...formData.render_settings, + engine_settings: { + ...formData.render_settings.engine_settings, + overscan_size: parseFloat(e.target.value) || 3, + } + } + })} + min="0" + max="50" + className="w-full px-3 py-1.5 bg-gray-800 border border-gray-600 rounded text-gray-100 text-sm focus:ring-2 focus:ring-orange-500 focus:border-transparent" + /> +
+ )}
)} @@ -1078,7 +3614,16 @@ export default function JobSubmission({ onSuccess }) {
{isUploading ? 'Uploading file...' : 'Creating job...'} - {isUploading && {Math.round(uploadProgress)}%} + {isUploading && ( +
+ {uploadTimeRemaining && ( + + ~{formatTimeRemaining(uploadTimeRemaining)} remaining + + )} + {Math.round(uploadProgress)}% +
+ )}
{isUploading && (
diff --git a/web/src/components/LoadingSpinner.jsx b/web/src/components/LoadingSpinner.jsx new file mode 100644 index 0000000..3a3d155 --- /dev/null +++ b/web/src/components/LoadingSpinner.jsx @@ -0,0 +1,19 @@ +import React from 'react'; + +/** + * Shared LoadingSpinner component with size variants + */ +export default function LoadingSpinner({ size = 'md', className = '', borderColor = 'border-orange-500' }) { + const sizeClasses = { + sm: 'h-8 w-8', + md: 'h-12 w-12', + lg: 'h-16 w-16', + }; + + return ( +
+
+
+ ); +} + diff --git a/web/src/components/Login.jsx b/web/src/components/Login.jsx index 5b64989..0ae2d35 100644 --- a/web/src/components/Login.jsx +++ b/web/src/components/Login.jsx @@ -1,5 +1,6 @@ import { useState, useEffect } from 'react'; import { auth } from '../utils/api'; +import ErrorMessage from './ErrorMessage'; export default function Login() { const [providers, setProviders] = useState({ @@ -92,11 +93,7 @@ export default function Login() {
- {error && ( -
- {error} -
- )} + {providers.local && (
diff --git a/web/src/components/PasswordChange.jsx b/web/src/components/PasswordChange.jsx index 16a8fc9..1b76c93 100644 --- a/web/src/components/PasswordChange.jsx +++ b/web/src/components/PasswordChange.jsx @@ -1,5 +1,6 @@ import { useState } from 'react'; import { auth } from '../utils/api'; +import ErrorMessage from './ErrorMessage'; import { useAuth } from '../hooks/useAuth'; export default function PasswordChange({ targetUserId = null, targetUserName = null, onSuccess }) { @@ -64,11 +65,7 @@ export default function PasswordChange({ targetUserId = null, targetUserName = n {isChangingOtherUser ? `Change Password for ${targetUserName || 'User'}` : 'Change Password'}

- {error && ( -
- {error} -
- )} + {success && (
diff --git a/web/src/components/UserJobs.jsx b/web/src/components/UserJobs.jsx index 1d1e62e..8054e14 100644 --- a/web/src/components/UserJobs.jsx +++ b/web/src/components/UserJobs.jsx @@ -1,22 +1,71 @@ -import { useState, useEffect } from 'react'; -import { admin } from '../utils/api'; +import { useState, useEffect, useRef } from 'react'; +import { admin, normalizeArrayResponse } from '../utils/api'; +import { wsManager } from '../utils/websocket'; import JobDetails from './JobDetails'; +import LoadingSpinner from './LoadingSpinner'; export default function UserJobs({ userId, userName, onBack }) { const [jobList, setJobList] = useState([]); const [loading, setLoading] = useState(true); const [selectedJob, setSelectedJob] = useState(null); + const listenerIdRef = useRef(null); useEffect(() => { loadJobs(); - const interval = setInterval(loadJobs, 5000); - return () => clearInterval(interval); + // Use shared WebSocket manager for real-time updates instead of polling + listenerIdRef.current = wsManager.subscribe(`userjobs_${userId}`, { + open: () => { + console.log('UserJobs: Shared WebSocket connected'); + loadJobs(); + }, + message: (data) => { + // Handle jobs channel messages (always broadcasted) + if (data.channel === 'jobs') { + if (data.type === 'job_update' && data.data) { + // Update job in list if it belongs to this user + setJobList(prev => { + const prevArray = Array.isArray(prev) ? prev : []; + const index = prevArray.findIndex(j => j.id === data.job_id); + if (index >= 0) { + const updated = [...prevArray]; + updated[index] = { ...updated[index], ...data.data }; + return updated; + } + // If job not in current list, reload to get updated list + if (data.data.status === 'completed' || data.data.status === 'failed') { + loadJobs(); + } + return prevArray; + }); + } else if (data.type === 'job_created' && data.data) { + // New job created - reload to check if it belongs to this user + loadJobs(); + } + } + }, + error: (error) => { + console.error('UserJobs: Shared WebSocket error:', error); + }, + close: (event) => { + console.log('UserJobs: Shared WebSocket closed:', event); + } + }); + + // Ensure connection is established + wsManager.connect(); + + return () => { + if (listenerIdRef.current) { + wsManager.unsubscribe(listenerIdRef.current); + listenerIdRef.current = null; + } + }; }, [userId]); const loadJobs = async () => { try { const data = await admin.getUserJobs(userId); - setJobList(Array.isArray(data) ? data : []); + setJobList(normalizeArrayResponse(data)); } catch (error) { console.error('Failed to load jobs:', error); setJobList([]); @@ -47,11 +96,7 @@ export default function UserJobs({ userId, userName, onBack }) { } if (loading) { - return ( -
-
-
- ); + return ; } return ( diff --git a/web/src/components/VideoPlayer.jsx b/web/src/components/VideoPlayer.jsx index 09e565d..3c23f0c 100644 --- a/web/src/components/VideoPlayer.jsx +++ b/web/src/components/VideoPlayer.jsx @@ -1,4 +1,6 @@ import { useState, useRef, useEffect } from 'react'; +import ErrorMessage from './ErrorMessage'; +import LoadingSpinner from './LoadingSpinner'; export default function VideoPlayer({ videoUrl, onClose }) { const videoRef = useRef(null); @@ -55,10 +57,10 @@ export default function VideoPlayer({ videoUrl, onClose }) { if (error) { return ( -
- {error} -
- Download video instead + ); @@ -68,7 +70,7 @@ export default function VideoPlayer({ videoUrl, onClose }) {
{loading && (
-
+
)}