something

This commit is contained in:
2025-11-27 00:46:48 -06:00
parent 11e7552b5b
commit edc8ea160c
43 changed files with 9990 additions and 3059 deletions

1
.gitignore vendored
View File

@@ -69,6 +69,7 @@ lerna-debug.log*
# Logs
*.log
/logs/
# OS files
Thumbs.db

170
Makefile
View File

@@ -1,114 +1,66 @@
.PHONY: build build-manager build-runner build-web run-manager run-runner run cleanup cleanup-manager cleanup-runner kill-all clean test help
.PHONY: build build-web run run-manager run-runner cleanup cleanup-manager cleanup-runner kill-all clean-bin clean-web test help install
# Build all components
build: clean-bin build-manager build-runner
# Build the jiggablend binary (includes embedded web UI)
build: clean-bin build-web
go build -o bin/jiggablend ./cmd/jiggablend
# Build manager
build-manager: clean-bin build-web
go build -o bin/manager ./cmd/manager
# Build runner
build-runner: clean-bin
GOOS=linux GOARCH=amd64 go build -o bin/runner ./cmd/runner
# Build for Linux (cross-compile)
build-linux: clean-bin build-web
GOOS=linux GOARCH=amd64 go build -o bin/jiggablend ./cmd/jiggablend
# Build web UI
build-web: clean-web
cd web && npm install && npm run build
# Cleanup manager (database and storage)
# Cleanup manager logs
cleanup-manager:
@echo "Cleaning up manager database and storage..."
@rm -f jiggablend.db 2>/dev/null || true
@rm -f jiggablend.db-shm 2>/dev/null || true
@rm -f jiggablend.db-wal 2>/dev/null || true
@rm -rf jiggablend-storage 2>/dev/null || true
@echo "Cleaning up manager logs..."
@rm -rf logs/manager.log 2>/dev/null || true
@echo "Manager cleanup complete"
# Cleanup runner workspaces
# Cleanup runner logs
cleanup-runner:
@echo "Cleaning up runner workspaces..."
@rm -rf jiggablend-workspaces jiggablend-workspace* *workspace* 2>/dev/null || true
@echo "Cleaning up runner logs..."
@rm -rf logs/runner*.log 2>/dev/null || true
@echo "Runner cleanup complete"
# Cleanup both manager and runner
# Cleanup both manager and runner logs
cleanup: cleanup-manager cleanup-runner
# Kill all manager and runner processes
# Kill all jiggablend processes
kill-all:
@echo "Killing all manager and runner processes..."
@# Kill manager processes (compiled binaries in bin/, root, and go run)
@-pkill -f "bin/manager" 2>/dev/null || true
@-pkill -f "\./manager" 2>/dev/null || true
@-pkill -f "manager" 2>/dev/null || true
@-pkill -f "main.*cmd/manager" 2>/dev/null || true
@-pkill -f "go run.*cmd/manager" 2>/dev/null || true
@# Kill runner processes (compiled binaries in bin/, root, and go run)
@-pkill -f "bin/runner" 2>/dev/null || true
@-pkill -f "\./runner" 2>/dev/null || true
@-pkill -f "runner" 2>/dev/null || true
@-pkill -f "main.*cmd/runner" 2>/dev/null || true
@-pkill -f "go run.*cmd/runner" 2>/dev/null || true
@# Wait a moment for graceful shutdown
@echo "Waiting for 5 seconds for graceful shutdown..."
@sleep 1
@echo "5"
@sleep 1
@echo "4"
@sleep 1
@echo "3"
@sleep 1
@echo "2"
@sleep 1
@echo "1"
@sleep 1
@echo "0"
@echo "Not implemented"
@# Check if any manager or runner processes are still running
@MANAGER_COUNT=$$(pgrep -f "bin/manager\|\./manager\|manager\|main.*cmd/manager\|go run.*cmd/manager" | wc -l); \
RUNNER_COUNT=$$(pgrep -f "bin/runner\|\./runner\|runner\|main.*cmd/runner\|go run.*cmd/runner" | wc -l); \
if [ $$MANAGER_COUNT -eq 0 ] && [ $$RUNNER_COUNT -eq 0 ]; then \
echo "All manager and runner processes have shut down gracefully"; \
exit 0; \
else \
echo "Some processes still running ($$MANAGER_COUNT managers, $$RUNNER_COUNT runners), proceeding with force kill..."; \
fi
@# Force kill any remaining processes
@-pkill -9 -f "bin/manager" 2>/dev/null || true
@-pkill -9 -f "\./manager" 2>/dev/null || true
@-pkill -9 -f "main.*cmd/manager" 2>/dev/null || true
@-pkill -9 -f "go run.*cmd/manager" 2>/dev/null || true
@-pkill -9 -f "bin/runner" 2>/dev/null || true
@-pkill -9 -f "\./runner" 2>/dev/null || true
@-pkill -9 -f "main.*cmd/runner" 2>/dev/null || true
@-pkill -9 -f "go run.*cmd/runner" 2>/dev/null || true
@echo "All manager and runner processes killed after 5 seconds"
# Run all parallel
run: cleanup-manager cleanup-runner build-manager build-runner
# Run manager and runner in parallel (for testing)
run: cleanup build init-test
@echo "Starting manager and runner in parallel..."
@echo "Press Ctrl+C to stop both..."
@echo "Note: This will create a test API key for the runner to use"
@trap 'kill $$MANAGER_PID $$RUNNER_PID 2>/dev/null; exit' INT TERM; \
FIXED_API_KEY=jk_r0_test_key_123456789012345678901234567890 ENABLE_LOCAL_AUTH=true LOCAL_TEST_EMAIL=test@example.com LOCAL_TEST_PASSWORD=testpassword bin/manager & \
bin/jiggablend manager & \
MANAGER_PID=$$!; \
sleep 2; \
API_KEY=jk_r0_test_key_123456789012345678901234567890 bin/runner & \
bin/jiggablend runner --api-key=jk_r0_test_key_123456789012345678901234567890 & \
RUNNER_PID=$$!; \
wait $$MANAGER_PID $$RUNNER_PID
# Run manager with test API key
# Note: ENABLE_LOCAL_AUTH enables local user registration/login
# LOCAL_TEST_EMAIL and LOCAL_TEST_PASSWORD create a test user on startup (if it doesn't exist)
# FIXED_API_KEY provides a pre-configured API key for testing (jk_r0_... format)
# The manager will accept this API key for runner registration
run-manager: cleanup-manager build-manager
FIXED_API_KEY=jk_r0_test_key_123456789012345678901234567890 ENABLE_LOCAL_AUTH=true LOCAL_TEST_EMAIL=test@example.com LOCAL_TEST_PASSWORD=testpassword bin/manager
# Run manager server
run-manager: cleanup-manager build init-test
bin/jiggablend manager
# Run runner with test API key
# Note: API_KEY must match what the manager accepts (see run-manager)
# The runner will use this API key for all authentication with the manager
run-runner: cleanup-runner build-runner
API_KEY=jk_r0_test_key_123456789012345678901234567890 bin/runner
# Run runner
run-runner: cleanup-runner build
bin/jiggablend runner --api-key=jk_r0_test_key_123456789012345678901234567890
# Initialize for testing (first run setup)
init-test: build
@echo "Initializing test configuration..."
bin/jiggablend manager config enable localauth
bin/jiggablend manager config set fixed-apikey jk_r0_test_key_123456789012345678901234567890 -f -y
bin/jiggablend manager config add user test@example.com testpassword --admin -f -y
@echo "Test configuration complete!"
@echo "fixed api key: jk_r0_test_key_123456789012345678901234567890"
@echo "test user: test@example.com"
@echo "test password: testpassword"
# Clean bin build artifacts
clean-bin:
@@ -122,39 +74,45 @@ clean-web:
test:
go test ./... -timeout 30s
# Install to /usr/local/bin
install: build
sudo cp bin/jiggablend /usr/local/bin/
# Show help
help:
@echo "Jiggablend Build and Run Makefile"
@echo ""
@echo "Build targets:"
@echo " build - Build manager, runner, and web UI"
@echo " build-manager - Build manager with web UI"
@echo " build-runner - Build runner binary"
@echo " build-web - Build web UI"
@echo " build - Build jiggablend binary with embedded web UI"
@echo " build-linux - Cross-compile for Linux amd64"
@echo " build-web - Build web UI only"
@echo ""
@echo "Run targets:"
@echo " run - Run manager and runner in parallel with test API key"
@echo " run-manager - Run manager with test API key enabled"
@echo " run - Run manager and runner in parallel (for testing)"
@echo " run-manager - Run manager server"
@echo " run-runner - Run runner with test API key"
@echo " init-test - Initialize test configuration (run once)"
@echo ""
@echo "Cleanup targets:"
@echo " cleanup - Clean manager and runner data"
@echo " cleanup-manager - Clean manager database and storage"
@echo " cleanup-runner - Clean runner workspaces and API keys"
@echo " cleanup - Clean all logs"
@echo " cleanup-manager - Clean manager logs"
@echo " cleanup-runner - Clean runner logs"
@echo " kill-all - Kill all running jiggablend processes"
@echo ""
@echo "Other targets:"
@echo " clean - Clean build artifacts"
@echo " kill-all - Kill all running manager and runner processes (binaries in bin/, root, or go run)"
@echo " clean-bin - Clean build artifacts"
@echo " clean-web - Clean web build artifacts"
@echo " test - Run Go tests"
@echo " install - Install to /usr/local/bin"
@echo " help - Show this help"
@echo ""
@echo "API Key System:"
@echo " - FIXED_API_KEY: Pre-configured API key for manager (optional)"
@echo " - API_KEY: API key for runner authentication"
@echo " - Format: jk_r{N}_{32-char-hex}"
@echo " - Generate via admin UI or set FIXED_API_KEY for testing"
@echo ""
@echo "Timeouts:"
@echo " - Use 'timeout <seconds> make run' to prevent hanging during testing"
@echo " - Example: timeout 30s make run"
@echo "CLI Usage:"
@echo " jiggablend manager serve - Start the manager server"
@echo " jiggablend runner - Start a runner"
@echo " jiggablend manager config show - Show configuration"
@echo " jiggablend manager config enable localauth"
@echo " jiggablend manager config add user --email=x --password=y"
@echo " jiggablend manager config add apikey --name=mykey"
@echo " jiggablend manager config set fixed-apikey <key>"
@echo " jiggablend manager config list users"
@echo " jiggablend manager config list apikeys"

View File

@@ -0,0 +1,152 @@
package cmd
import (
"fmt"
"net/http"
"os/exec"
"strings"
"jiggablend/internal/api"
"jiggablend/internal/auth"
"jiggablend/internal/config"
"jiggablend/internal/database"
"jiggablend/internal/logger"
"jiggablend/internal/storage"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
var managerCmd = &cobra.Command{
Use: "manager",
Short: "Start the Jiggablend manager server",
Long: `Start the Jiggablend manager server to coordinate render jobs.`,
Run: runManager,
}
func init() {
rootCmd.AddCommand(managerCmd)
// Flags with env binding via viper
managerCmd.Flags().StringP("port", "p", "8080", "Server port")
managerCmd.Flags().String("db", "jiggablend.db", "Database path")
managerCmd.Flags().String("storage", "./jiggablend-storage", "Storage path")
managerCmd.Flags().StringP("log-file", "l", "", "Log file path (truncated on start, if not set logs only to stdout)")
managerCmd.Flags().String("log-level", "info", "Log level (debug, info, warn, error)")
managerCmd.Flags().BoolP("verbose", "v", false, "Enable verbose logging (same as --log-level=debug)")
// Bind flags to viper with JIGGABLEND_ prefix
viper.SetEnvPrefix("JIGGABLEND")
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
viper.AutomaticEnv()
viper.BindPFlag("port", managerCmd.Flags().Lookup("port"))
viper.BindPFlag("db", managerCmd.Flags().Lookup("db"))
viper.BindPFlag("storage", managerCmd.Flags().Lookup("storage"))
viper.BindPFlag("log_file", managerCmd.Flags().Lookup("log-file"))
viper.BindPFlag("log_level", managerCmd.Flags().Lookup("log-level"))
viper.BindPFlag("verbose", managerCmd.Flags().Lookup("verbose"))
}
func runManager(cmd *cobra.Command, args []string) {
// Get config values (flags take precedence over env vars)
port := viper.GetString("port")
dbPath := viper.GetString("db")
storagePath := viper.GetString("storage")
logFile := viper.GetString("log_file")
logLevel := viper.GetString("log_level")
verbose := viper.GetBool("verbose")
// Initialize logger
if logFile != "" {
if err := logger.InitWithFile(logFile); err != nil {
logger.Fatalf("Failed to initialize logger: %v", err)
}
defer func() {
if l := logger.GetDefault(); l != nil {
l.Close()
}
}()
} else {
logger.InitStdout()
}
// Set log level
if verbose {
logger.SetLevel(logger.LevelDebug)
} else {
logger.SetLevel(logger.ParseLevel(logLevel))
}
if logFile != "" {
logger.Infof("Logging to file: %s", logFile)
}
logger.Debugf("Log level: %s", logLevel)
// Initialize database
db, err := database.NewDB(dbPath)
if err != nil {
logger.Fatalf("Failed to initialize database: %v", err)
}
defer db.Close()
// Initialize config from database
cfg := config.NewConfig(db)
if err := cfg.InitializeFromEnv(); err != nil {
logger.Fatalf("Failed to initialize config: %v", err)
}
logger.Info("Configuration loaded from database")
// Initialize auth
authHandler, err := auth.NewAuth(db, cfg)
if err != nil {
logger.Fatalf("Failed to initialize auth: %v", err)
}
// Initialize storage
storageHandler, err := storage.NewStorage(storagePath)
if err != nil {
logger.Fatalf("Failed to initialize storage: %v", err)
}
// Check if Blender is available
if err := checkBlenderAvailable(); err != nil {
logger.Fatalf("Blender is not available: %v\n"+
"The manager requires Blender to be installed and in PATH for metadata extraction.\n"+
"Please install Blender and ensure it's accessible via the 'blender' command.", err)
}
logger.Info("Blender is available")
// Create API server
server, err := api.NewServer(db, cfg, authHandler, storageHandler)
if err != nil {
logger.Fatalf("Failed to create server: %v", err)
}
// Start server
addr := fmt.Sprintf(":%s", port)
logger.Infof("Starting manager server on %s", addr)
logger.Infof("Database: %s", dbPath)
logger.Infof("Storage: %s", storagePath)
httpServer := &http.Server{
Addr: addr,
Handler: server,
MaxHeaderBytes: 1 << 20,
ReadTimeout: 0,
WriteTimeout: 0,
}
if err := httpServer.ListenAndServe(); err != nil {
logger.Fatalf("Server failed: %v", err)
}
}
func checkBlenderAvailable() error {
cmd := exec.Command("blender", "--version")
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to run 'blender --version': %w (output: %s)", err, string(output))
}
return nil
}

View File

@@ -0,0 +1,621 @@
package cmd
import (
"bufio"
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/hex"
"fmt"
"os"
"strings"
"jiggablend/internal/config"
"jiggablend/internal/database"
"github.com/spf13/cobra"
"golang.org/x/crypto/bcrypt"
)
var (
configDBPath string
configYes bool // Auto-confirm prompts
configForce bool // Force override existing
)
var configCmd = &cobra.Command{
Use: "config",
Short: "Configure the manager",
Long: `Configure the Jiggablend manager settings stored in the database.`,
}
// --- Enable/Disable commands ---
var enableCmd = &cobra.Command{
Use: "enable",
Short: "Enable a feature",
}
var disableCmd = &cobra.Command{
Use: "disable",
Short: "Disable a feature",
}
var enableLocalAuthCmd = &cobra.Command{
Use: "localauth",
Short: "Enable local authentication",
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
if err := cfg.SetBool(config.KeyEnableLocalAuth, true); err != nil {
exitWithError("Failed to enable local auth: %v", err)
}
fmt.Println("Local authentication enabled")
})
},
}
var disableLocalAuthCmd = &cobra.Command{
Use: "localauth",
Short: "Disable local authentication",
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
if err := cfg.SetBool(config.KeyEnableLocalAuth, false); err != nil {
exitWithError("Failed to disable local auth: %v", err)
}
fmt.Println("Local authentication disabled")
})
},
}
var enableRegistrationCmd = &cobra.Command{
Use: "registration",
Short: "Enable user registration",
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
if err := cfg.SetBool(config.KeyRegistrationEnabled, true); err != nil {
exitWithError("Failed to enable registration: %v", err)
}
fmt.Println("User registration enabled")
})
},
}
var disableRegistrationCmd = &cobra.Command{
Use: "registration",
Short: "Disable user registration",
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
if err := cfg.SetBool(config.KeyRegistrationEnabled, false); err != nil {
exitWithError("Failed to disable registration: %v", err)
}
fmt.Println("User registration disabled")
})
},
}
var enableProductionCmd = &cobra.Command{
Use: "production",
Short: "Enable production mode",
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
if err := cfg.SetBool(config.KeyProductionMode, true); err != nil {
exitWithError("Failed to enable production mode: %v", err)
}
fmt.Println("Production mode enabled")
})
},
}
var disableProductionCmd = &cobra.Command{
Use: "production",
Short: "Disable production mode",
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
if err := cfg.SetBool(config.KeyProductionMode, false); err != nil {
exitWithError("Failed to disable production mode: %v", err)
}
fmt.Println("Production mode disabled")
})
},
}
// --- Add commands ---
var addCmd = &cobra.Command{
Use: "add",
Short: "Add a resource",
}
var (
addUserName string
addUserAdmin bool
)
var addUserCmd = &cobra.Command{
Use: "user <email> <password>",
Short: "Add a local user",
Long: `Add a new local user account to the database.`,
Args: cobra.ExactArgs(2),
Run: func(cmd *cobra.Command, args []string) {
email := args[0]
password := args[1]
name := addUserName
if name == "" {
// Use email prefix as name
if atIndex := strings.Index(email, "@"); atIndex > 0 {
name = email[:atIndex]
} else {
name = email
}
}
if len(password) < 8 {
exitWithError("Password must be at least 8 characters")
}
withConfig(func(cfg *config.Config, db *database.DB) {
// Check if user exists
var exists bool
err := db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ?)", email).Scan(&exists)
})
if err != nil {
exitWithError("Failed to check user: %v", err)
}
isAdmin := addUserAdmin
if exists {
if !configForce {
exitWithError("User with email %s already exists (use -f to override)", email)
}
// Confirm override
if !configYes && !confirm(fmt.Sprintf("User %s already exists. Override?", email)) {
fmt.Println("Aborted")
return
}
// Update existing user
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
exitWithError("Failed to hash password: %v", err)
}
err = db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
"UPDATE users SET name = ?, password_hash = ?, is_admin = ? WHERE email = ?",
name, string(hashedPassword), isAdmin, email,
)
return err
})
if err != nil {
exitWithError("Failed to update user: %v", err)
}
fmt.Printf("Updated user: %s (admin: %v)\n", email, isAdmin)
return
}
// Hash password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
exitWithError("Failed to hash password: %v", err)
}
// Check if first user (make admin)
var userCount int
db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
})
if userCount == 0 {
isAdmin = true
}
// Confirm creation
if !configYes && !confirm(fmt.Sprintf("Create user %s (admin: %v)?", email, isAdmin)) {
fmt.Println("Aborted")
return
}
// Create user
err = db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?)",
email, name, email, string(hashedPassword), isAdmin,
)
return err
})
if err != nil {
exitWithError("Failed to create user: %v", err)
}
fmt.Printf("Created user: %s (admin: %v)\n", email, isAdmin)
})
},
}
var addAPIKeyScope string
var addAPIKeyCmd = &cobra.Command{
Use: "apikey [name]",
Short: "Add a runner API key",
Long: `Generate a new API key for runner authentication.`,
Args: cobra.MaximumNArgs(1),
Run: func(cmd *cobra.Command, args []string) {
name := "cli-generated"
if len(args) > 0 {
name = args[0]
}
withConfig(func(cfg *config.Config, db *database.DB) {
// Check if API key with same name exists
var exists bool
err := db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM runner_api_keys WHERE name = ?)", name).Scan(&exists)
})
if err != nil {
exitWithError("Failed to check API key: %v", err)
}
if exists {
if !configForce {
exitWithError("API key with name %s already exists (use -f to create another)", name)
}
if !configYes && !confirm(fmt.Sprintf("API key named '%s' already exists. Create another?", name)) {
fmt.Println("Aborted")
return
}
}
// Confirm creation
if !configYes && !confirm(fmt.Sprintf("Generate new API key '%s' (scope: %s)?", name, addAPIKeyScope)) {
fmt.Println("Aborted")
return
}
// Generate API key
key, keyPrefix, keyHash, err := generateAPIKey()
if err != nil {
exitWithError("Failed to generate API key: %v", err)
}
// Get first user ID for created_by (or use 0 if no users)
var createdBy int64
db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT id FROM users ORDER BY id ASC LIMIT 1").Scan(&createdBy)
})
// Store in database
err = db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`INSERT INTO runner_api_keys (key_prefix, key_hash, name, scope, is_active, created_by)
VALUES (?, ?, ?, ?, true, ?)`,
keyPrefix, keyHash, name, addAPIKeyScope, createdBy,
)
return err
})
if err != nil {
exitWithError("Failed to store API key: %v", err)
}
fmt.Printf("Generated API key: %s\n", key)
fmt.Printf("Name: %s, Scope: %s\n", name, addAPIKeyScope)
fmt.Println("\nSave this key - it cannot be retrieved later!")
})
},
}
// --- Set commands ---
var setCmd = &cobra.Command{
Use: "set",
Short: "Set a configuration value",
}
var setFixedAPIKeyCmd = &cobra.Command{
Use: "fixed-apikey [key]",
Short: "Set a fixed API key for testing",
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
// Check if already set
existing := cfg.FixedAPIKey()
if existing != "" && !configForce {
exitWithError("Fixed API key already set (use -f to override)")
}
if existing != "" && !configYes && !confirm("Fixed API key already set. Override?") {
fmt.Println("Aborted")
return
}
if err := cfg.Set(config.KeyFixedAPIKey, args[0]); err != nil {
exitWithError("Failed to set fixed API key: %v", err)
}
fmt.Println("Fixed API key set")
})
},
}
var setAllowedOriginsCmd = &cobra.Command{
Use: "allowed-origins [origins]",
Short: "Set allowed CORS origins (comma-separated)",
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
if err := cfg.Set(config.KeyAllowedOrigins, args[0]); err != nil {
exitWithError("Failed to set allowed origins: %v", err)
}
fmt.Printf("Allowed origins set to: %s\n", args[0])
})
},
}
var setGoogleOAuthRedirectURL string
var setGoogleOAuthCmd = &cobra.Command{
Use: "google-oauth <client-id> <client-secret>",
Short: "Set Google OAuth credentials",
Args: cobra.ExactArgs(2),
Run: func(cmd *cobra.Command, args []string) {
clientID := args[0]
clientSecret := args[1]
withConfig(func(cfg *config.Config, db *database.DB) {
// Check if already configured
existing := cfg.GoogleClientID()
if existing != "" && !configForce {
exitWithError("Google OAuth already configured (use -f to override)")
}
if existing != "" && !configYes && !confirm("Google OAuth already configured. Override?") {
fmt.Println("Aborted")
return
}
if err := cfg.Set(config.KeyGoogleClientID, clientID); err != nil {
exitWithError("Failed to set Google client ID: %v", err)
}
if err := cfg.Set(config.KeyGoogleClientSecret, clientSecret); err != nil {
exitWithError("Failed to set Google client secret: %v", err)
}
if setGoogleOAuthRedirectURL != "" {
if err := cfg.Set(config.KeyGoogleRedirectURL, setGoogleOAuthRedirectURL); err != nil {
exitWithError("Failed to set Google redirect URL: %v", err)
}
}
fmt.Println("Google OAuth configured")
})
},
}
var setDiscordOAuthRedirectURL string
var setDiscordOAuthCmd = &cobra.Command{
Use: "discord-oauth <client-id> <client-secret>",
Short: "Set Discord OAuth credentials",
Args: cobra.ExactArgs(2),
Run: func(cmd *cobra.Command, args []string) {
clientID := args[0]
clientSecret := args[1]
withConfig(func(cfg *config.Config, db *database.DB) {
// Check if already configured
existing := cfg.DiscordClientID()
if existing != "" && !configForce {
exitWithError("Discord OAuth already configured (use -f to override)")
}
if existing != "" && !configYes && !confirm("Discord OAuth already configured. Override?") {
fmt.Println("Aborted")
return
}
if err := cfg.Set(config.KeyDiscordClientID, clientID); err != nil {
exitWithError("Failed to set Discord client ID: %v", err)
}
if err := cfg.Set(config.KeyDiscordClientSecret, clientSecret); err != nil {
exitWithError("Failed to set Discord client secret: %v", err)
}
if setDiscordOAuthRedirectURL != "" {
if err := cfg.Set(config.KeyDiscordRedirectURL, setDiscordOAuthRedirectURL); err != nil {
exitWithError("Failed to set Discord redirect URL: %v", err)
}
}
fmt.Println("Discord OAuth configured")
})
},
}
// --- Show command ---
var showCmd = &cobra.Command{
Use: "show",
Short: "Show current configuration",
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
all, err := cfg.GetAll()
if err != nil {
exitWithError("Failed to get config: %v", err)
}
if len(all) == 0 {
fmt.Println("No configuration stored")
return
}
fmt.Println("Current configuration:")
fmt.Println("----------------------")
for key, value := range all {
// Redact sensitive values
if strings.Contains(key, "secret") || strings.Contains(key, "api_key") || strings.Contains(key, "password") {
fmt.Printf(" %s: [REDACTED]\n", key)
} else {
fmt.Printf(" %s: %s\n", key, value)
}
}
})
},
}
// --- List commands ---
var listCmd = &cobra.Command{
Use: "list",
Short: "List resources",
}
var listUsersCmd = &cobra.Command{
Use: "users",
Short: "List all users",
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
var rows *sql.Rows
err := db.With(func(conn *sql.DB) error {
var err error
rows, err = conn.Query("SELECT id, email, name, oauth_provider, is_admin, created_at FROM users ORDER BY id")
return err
})
if err != nil {
exitWithError("Failed to list users: %v", err)
}
defer rows.Close()
fmt.Printf("%-6s %-30s %-20s %-10s %-6s %s\n", "ID", "Email", "Name", "Provider", "Admin", "Created")
fmt.Println(strings.Repeat("-", 100))
for rows.Next() {
var id int64
var email, name, provider string
var isAdmin bool
var createdAt string
if err := rows.Scan(&id, &email, &name, &provider, &isAdmin, &createdAt); err != nil {
continue
}
adminStr := "no"
if isAdmin {
adminStr = "yes"
}
fmt.Printf("%-6d %-30s %-20s %-10s %-6s %s\n", id, email, name, provider, adminStr, createdAt[:19])
}
})
},
}
var listAPIKeysCmd = &cobra.Command{
Use: "apikeys",
Short: "List all API keys",
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
var rows *sql.Rows
err := db.With(func(conn *sql.DB) error {
var err error
rows, err = conn.Query("SELECT id, key_prefix, name, scope, is_active, created_at FROM runner_api_keys ORDER BY id")
return err
})
if err != nil {
exitWithError("Failed to list API keys: %v", err)
}
defer rows.Close()
fmt.Printf("%-6s %-12s %-20s %-10s %-8s %s\n", "ID", "Prefix", "Name", "Scope", "Active", "Created")
fmt.Println(strings.Repeat("-", 80))
for rows.Next() {
var id int64
var prefix, name, scope string
var isActive bool
var createdAt string
if err := rows.Scan(&id, &prefix, &name, &scope, &isActive, &createdAt); err != nil {
continue
}
activeStr := "no"
if isActive {
activeStr = "yes"
}
fmt.Printf("%-6d %-12s %-20s %-10s %-8s %s\n", id, prefix, name, scope, activeStr, createdAt[:19])
}
})
},
}
func init() {
managerCmd.AddCommand(configCmd)
// Global config flags
configCmd.PersistentFlags().StringVar(&configDBPath, "db", "jiggablend.db", "Database path")
configCmd.PersistentFlags().BoolVarP(&configYes, "yes", "y", false, "Auto-confirm prompts")
configCmd.PersistentFlags().BoolVarP(&configForce, "force", "f", false, "Force override existing")
// Enable/Disable
configCmd.AddCommand(enableCmd)
configCmd.AddCommand(disableCmd)
enableCmd.AddCommand(enableLocalAuthCmd)
enableCmd.AddCommand(enableRegistrationCmd)
enableCmd.AddCommand(enableProductionCmd)
disableCmd.AddCommand(disableLocalAuthCmd)
disableCmd.AddCommand(disableRegistrationCmd)
disableCmd.AddCommand(disableProductionCmd)
// Add
configCmd.AddCommand(addCmd)
addCmd.AddCommand(addUserCmd)
addUserCmd.Flags().StringVarP(&addUserName, "name", "n", "", "User display name")
addUserCmd.Flags().BoolVarP(&addUserAdmin, "admin", "a", false, "Make user an admin")
addCmd.AddCommand(addAPIKeyCmd)
addAPIKeyCmd.Flags().StringVarP(&addAPIKeyScope, "scope", "s", "manager", "API key scope (manager or user)")
// Set
configCmd.AddCommand(setCmd)
setCmd.AddCommand(setFixedAPIKeyCmd)
setCmd.AddCommand(setAllowedOriginsCmd)
setCmd.AddCommand(setGoogleOAuthCmd)
setCmd.AddCommand(setDiscordOAuthCmd)
setGoogleOAuthCmd.Flags().StringVarP(&setGoogleOAuthRedirectURL, "redirect-url", "r", "", "Google OAuth redirect URL")
setDiscordOAuthCmd.Flags().StringVarP(&setDiscordOAuthRedirectURL, "redirect-url", "r", "", "Discord OAuth redirect URL")
// Show
configCmd.AddCommand(showCmd)
// List
configCmd.AddCommand(listCmd)
listCmd.AddCommand(listUsersCmd)
listCmd.AddCommand(listAPIKeysCmd)
}
// withConfig opens the database and runs the callback with config access
func withConfig(fn func(cfg *config.Config, db *database.DB)) {
db, err := database.NewDB(configDBPath)
if err != nil {
exitWithError("Failed to open database: %v", err)
}
defer db.Close()
cfg := config.NewConfig(db)
fn(cfg, db)
}
// generateAPIKey generates a new API key
func generateAPIKey() (key, prefix, hash string, err error) {
randomBytes := make([]byte, 16)
if _, err := rand.Read(randomBytes); err != nil {
return "", "", "", err
}
randomStr := hex.EncodeToString(randomBytes)
prefixDigit := make([]byte, 1)
if _, err := rand.Read(prefixDigit); err != nil {
return "", "", "", err
}
prefix = fmt.Sprintf("jk_r%d", prefixDigit[0]%10)
key = fmt.Sprintf("%s_%s", prefix, randomStr)
keyHash := sha256.Sum256([]byte(key))
hash = hex.EncodeToString(keyHash[:])
return key, prefix, hash, nil
}
// confirm prompts the user for confirmation
func confirm(prompt string) bool {
fmt.Printf("%s [y/N]: ", prompt)
reader := bufio.NewReader(os.Stdin)
response, err := reader.ReadString('\n')
if err != nil {
return false
}
response = strings.TrimSpace(strings.ToLower(response))
return response == "y" || response == "yes"
}

View File

@@ -0,0 +1,35 @@
package cmd
import (
"fmt"
"os"
"github.com/spf13/cobra"
)
var rootCmd = &cobra.Command{
Use: "jiggablend",
Short: "Jiggablend - Distributed Blender Render Farm",
Long: `Jiggablend is a distributed render farm for Blender.
Run 'jiggablend manager' to start the manager server.
Run 'jiggablend runner' to start a render runner.
Run 'jiggablend manager config' to configure the manager.`,
}
// Execute runs the root command
func Execute() error {
return rootCmd.Execute()
}
func init() {
// Global flags can be added here if needed
rootCmd.CompletionOptions.DisableDefaultCmd = true
}
// exitWithError prints an error and exits
func exitWithError(msg string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "Error: "+msg+"\n", args...)
os.Exit(1)
}

View File

@@ -0,0 +1,211 @@
package cmd
import (
"crypto/rand"
"encoding/hex"
"fmt"
"os"
"os/signal"
"strings"
"syscall"
"time"
"jiggablend/internal/logger"
"jiggablend/internal/runner"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
var runnerViper = viper.New()
var runnerCmd = &cobra.Command{
Use: "runner",
Short: "Start the Jiggablend render runner",
Long: `Start the Jiggablend render runner that connects to a manager and processes render tasks.`,
Run: runRunner,
}
func init() {
rootCmd.AddCommand(runnerCmd)
runnerCmd.Flags().StringP("manager", "m", "http://localhost:8080", "Manager URL")
runnerCmd.Flags().StringP("name", "n", "", "Runner name")
runnerCmd.Flags().String("hostname", "", "Runner hostname")
runnerCmd.Flags().StringP("api-key", "k", "", "API key for authentication")
runnerCmd.Flags().StringP("log-file", "l", "", "Log file path (truncated on start, if not set logs only to stdout)")
runnerCmd.Flags().String("log-level", "info", "Log level (debug, info, warn, error)")
runnerCmd.Flags().BoolP("verbose", "v", false, "Enable verbose logging (same as --log-level=debug)")
// Bind flags to viper with JIGGABLEND_ prefix
runnerViper.SetEnvPrefix("JIGGABLEND")
runnerViper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
runnerViper.AutomaticEnv()
runnerViper.BindPFlag("manager", runnerCmd.Flags().Lookup("manager"))
runnerViper.BindPFlag("name", runnerCmd.Flags().Lookup("name"))
runnerViper.BindPFlag("hostname", runnerCmd.Flags().Lookup("hostname"))
runnerViper.BindPFlag("api_key", runnerCmd.Flags().Lookup("api-key"))
runnerViper.BindPFlag("log_file", runnerCmd.Flags().Lookup("log-file"))
runnerViper.BindPFlag("log_level", runnerCmd.Flags().Lookup("log-level"))
runnerViper.BindPFlag("verbose", runnerCmd.Flags().Lookup("verbose"))
}
func runRunner(cmd *cobra.Command, args []string) {
// Get config values (flags take precedence over env vars)
managerURL := runnerViper.GetString("manager")
name := runnerViper.GetString("name")
hostname := runnerViper.GetString("hostname")
apiKey := runnerViper.GetString("api_key")
logFile := runnerViper.GetString("log_file")
logLevel := runnerViper.GetString("log_level")
verbose := runnerViper.GetBool("verbose")
var client *runner.Client
defer func() {
if r := recover(); r != nil {
logger.Errorf("Runner panicked: %v", r)
if client != nil {
client.CleanupWorkspace()
}
os.Exit(1)
}
}()
if hostname == "" {
hostname, _ = os.Hostname()
}
// Generate unique runner ID
runnerIDStr := generateShortID()
// Generate runner name with ID if not provided
if name == "" {
name = fmt.Sprintf("runner-%s-%s", hostname, runnerIDStr)
} else {
name = fmt.Sprintf("%s-%s", name, runnerIDStr)
}
// Initialize logger
if logFile != "" {
if err := logger.InitWithFile(logFile); err != nil {
logger.Fatalf("Failed to initialize logger: %v", err)
}
defer func() {
if l := logger.GetDefault(); l != nil {
l.Close()
}
}()
} else {
logger.InitStdout()
}
// Set log level
if verbose {
logger.SetLevel(logger.LevelDebug)
} else {
logger.SetLevel(logger.ParseLevel(logLevel))
}
logger.Info("Runner starting up...")
logger.Debugf("Generated runner ID suffix: %s", runnerIDStr)
if logFile != "" {
logger.Infof("Logging to file: %s", logFile)
}
client = runner.NewClient(managerURL, name, hostname)
// Clean up orphaned workspace directories
client.CleanupWorkspace()
// Probe capabilities
logger.Debug("Probing runner capabilities...")
client.ProbeCapabilities()
capabilities := client.GetCapabilities()
capList := []string{}
for cap, value := range capabilities {
if enabled, ok := value.(bool); ok && enabled {
capList = append(capList, cap)
} else if count, ok := value.(int); ok && count > 0 {
capList = append(capList, fmt.Sprintf("%s=%d", cap, count))
} else if count, ok := value.(float64); ok && count > 0 {
capList = append(capList, fmt.Sprintf("%s=%.0f", cap, count))
}
}
if len(capList) > 0 {
logger.Infof("Detected capabilities: %s", strings.Join(capList, ", "))
} else {
logger.Warn("No capabilities detected")
}
// Register with API key
if apiKey == "" {
logger.Fatal("API key required (use --api-key or set JIGGABLEND_API_KEY env var)")
}
// Retry registration with exponential backoff
backoff := 1 * time.Second
maxBackoff := 30 * time.Second
maxRetries := 10
retryCount := 0
var runnerID int64
for {
var err error
runnerID, _, _, err = client.Register(apiKey)
if err == nil {
logger.Infof("Registered runner with ID: %d", runnerID)
break
}
errMsg := err.Error()
if strings.Contains(errMsg, "token error:") {
logger.Fatalf("Registration failed (token error): %v", err)
}
retryCount++
if retryCount >= maxRetries {
logger.Fatalf("Failed to register runner after %d attempts: %v", maxRetries, err)
}
logger.Warnf("Registration failed (attempt %d/%d): %v, retrying in %v", retryCount, maxRetries, err, backoff)
time.Sleep(backoff)
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
}
// Start WebSocket connection
go client.ConnectWebSocketWithReconnect()
// Start heartbeat loop
go client.HeartbeatLoop()
logger.Info("Runner started, connecting to manager via WebSocket...")
// Signal handlers
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
go func() {
sig := <-sigChan
logger.Infof("Received signal: %v, killing all processes and cleaning up...", sig)
client.KillAllProcesses()
client.CleanupWorkspace()
os.Exit(0)
}()
// Block forever
select {}
}
func generateShortID() string {
bytes := make([]byte, 4)
if _, err := rand.Read(bytes); err != nil {
return fmt.Sprintf("%x", os.Getpid()^int(time.Now().Unix()))
}
return hex.EncodeToString(bytes)
}

14
cmd/jiggablend/main.go Normal file
View File

@@ -0,0 +1,14 @@
package main
import (
"os"
"jiggablend/cmd/jiggablend/cmd"
)
func main() {
if err := cmd.Execute(); err != nil {
os.Exit(1)
}
}

View File

@@ -1,125 +0,0 @@
package main
import (
"flag"
"fmt"
"log"
"net/http"
"os"
"os/exec"
"jiggablend/internal/api"
"jiggablend/internal/auth"
"jiggablend/internal/database"
"jiggablend/internal/logger"
"jiggablend/internal/storage"
)
func main() {
var (
port = flag.String("port", getEnv("PORT", "8080"), "Server port")
dbPath = flag.String("db", getEnv("DB_PATH", "jiggablend.db"), "Database path")
storagePath = flag.String("storage", getEnv("STORAGE_PATH", "./jiggablend-storage"), "Storage path")
logDir = flag.String("log-dir", getEnv("LOG_DIR", "./logs"), "Log directory")
logMaxSize = flag.Int("log-max-size", getEnvInt("LOG_MAX_SIZE", 100), "Maximum log file size in MB before rotation")
logMaxBackups = flag.Int("log-max-backups", getEnvInt("LOG_MAX_BACKUPS", 5), "Maximum number of rotated log files to keep")
logMaxAge = flag.Int("log-max-age", getEnvInt("LOG_MAX_AGE", 30), "Maximum age in days for rotated log files")
)
flag.Parse()
// Initialize logger (writes to both stdout and log file with rotation)
logDirPath := *logDir
if err := logger.Init(logDirPath, "manager.log", *logMaxSize, *logMaxBackups, *logMaxAge); err != nil {
log.Fatalf("Failed to initialize logger: %v", err)
}
defer func() {
if l := logger.GetDefault(); l != nil {
l.Close()
}
}()
log.Printf("Log rotation configured: max_size=%dMB, max_backups=%d, max_age=%d days", *logMaxSize, *logMaxBackups, *logMaxAge)
// Initialize database
db, err := database.NewDB(*dbPath)
if err != nil {
log.Fatalf("Failed to initialize database: %v", err)
}
defer db.Close()
// Initialize auth
authHandler, err := auth.NewAuth(db.DB)
if err != nil {
log.Fatalf("Failed to initialize auth: %v", err)
}
// Initialize storage
storageHandler, err := storage.NewStorage(*storagePath)
if err != nil {
log.Fatalf("Failed to initialize storage: %v", err)
}
// Check if Blender is available (required for metadata extraction)
if err := checkBlenderAvailable(); err != nil {
log.Fatalf("Blender is not available: %v\n"+
"The manager requires Blender to be installed and in PATH for metadata extraction.\n"+
"Please install Blender and ensure it's accessible via the 'blender' command.", err)
}
log.Printf("Blender is available")
// Create API server
server, err := api.NewServer(db, authHandler, storageHandler)
if err != nil {
log.Fatalf("Failed to create server: %v", err)
}
// Start server with increased request body size limit for large file uploads
addr := fmt.Sprintf(":%s", *port)
log.Printf("Starting manager server on %s", addr)
log.Printf("Database: %s", *dbPath)
log.Printf("Storage: %s", *storagePath)
httpServer := &http.Server{
Addr: addr,
Handler: server,
MaxHeaderBytes: 1 << 20, // 1 MB for headers
ReadTimeout: 0, // No read timeout (for large uploads)
WriteTimeout: 0, // No write timeout (for large uploads)
}
// Note: MaxRequestBodySize is not directly configurable in http.Server
// It's handled by ParseMultipartForm in handlers, which we've already configured
// But we need to ensure the server can handle large requests
// The default limit is 10MB, but we bypass it by using ParseMultipartForm with larger limit
if err := httpServer.ListenAndServe(); err != nil {
log.Fatalf("Server failed: %v", err)
}
}
func getEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
func getEnvInt(key string, defaultValue int) int {
if value := os.Getenv(key); value != "" {
var result int
if _, err := fmt.Sscanf(value, "%d", &result); err == nil {
return result
}
}
return defaultValue
}
// checkBlenderAvailable checks if Blender is available by running `blender --version`
func checkBlenderAvailable() error {
cmd := exec.Command("blender", "--version")
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to run 'blender --version': %w (output: %s)", err, string(output))
}
// If we got here, Blender is available
return nil
}

View File

@@ -1,207 +0,0 @@
package main
import (
"crypto/rand"
"encoding/hex"
"flag"
"fmt"
"log"
"os"
"os/signal"
"strings"
"syscall"
"time"
"jiggablend/internal/logger"
"jiggablend/internal/runner"
)
// Removed SecretsFile - runners now generate ephemeral instance IDs
func main() {
log.Printf("Runner starting up...")
// Create client early so we can clean it up on panic
var client *runner.Client
defer func() {
if r := recover(); r != nil {
log.Printf("Runner panicked: %v", r)
// Clean up workspace even on panic
if client != nil {
client.CleanupWorkspace()
}
os.Exit(1)
}
}()
var (
managerURL = flag.String("manager", getEnv("MANAGER_URL", "http://localhost:8080"), "Manager URL")
name = flag.String("name", getEnv("RUNNER_NAME", ""), "Runner name")
hostname = flag.String("hostname", getEnv("RUNNER_HOSTNAME", ""), "Runner hostname")
apiKeyFlag = flag.String("api-key", getEnv("API_KEY", ""), "API key for authentication")
logDir = flag.String("log-dir", getEnv("LOG_DIR", "./logs"), "Log directory")
logMaxSize = flag.Int("log-max-size", getEnvInt("LOG_MAX_SIZE", 100), "Maximum log file size in MB before rotation")
logMaxBackups = flag.Int("log-max-backups", getEnvInt("LOG_MAX_BACKUPS", 5), "Maximum number of rotated log files to keep")
logMaxAge = flag.Int("log-max-age", getEnvInt("LOG_MAX_AGE", 30), "Maximum age in days for rotated log files")
)
flag.Parse()
log.Printf("Flags parsed, hostname: %s", *hostname)
if *hostname == "" {
*hostname, _ = os.Hostname()
}
// Always generate a random runner ID suffix on startup
// This ensures every runner has a unique local identifier
runnerIDStr := generateShortID()
log.Printf("Generated runner ID suffix: %s", runnerIDStr)
// Generate runner name with ID if not provided
if *name == "" {
*name = fmt.Sprintf("runner-%s-%s", *hostname, runnerIDStr)
} else {
// Append ID to provided name to ensure uniqueness
*name = fmt.Sprintf("%s-%s", *name, runnerIDStr)
}
// Initialize logger (writes to both stdout and log file with rotation)
// Use runner-specific log file name based on the final name
sanitizedName := strings.ReplaceAll(*name, "/", "_")
sanitizedName = strings.ReplaceAll(sanitizedName, "\\", "_")
logFileName := fmt.Sprintf("runner-%s.log", sanitizedName)
if err := logger.Init(*logDir, logFileName, *logMaxSize, *logMaxBackups, *logMaxAge); err != nil {
log.Fatalf("Failed to initialize logger: %v", err)
}
defer func() {
if l := logger.GetDefault(); l != nil {
l.Close()
}
}()
log.Printf("Logger initialized, continuing with startup...")
log.Printf("Log rotation configured: max_size=%dMB, max_backups=%d, max_age=%d days", *logMaxSize, *logMaxBackups, *logMaxAge)
log.Printf("About to create client...")
client = runner.NewClient(*managerURL, *name, *hostname)
log.Printf("Client created successfully")
// Clean up any orphaned workspace directories from previous runs
client.CleanupWorkspace()
// Probe capabilities once at startup (before any registration attempts)
log.Printf("Probing runner capabilities...")
client.ProbeCapabilities()
capabilities := client.GetCapabilities()
capList := []string{}
for cap, value := range capabilities {
// Only show boolean true capabilities and numeric GPU counts
if enabled, ok := value.(bool); ok && enabled {
capList = append(capList, cap)
} else if count, ok := value.(int); ok && count > 0 {
capList = append(capList, fmt.Sprintf("%s=%d", cap, count))
} else if count, ok := value.(float64); ok && count > 0 {
capList = append(capList, fmt.Sprintf("%s=%.0f", cap, count))
}
}
if len(capList) > 0 {
log.Printf("Detected capabilities: %s", strings.Join(capList, ", "))
} else {
log.Printf("Warning: No capabilities detected")
}
// Register with API key (with retry logic)
if *apiKeyFlag == "" {
log.Fatalf("API key required (use --api-key or set API_KEY env var)")
}
// Retry registration with exponential backoff
backoff := 1 * time.Second
maxBackoff := 30 * time.Second
maxRetries := 10
retryCount := 0
var runnerID int64
for {
var err error
runnerID, _, _, err = client.Register(*apiKeyFlag)
if err == nil {
log.Printf("Registered runner with ID: %d", runnerID)
break
}
// Check if it's a token error (invalid/expired/used token) - shutdown immediately
errMsg := err.Error()
if strings.Contains(errMsg, "token error:") {
log.Fatalf("Registration failed (token error): %v", err)
}
// Only retry on connection errors or other retryable errors
retryCount++
if retryCount >= maxRetries {
log.Fatalf("Failed to register runner after %d attempts: %v", maxRetries, err)
}
log.Printf("Registration failed (attempt %d/%d): %v, retrying in %v", retryCount, maxRetries, err, backoff)
time.Sleep(backoff)
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
}
// Start WebSocket connection with reconnection
go client.ConnectWebSocketWithReconnect()
// Start heartbeat loop (for WebSocket ping/pong and HTTP fallback)
go client.HeartbeatLoop()
// ProcessTasks is now handled via WebSocket, but kept for HTTP fallback
// WebSocket will handle task assignment automatically
log.Printf("Runner started, connecting to manager via WebSocket...")
// Set up signal handlers to kill processes on shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
go func() {
sig := <-sigChan
log.Printf("Received signal: %v, killing all processes and cleaning up...", sig)
client.KillAllProcesses()
// Cleanup happens in defer, but also do it here for good measure
client.CleanupWorkspace()
os.Exit(0)
}()
// Block forever
select {}
}
func getEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
func getEnvInt(key string, defaultValue int) int {
if value := os.Getenv(key); value != "" {
var result int
if _, err := fmt.Sscanf(value, "%d", &result); err == nil {
return result
}
}
return defaultValue
}
// generateShortID generates a short random ID (8 hex characters)
func generateShortID() string {
bytes := make([]byte, 4)
if _, err := rand.Read(bytes); err != nil {
// Fallback to timestamp-based ID if crypto/rand fails
return fmt.Sprintf("%x", os.Getpid()^int(time.Now().Unix()))
}
return hex.EncodeToString(bytes)
}

38
go.mod
View File

@@ -7,34 +7,28 @@ require (
github.com/go-chi/cors v1.2.2
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/marcboeker/go-duckdb/v2 v2.4.3
github.com/mattn/go-sqlite3 v1.14.22
github.com/spf13/cobra v1.10.1
github.com/spf13/viper v1.21.0
golang.org/x/crypto v0.45.0
golang.org/x/oauth2 v0.33.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
)
require (
cloud.google.com/go/compute/metadata v0.3.0 // indirect
github.com/apache/arrow-go/v18 v18.4.1 // indirect
github.com/duckdb/duckdb-go-bindings v0.1.21 // indirect
github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.21 // indirect
github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.21 // indirect
github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.21 // indirect
github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.21 // indirect
github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.21 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/google/flatbuffers v25.2.10+incompatible // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/marcboeker/go-duckdb/arrowmapping v0.0.21 // indirect
github.com/marcboeker/go-duckdb/mapping v0.0.21 // indirect
github.com/pierrec/lz4/v4 v4.1.22 // indirect
github.com/zeebo/xxh3 v1.0.2 // indirect
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 // indirect
golang.org/x/mod v0.27.0 // indirect
golang.org/x/sync v0.16.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/sagikazarmark/locafero v0.11.0 // indirect
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
github.com/spf13/afero v1.15.0 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/tools v0.36.0 // indirect
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect
golang.org/x/text v0.31.0 // indirect
)

106
go.sum
View File

@@ -1,88 +1,70 @@
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/apache/arrow-go/v18 v18.4.1 h1:q/jVkBWCJOB9reDgaIZIdruLQUb1kbkvOnOFezVH1C4=
github.com/apache/arrow-go/v18 v18.4.1/go.mod h1:tLyFubsAl17bvFdUAy24bsSvA/6ww95Iqi67fTpGu3E=
github.com/apache/thrift v0.22.0 h1:r7mTJdj51TMDe6RtcmNdQxgn9XcyfGDOzegMDRg47uc=
github.com/apache/thrift v0.22.0/go.mod h1:1e7J/O1Ae6ZQMTYdy9xa3w9k+XHWPfRvdPyJeynQ+/g=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/duckdb/duckdb-go-bindings v0.1.21 h1:bOb/MXNT4PN5JBZ7wpNg6hrj9+cuDjWDa4ee9UdbVyI=
github.com/duckdb/duckdb-go-bindings v0.1.21/go.mod h1:pBnfviMzANT/9hi4bg+zW4ykRZZPCXlVuvBWEcZofkc=
github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.21 h1:Sjjhf2F/zCjPF53c2VXOSKk0PzieMriSoyr5wfvr9d8=
github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.21/go.mod h1:Ezo7IbAfB8NP7CqPIN8XEHKUg5xdRRQhcPPlCXImXYA=
github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.21 h1:IUk0FFUB6dpWLhlN9hY1mmdPX7Hkn3QpyrAmn8pmS8g=
github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.21/go.mod h1:eS7m/mLnPQgVF4za1+xTyorKRBuK0/BA44Oy6DgrGXI=
github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.21 h1:Qpc7ZE3n6Nwz30KTvaAwI6nGkXjXmMxBTdFpC8zDEYI=
github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.21/go.mod h1:1GOuk1PixiESxLaCGFhag+oFi7aP+9W8byymRAvunBk=
github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.21 h1:eX2DhobAZOgjXkh8lPnKAyrxj8gXd2nm+K71f6KV/mo=
github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.21/go.mod h1:o7crKMpT2eOIi5/FY6HPqaXcvieeLSqdXXaXbruGX7w=
github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.21 h1:hhziFnGV7mpA+v5J5G2JnYQ+UWCCP3NQ+OTvxFX10D8=
github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.21/go.mod h1:IlOhJdVKUJCAPj3QsDszUo8DVdvp1nBFp4TUJVdw99s=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE=
github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
github.com/go-chi/cors v1.2.2 h1:Jmey33TE+b+rB7fT8MUy1u0I4L+NARQlK6LhzKPSyQE=
github.com/go-chi/cors v1.2.2/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/flatbuffers v25.2.10+incompatible h1:F3vclr7C3HpB1k9mxCGRMXq6FdUalZ6H/pNX4FP1v0Q=
github.com/google/flatbuffers v25.2.10+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4=
github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/marcboeker/go-duckdb/arrowmapping v0.0.21 h1:geHnVjlsAJGczSWEqYigy/7ARuD+eBtjd0kLN80SPJQ=
github.com/marcboeker/go-duckdb/arrowmapping v0.0.21/go.mod h1:flFTc9MSqQCh2Xm62RYvG3Kyj29h7OtsTb6zUx1CdK8=
github.com/marcboeker/go-duckdb/mapping v0.0.21 h1:6woNXZn8EfYdc9Vbv0qR6acnt0TM1s1eFqnrJZVrqEs=
github.com/marcboeker/go-duckdb/mapping v0.0.21/go.mod h1:q3smhpLyv2yfgkQd7gGHMd+H/Z905y+WYIUjrl29vT4=
github.com/marcboeker/go-duckdb/v2 v2.4.3 h1:bHUkphPsAp2Bh/VFEdiprGpUekxBNZiWWtK+Bv/ljRk=
github.com/marcboeker/go-duckdb/v2 v2.4.3/go.mod h1:taim9Hktg2igHdNBmg5vgTfHAlV26z3gBI0QXQOcuyI=
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs=
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY=
github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI=
github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE=
github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU=
github.com/pierrec/lz4/v4 v4.1.22/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.0 h1:ib4sjIrwZKxE5u/Japgo/7SJV3PvgjGiRNAvTVGqQl8=
github.com/stretchr/testify v1.11.0/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ=
github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0=
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s=
github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0=
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM=
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8=
golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ=
golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc=
golang.org/x/oauth2 v0.33.0 h1:4Q+qn+E5z8gPRJfmRy7C2gGG3T4jIprK6aSYgTXGRpo=
golang.org/x/oauth2 v0.33.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg=
golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s=
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY=
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -131,14 +131,19 @@ func (s *Server) handleVerifyRunner(w http.ResponseWriter, r *http.Request) {
// Check if runner exists
var exists bool
err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists)
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists)
})
if err != nil || !exists {
s.respondError(w, http.StatusNotFound, "Runner not found")
return
}
// Mark runner as verified
_, err = s.db.Exec("UPDATE runners SET verified = 1 WHERE id = ?", runnerID)
err = s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec("UPDATE runners SET verified = 1 WHERE id = ?", runnerID)
return err
})
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify runner: %v", err))
return
@@ -157,14 +162,19 @@ func (s *Server) handleDeleteRunner(w http.ResponseWriter, r *http.Request) {
// Check if runner exists
var exists bool
err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists)
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists)
})
if err != nil || !exists {
s.respondError(w, http.StatusNotFound, "Runner not found")
return
}
// Delete runner
_, err = s.db.Exec("DELETE FROM runners WHERE id = ?", runnerID)
err = s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec("DELETE FROM runners WHERE id = ?", runnerID)
return err
})
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to delete runner: %v", err))
return
@@ -175,17 +185,31 @@ func (s *Server) handleDeleteRunner(w http.ResponseWriter, r *http.Request) {
// handleListRunnersAdmin lists all runners with admin details
func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request) {
rows, err := s.db.Query(
var rows *sql.Rows
err := s.db.With(func(conn *sql.DB) error {
var err error
rows, err = conn.Query(
`SELECT id, name, hostname, status, last_heartbeat, capabilities,
api_key_id, api_key_scope, priority, created_at
FROM runners ORDER BY created_at DESC`,
)
return err
})
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query runners: %v", err))
return
}
defer rows.Close()
// Get the set of currently connected runners via WebSocket
// This is the source of truth for online status
s.runnerConnsMu.RLock()
connectedRunners := make(map[int64]bool)
for runnerID := range s.runnerConns {
connectedRunners[runnerID] = true
}
s.runnerConnsMu.RUnlock()
runners := []map[string]interface{}{}
for rows.Next() {
var runner types.Runner
@@ -202,11 +226,21 @@ func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request)
return
}
// Override status based on actual WebSocket connection state
// The WebSocket connection is the source of truth for runner status
actualStatus := runner.Status
if connectedRunners[runner.ID] {
actualStatus = types.RunnerStatusOnline
} else if runner.Status == types.RunnerStatusOnline {
// Database says online but not connected via WebSocket - mark as offline
actualStatus = types.RunnerStatusOffline
}
runners = append(runners, map[string]interface{}{
"id": runner.ID,
"name": runner.Name,
"hostname": runner.Hostname,
"status": runner.Status,
"status": actualStatus,
"last_heartbeat": runner.LastHeartbeat,
"capabilities": runner.Capabilities,
"api_key_id": apiKeyID.Int64,
@@ -228,10 +262,15 @@ func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) {
firstUserID = 0
}
rows, err := s.db.Query(
var rows *sql.Rows
err = s.db.With(func(conn *sql.DB) error {
var err error
rows, err = conn.Query(
`SELECT id, email, name, oauth_provider, is_admin, created_at
FROM users ORDER BY created_at DESC`,
)
return err
})
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query users: %v", err))
return
@@ -253,7 +292,9 @@ func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) {
// Get job count for this user
var jobCount int
err = s.db.QueryRow("SELECT COUNT(*) FROM jobs WHERE user_id = ?", userID).Scan(&jobCount)
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT COUNT(*) FROM jobs WHERE user_id = ?", userID).Scan(&jobCount)
})
if err != nil {
jobCount = 0 // Default to 0 if query fails
}
@@ -283,18 +324,25 @@ func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) {
// Verify user exists
var exists bool
err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", userID).Scan(&exists)
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", userID).Scan(&exists)
})
if err != nil || !exists {
s.respondError(w, http.StatusNotFound, "User not found")
return
}
rows, err := s.db.Query(
var rows *sql.Rows
err = s.db.With(func(conn *sql.DB) error {
var err error
rows, err = conn.Query(
`SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format,
allow_parallel_runners, timeout_seconds, blend_metadata, created_at, started_at, completed_at, error_message
FROM jobs WHERE user_id = ? ORDER BY created_at DESC`,
userID,
)
return err
})
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query jobs: %v", err))
return

File diff suppressed because it is too large Load Diff

View File

@@ -2,8 +2,6 @@ package api
import (
"archive/tar"
"bufio"
"bytes"
"database/sql"
_ "embed"
"encoding/json"
@@ -13,10 +11,10 @@ import (
"log"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"jiggablend/pkg/executils"
"jiggablend/pkg/scripts"
"jiggablend/pkg/types"
)
@@ -44,7 +42,9 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) {
// Verify job exists
var jobUserID int64
err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID)
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID)
})
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "Job not found")
return
@@ -57,30 +57,36 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) {
// Find the metadata extraction task for this job
// First try to find task assigned to this runner, then fall back to any metadata task for this job
var taskID int64
err = s.db.QueryRow(
err = s.db.With(func(conn *sql.DB) error {
err := conn.QueryRow(
`SELECT id FROM tasks WHERE job_id = ? AND task_type = ? AND runner_id = ?`,
jobID, types.TaskTypeMetadata, runnerID,
).Scan(&taskID)
if err == sql.ErrNoRows {
// Fall back to any metadata task for this job (in case assignment changed)
err = s.db.QueryRow(
err = conn.QueryRow(
`SELECT id FROM tasks WHERE job_id = ? AND task_type = ? ORDER BY created_at DESC LIMIT 1`,
jobID, types.TaskTypeMetadata,
).Scan(&taskID)
if err == sql.ErrNoRows {
return fmt.Errorf("metadata extraction task not found")
}
if err != nil {
return err
}
// Update the task to be assigned to this runner if it wasn't already
conn.Exec(
`UPDATE tasks SET runner_id = ? WHERE id = ? AND runner_id IS NULL`,
runnerID, taskID,
)
}
return err
})
if err != nil {
if err.Error() == "metadata extraction task not found" {
s.respondError(w, http.StatusNotFound, "Metadata extraction task not found")
return
}
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to find task: %v", err))
return
}
// Update the task to be assigned to this runner if it wasn't already
s.db.Exec(
`UPDATE tasks SET runner_id = ? WHERE id = ? AND runner_id IS NULL`,
runnerID, taskID,
)
} else if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to find task: %v", err))
return
}
@@ -93,20 +99,27 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) {
}
// Update job with metadata
_, err = s.db.Exec(
err = s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`UPDATE jobs SET blend_metadata = ? WHERE id = ?`,
string(metadataJSON), jobID,
)
return err
})
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update job metadata: %v", err))
return
}
// Mark task as completed
_, err = s.db.Exec(
`UPDATE tasks SET status = ?, completed_at = CURRENT_TIMESTAMP WHERE id = ?`,
types.TaskStatusCompleted, taskID,
)
err = s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(`UPDATE tasks SET status = ? WHERE id = ?`, types.TaskStatusCompleted, taskID)
if err != nil {
return err
}
_, err = conn.Exec(`UPDATE tasks SET completed_at = CURRENT_TIMESTAMP WHERE id = ?`, taskID)
return err
})
if err != nil {
log.Printf("Failed to mark metadata task as completed: %v", err)
} else {
@@ -136,10 +149,12 @@ func (s *Server) handleGetJobMetadata(w http.ResponseWriter, r *http.Request) {
// Verify job belongs to user
var jobUserID int64
var blendMetadataJSON sql.NullString
err = s.db.QueryRow(
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT user_id, blend_metadata FROM jobs WHERE id = ?`,
jobID,
).Scan(&jobUserID, &blendMetadataJSON)
})
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "Job not found")
return
@@ -245,64 +260,23 @@ func (s *Server) extractMetadataFromContext(jobID int64) (*types.BlendMetadata,
return nil, fmt.Errorf("failed to get relative path for blend file: %w", err)
}
// Execute Blender with Python script
cmd := exec.Command("blender", "-b", blendFileRel, "--python", "extract_metadata.py")
cmd.Dir = tmpDir
// Execute Blender with Python script using executils
result, err := executils.RunCommand(
"blender",
[]string{"-b", blendFileRel, "--python", "extract_metadata.py"},
tmpDir,
nil, // inherit environment
jobID,
nil, // no process tracker needed for metadata extraction
)
// Capture stdout and stderr
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
}
stderrPipe, err := cmd.StderrPipe()
if err != nil {
return nil, fmt.Errorf("failed to create stderr pipe: %w", err)
}
// Buffer to collect stdout for JSON parsing
var stdoutBuffer bytes.Buffer
// Start the command
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start blender: %w", err)
}
// Stream stdout and collect for JSON parsing
stdoutDone := make(chan bool)
go func() {
defer close(stdoutDone)
scanner := bufio.NewScanner(stdoutPipe)
for scanner.Scan() {
line := scanner.Text()
stdoutBuffer.WriteString(line)
stdoutBuffer.WriteString("\n")
stderrOutput := ""
stdoutOutput := ""
if result != nil {
stderrOutput = strings.TrimSpace(result.Stderr)
stdoutOutput = strings.TrimSpace(result.Stdout)
}
}()
// Capture stderr for error reporting
var stderrBuffer bytes.Buffer
stderrDone := make(chan bool)
go func() {
defer close(stderrDone)
scanner := bufio.NewScanner(stderrPipe)
for scanner.Scan() {
line := scanner.Text()
stderrBuffer.WriteString(line)
stderrBuffer.WriteString("\n")
}
}()
// Wait for command to complete
err = cmd.Wait()
// Wait for streaming goroutines to finish
<-stdoutDone
<-stderrDone
if err != nil {
stderrOutput := strings.TrimSpace(stderrBuffer.String())
stdoutOutput := strings.TrimSpace(stdoutBuffer.String())
log.Printf("Blender metadata extraction failed for job %d:", jobID)
if stderrOutput != "" {
log.Printf("Blender stderr: %s", stderrOutput)
@@ -317,7 +291,7 @@ func (s *Server) extractMetadataFromContext(jobID int64) (*types.BlendMetadata,
}
// Parse output (metadata is printed to stdout)
metadataJSON := strings.TrimSpace(stdoutBuffer.String())
metadataJSON := strings.TrimSpace(result.Stdout)
// Extract JSON from output (Blender may print other stuff)
jsonStart := strings.Index(metadataJSON, "{")
jsonEnd := strings.LastIndex(metadataJSON, "}")

File diff suppressed because it is too large Load Diff

View File

@@ -10,15 +10,18 @@ import (
"net/http"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"time"
authpkg "jiggablend/internal/auth"
"jiggablend/internal/config"
"jiggablend/internal/database"
"jiggablend/internal/storage"
"jiggablend/pkg/types"
"jiggablend/web"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
@@ -26,79 +29,299 @@ import (
"github.com/gorilla/websocket"
)
// Configuration constants
const (
// WebSocket timeouts
WSReadDeadline = 90 * time.Second
WSPingInterval = 30 * time.Second
WSWriteDeadline = 10 * time.Second
// Task timeouts
DefaultTaskTimeout = 300 // 5 minutes for frame rendering
VideoGenerationTimeout = 86400 // 24 hours for video generation
DefaultJobTimeout = 86400 // 24 hours
// Limits
MaxFrameRange = 10000
MaxUploadSize = 50 << 30 // 50 GB
RunnerHeartbeatTimeout = 90 * time.Second
TaskDistributionInterval = 10 * time.Second
ProgressUpdateThrottle = 2 * time.Second
// Cookie settings
SessionCookieMaxAge = 86400 // 24 hours
)
// Server represents the API server
type Server struct {
db *database.DB
cfg *config.Config
auth *authpkg.Auth
secrets *authpkg.Secrets
storage *storage.Storage
router *chi.Mux
// WebSocket connections
wsUpgrader websocket.Upgrader
runnerConns map[int64]*websocket.Conn
runnerConnsMu sync.RWMutex
wsUpgrader websocket.Upgrader
runnerConns map[int64]*websocket.Conn
runnerConnsMu sync.RWMutex
// Mutexes for each runner connection to serialize writes
runnerConnsWriteMu map[int64]*sync.Mutex
runnerConnsWriteMuMu sync.RWMutex
frontendConns map[string]*websocket.Conn // key: "jobId:taskId"
frontendConnsMu sync.RWMutex
// Mutexes for each frontend connection to serialize writes
frontendConnsWriteMu map[string]*sync.Mutex // key: "jobId:taskId"
// DEPRECATED: Old WebSocket connection maps (kept for backwards compatibility)
// These will be removed in a future release. Use clientConns instead.
frontendConns map[string]*websocket.Conn // key: "jobId:taskId"
frontendConnsMu sync.RWMutex
frontendConnsWriteMu map[string]*sync.Mutex
frontendConnsWriteMuMu sync.RWMutex
// Job list WebSocket connections (key: userID)
jobListConns map[int64]*websocket.Conn
jobListConnsMu sync.RWMutex
// Single job WebSocket connections (key: "userId:jobId")
jobConns map[string]*websocket.Conn
jobConnsMu sync.RWMutex
// Mutexes for job WebSocket connections
jobConnsWriteMu map[string]*sync.Mutex
jobConnsWriteMuMu sync.RWMutex
jobListConns map[int64]*websocket.Conn
jobListConnsMu sync.RWMutex
jobConns map[string]*websocket.Conn
jobConnsMu sync.RWMutex
jobConnsWriteMu map[string]*sync.Mutex
jobConnsWriteMuMu sync.RWMutex
// Throttling for progress updates (per job)
progressUpdateTimes map[int64]time.Time // key: jobID
progressUpdateTimesMu sync.RWMutex
// Throttling for task status updates (per task)
taskUpdateTimes map[int64]time.Time // key: taskID
taskUpdateTimesMu sync.RWMutex
// Task distribution serialization
taskDistMu sync.Mutex // Mutex to prevent concurrent distribution
// Client WebSocket connections (new unified WebSocket)
clientConns map[int64]*ClientConnection // key: userID
clientConnsMu sync.RWMutex
// Upload session tracking
uploadSessions map[string]*UploadSession // sessionId -> session info
uploadSessionsMu sync.RWMutex
// Verbose WebSocket logging (set to true to enable detailed WebSocket logs)
verboseWSLogging bool
// Server start time for health checks
startTime time.Time
}
// ClientConnection represents a client WebSocket connection with subscriptions
type ClientConnection struct {
Conn *websocket.Conn
UserID int64
IsAdmin bool
Subscriptions map[string]bool // channel -> subscribed
SubsMu sync.RWMutex // Protects Subscriptions map
WriteMu *sync.Mutex
}
// UploadSession tracks upload and processing progress
type UploadSession struct {
SessionID string
UserID int64
Progress float64
Status string // "uploading", "processing", "extracting_metadata", "creating_context", "completed", "error"
Message string
CreatedAt time.Time
}
// NewServer creates a new API server
func NewServer(db *database.DB, auth *authpkg.Auth, storage *storage.Storage) (*Server, error) {
secrets, err := authpkg.NewSecrets(db.DB)
func NewServer(db *database.DB, cfg *config.Config, auth *authpkg.Auth, storage *storage.Storage) (*Server, error) {
secrets, err := authpkg.NewSecrets(db, cfg)
if err != nil {
return nil, fmt.Errorf("failed to initialize secrets: %w", err)
}
s := &Server{
db: db,
auth: auth,
secrets: secrets,
storage: storage,
router: chi.NewRouter(),
db: db,
cfg: cfg,
auth: auth,
secrets: secrets,
storage: storage,
router: chi.NewRouter(),
startTime: time.Now(),
wsUpgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins for now
},
CheckOrigin: checkWebSocketOrigin,
ReadBufferSize: 1024,
WriteBufferSize: 1024,
},
runnerConns: make(map[int64]*websocket.Conn),
runnerConnsWriteMu: make(map[int64]*sync.Mutex),
runnerConns: make(map[int64]*websocket.Conn),
runnerConnsWriteMu: make(map[int64]*sync.Mutex),
// DEPRECATED: Initialize old WebSocket maps for backward compatibility
frontendConns: make(map[string]*websocket.Conn),
frontendConnsWriteMu: make(map[string]*sync.Mutex),
jobListConns: make(map[int64]*websocket.Conn),
jobConns: make(map[string]*websocket.Conn),
jobConnsWriteMu: make(map[string]*sync.Mutex),
progressUpdateTimes: make(map[int64]time.Time),
taskUpdateTimes: make(map[int64]time.Time),
clientConns: make(map[int64]*ClientConnection),
uploadSessions: make(map[string]*UploadSession),
}
s.setupMiddleware()
s.setupRoutes()
s.StartBackgroundTasks()
// On startup, check for runners that are marked online but not actually connected
// This handles the case where the manager restarted and lost track of connections
go s.recoverRunnersOnStartup()
return s, nil
}
// checkWebSocketOrigin validates WebSocket connection origins
// In production mode, only allows same-origin connections or configured allowed origins
func checkWebSocketOrigin(r *http.Request) bool {
origin := r.Header.Get("Origin")
if origin == "" {
// No origin header - allow (could be non-browser client like runner)
return true
}
// In development mode, allow all origins
// Note: This function doesn't have access to Server, so we use authpkg.IsProductionMode()
// which checks environment variable. The server setup uses s.cfg.IsProductionMode() for consistency.
if !authpkg.IsProductionMode() {
return true
}
// In production, check against allowed origins
allowedOrigins := os.Getenv("ALLOWED_ORIGINS")
if allowedOrigins == "" {
// Default to same-origin only
host := r.Host
return strings.HasSuffix(origin, "://"+host) || strings.HasSuffix(origin, "://"+strings.Split(host, ":")[0])
}
// Check against configured allowed origins
for _, allowed := range strings.Split(allowedOrigins, ",") {
allowed = strings.TrimSpace(allowed)
if allowed == "*" {
return true
}
if origin == allowed {
return true
}
}
log.Printf("WebSocket origin rejected: %s (allowed: %s)", origin, allowedOrigins)
return false
}
// RateLimiter provides simple in-memory rate limiting per IP
type RateLimiter struct {
requests map[string][]time.Time
mu sync.RWMutex
limit int // max requests
window time.Duration // time window
}
// NewRateLimiter creates a new rate limiter
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
rl := &RateLimiter{
requests: make(map[string][]time.Time),
limit: limit,
window: window,
}
// Start cleanup goroutine
go rl.cleanup()
return rl
}
// Allow checks if a request from the given IP is allowed
func (rl *RateLimiter) Allow(ip string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
cutoff := now.Add(-rl.window)
// Get existing requests and filter old ones
reqs := rl.requests[ip]
validReqs := make([]time.Time, 0, len(reqs))
for _, t := range reqs {
if t.After(cutoff) {
validReqs = append(validReqs, t)
}
}
// Check if under limit
if len(validReqs) >= rl.limit {
rl.requests[ip] = validReqs
return false
}
// Add this request
validReqs = append(validReqs, now)
rl.requests[ip] = validReqs
return true
}
// cleanup periodically removes old entries
func (rl *RateLimiter) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
rl.mu.Lock()
cutoff := time.Now().Add(-rl.window)
for ip, reqs := range rl.requests {
validReqs := make([]time.Time, 0, len(reqs))
for _, t := range reqs {
if t.After(cutoff) {
validReqs = append(validReqs, t)
}
}
if len(validReqs) == 0 {
delete(rl.requests, ip)
} else {
rl.requests[ip] = validReqs
}
}
rl.mu.Unlock()
}
}
// Global rate limiters for different endpoint types
var (
// General API rate limiter: 100 requests per minute per IP
apiRateLimiter = NewRateLimiter(100, time.Minute)
// Auth rate limiter: 10 requests per minute per IP (stricter for login attempts)
authRateLimiter = NewRateLimiter(10, time.Minute)
)
// rateLimitMiddleware applies rate limiting based on client IP
func rateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get client IP (handle proxied requests)
ip := r.RemoteAddr
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
// Take the first IP in the chain
if idx := strings.Index(forwarded, ","); idx != -1 {
ip = strings.TrimSpace(forwarded[:idx])
} else {
ip = strings.TrimSpace(forwarded)
}
} else if realIP := r.Header.Get("X-Real-IP"); realIP != "" {
ip = strings.TrimSpace(realIP)
}
if !limiter.Allow(ip) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Retry-After", "60")
w.WriteHeader(http.StatusTooManyRequests)
json.NewEncoder(w).Encode(map[string]string{
"error": "Rate limit exceeded. Please try again later.",
})
return
}
next.ServeHTTP(w, r)
})
}
}
// setupMiddleware configures middleware
func (s *Server) setupMiddleware() {
s.router.Use(middleware.Logger)
@@ -106,17 +329,47 @@ func (s *Server) setupMiddleware() {
// Note: Timeout middleware is NOT applied globally to avoid conflicts with WebSocket connections
// WebSocket connections are long-lived and should not have HTTP timeouts
// Check production mode from config
isProduction := s.cfg.IsProductionMode()
// Add rate limiting (applied in production mode only, or when explicitly enabled)
if isProduction || os.Getenv("ENABLE_RATE_LIMITING") == "true" {
s.router.Use(rateLimitMiddleware(apiRateLimiter))
log.Printf("Rate limiting enabled: 100 requests/minute per IP")
}
// Add gzip compression for JSON responses
s.router.Use(gzipMiddleware)
s.router.Use(cors.Handler(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
// Configure CORS based on environment
corsOptions := cors.Options{
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "Range", "If-None-Match"},
ExposedHeaders: []string{"Link", "Content-Range", "Accept-Ranges", "Content-Length", "ETag"},
AllowCredentials: true,
MaxAge: 300,
}))
}
// In production, restrict CORS origins
if isProduction {
allowedOrigins := s.cfg.AllowedOrigins()
if allowedOrigins != "" {
corsOptions.AllowedOrigins = strings.Split(allowedOrigins, ",")
for i := range corsOptions.AllowedOrigins {
corsOptions.AllowedOrigins[i] = strings.TrimSpace(corsOptions.AllowedOrigins[i])
}
} else {
// Default to no origins in production if not configured
// This effectively disables cross-origin requests
corsOptions.AllowedOrigins = []string{}
}
log.Printf("Production mode: CORS restricted to origins: %v", corsOptions.AllowedOrigins)
} else {
// Development mode: allow all origins
corsOptions.AllowedOrigins = []string{"*"}
}
s.router.Use(cors.Handler(corsOptions))
}
// gzipMiddleware compresses responses with gzip if client supports it
@@ -164,8 +417,15 @@ func (w *gzipResponseWriter) WriteHeader(statusCode int) {
// setupRoutes configures routes
func (s *Server) setupRoutes() {
// Public routes
// Health check endpoint (unauthenticated)
s.router.Get("/api/health", s.handleHealthCheck)
// Public routes (with stricter rate limiting for auth endpoints)
s.router.Route("/api/auth", func(r chi.Router) {
// Apply stricter rate limiting to auth endpoints in production
if s.cfg.IsProductionMode() || os.Getenv("ENABLE_RATE_LIMITING") == "true" {
r.Use(rateLimitMiddleware(authRateLimiter))
}
r.Get("/providers", s.handleGetAuthProviders)
r.Get("/google/login", s.handleGoogleLogin)
r.Get("/google/callback", s.handleGoogleCallback)
@@ -203,16 +463,10 @@ func (s *Server) setupRoutes() {
r.Get("/{id}/tasks/summary", s.handleListJobTasksSummary)
r.Post("/{id}/tasks/batch", s.handleBatchGetTasks)
r.Get("/{id}/tasks/{taskId}/logs", s.handleGetTaskLogs)
// WebSocket route - no timeout middleware (long-lived connection)
r.With(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Remove timeout middleware for WebSocket
next.ServeHTTP(w, r)
})
}).Get("/{id}/tasks/{taskId}/logs/ws", s.handleStreamTaskLogsWebSocket)
// Old WebSocket route removed - use client WebSocket with subscriptions instead
r.Get("/{id}/tasks/{taskId}/steps", s.handleGetTaskSteps)
r.Post("/{id}/tasks/{taskId}/retry", s.handleRetryTask)
// WebSocket routes for real-time updates
// WebSocket route for unified client WebSocket
r.With(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Apply authentication middleware first
@@ -221,16 +475,7 @@ func (s *Server) setupRoutes() {
next.ServeHTTP(w, r)
})(w, r)
})
}).Get("/ws", s.handleJobsWebSocket)
r.With(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Apply authentication middleware first
s.auth.Middleware(func(w http.ResponseWriter, r *http.Request) {
// Remove timeout middleware for WebSocket
next.ServeHTTP(w, r)
})(w, r)
})
}).Get("/{id}/ws", s.handleJobWebSocket)
}).Get("/ws", s.handleClientWebSocket)
})
// Admin routes
@@ -286,8 +531,8 @@ func (s *Server) setupRoutes() {
})
})
// Serve static files (built React app)
s.router.Handle("/*", http.FileServer(http.Dir("./web/dist")))
// Serve static files (embedded React app with SPA fallback)
s.router.Handle("/*", web.SPAHandler())
}
// ServeHTTP implements http.Handler
@@ -308,6 +553,76 @@ func (s *Server) respondError(w http.ResponseWriter, status int, message string)
s.respondJSON(w, status, map[string]string{"error": message})
}
// createSessionCookie creates a secure session cookie with appropriate flags for the environment
func createSessionCookie(sessionID string) *http.Cookie {
cookie := &http.Cookie{
Name: "session_id",
Value: sessionID,
Path: "/",
MaxAge: SessionCookieMaxAge,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
// In production mode, set Secure flag to require HTTPS
if authpkg.IsProductionMode() {
cookie.Secure = true
}
return cookie
}
// handleHealthCheck returns server health status
func (s *Server) handleHealthCheck(w http.ResponseWriter, r *http.Request) {
// Check database connectivity
dbHealthy := true
if err := s.db.Ping(); err != nil {
dbHealthy = false
log.Printf("Health check: database ping failed: %v", err)
}
// Count connected runners
s.runnerConnsMu.RLock()
runnerCount := len(s.runnerConns)
s.runnerConnsMu.RUnlock()
// Count connected clients
s.clientConnsMu.RLock()
clientCount := len(s.clientConns)
s.clientConnsMu.RUnlock()
// Calculate uptime
uptime := time.Since(s.startTime)
// Get memory stats
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
status := "healthy"
statusCode := http.StatusOK
if !dbHealthy {
status = "degraded"
statusCode = http.StatusServiceUnavailable
}
response := map[string]interface{}{
"status": status,
"uptime_seconds": int64(uptime.Seconds()),
"database": dbHealthy,
"connected_runners": runnerCount,
"connected_clients": clientCount,
"memory": map[string]interface{}{
"alloc_mb": memStats.Alloc / 1024 / 1024,
"total_alloc_mb": memStats.TotalAlloc / 1024 / 1024,
"sys_mb": memStats.Sys / 1024 / 1024,
"num_gc": memStats.NumGC,
},
"timestamp": time.Now().Unix(),
}
s.respondJSON(w, statusCode, response)
}
// Auth handlers
func (s *Server) handleGoogleLogin(w http.ResponseWriter, r *http.Request) {
url, err := s.auth.GoogleLoginURL()
@@ -337,14 +652,7 @@ func (s *Server) handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: sessionID,
Path: "/",
MaxAge: 86400,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
http.SetCookie(w, createSessionCookie(sessionID))
http.Redirect(w, r, "/", http.StatusFound)
}
@@ -377,14 +685,7 @@ func (s *Server) handleDiscordCallback(w http.ResponseWriter, r *http.Request) {
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: sessionID,
Path: "/",
MaxAge: 86400,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
http.SetCookie(w, createSessionCookie(sessionID))
http.Redirect(w, r, "/", http.StatusFound)
}
@@ -394,13 +695,20 @@ func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
if err == nil {
s.auth.DeleteSession(cookie.Value)
}
http.SetCookie(w, &http.Cookie{
// Create an expired cookie to clear the session
expiredCookie := &http.Cookie{
Name: "session_id",
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
})
SameSite: http.SameSiteLaxMode,
}
// Use s.cfg.IsProductionMode() for consistency with other server methods
if s.cfg.IsProductionMode() {
expiredCookie.Secure = true
}
http.SetCookie(w, expiredCookie)
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Logged out"})
}
@@ -470,14 +778,7 @@ func (s *Server) handleLocalRegister(w http.ResponseWriter, r *http.Request) {
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: sessionID,
Path: "/",
MaxAge: 86400,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
http.SetCookie(w, createSessionCookie(sessionID))
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
"message": "Registration successful",
@@ -514,14 +815,7 @@ func (s *Server) handleLocalLogin(w http.ResponseWriter, r *http.Request) {
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: sessionID,
Path: "/",
MaxAge: 86400,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
http.SetCookie(w, createSessionCookie(sessionID))
s.respondJSON(w, http.StatusOK, map[string]interface{}{
"message": "Login successful",
@@ -612,6 +906,87 @@ func (s *Server) StartBackgroundTasks() {
go s.recoverStuckTasks()
go s.cleanupOldRenderJobs()
go s.cleanupOldTempDirectories()
go s.cleanupOldOfflineRunners()
go s.cleanupOldUploadSessions()
}
// recoverRunnersOnStartup checks for runners marked as online but not actually connected
// This runs once on startup to handle manager restarts where we lose track of connections
func (s *Server) recoverRunnersOnStartup() {
// Wait a short time for runners to reconnect after manager restart
// This gives runners a chance to reconnect before we mark them as dead
time.Sleep(5 * time.Second)
log.Printf("Recovering runners on startup: checking for disconnected runners...")
var onlineRunnerIDs []int64
err := s.db.With(func(conn *sql.DB) error {
rows, err := conn.Query(
`SELECT id FROM runners WHERE status = ?`,
types.RunnerStatusOnline,
)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var runnerID int64
if err := rows.Scan(&runnerID); err == nil {
onlineRunnerIDs = append(onlineRunnerIDs, runnerID)
}
}
return nil
})
if err != nil {
log.Printf("Failed to query online runners on startup: %v", err)
return
}
if len(onlineRunnerIDs) == 0 {
log.Printf("No runners marked as online on startup")
return
}
log.Printf("Found %d runners marked as online, checking actual connections...", len(onlineRunnerIDs))
// Check which runners are actually connected
s.runnerConnsMu.RLock()
deadRunnerIDs := make([]int64, 0)
for _, runnerID := range onlineRunnerIDs {
if _, connected := s.runnerConns[runnerID]; !connected {
deadRunnerIDs = append(deadRunnerIDs, runnerID)
}
}
s.runnerConnsMu.RUnlock()
if len(deadRunnerIDs) == 0 {
log.Printf("All runners marked as online are actually connected")
return
}
log.Printf("Found %d runners marked as online but not connected, redistributing their tasks...", len(deadRunnerIDs))
// Redistribute tasks for disconnected runners
for _, runnerID := range deadRunnerIDs {
log.Printf("Recovering runner %d: redistributing tasks and marking as offline", runnerID)
s.redistributeRunnerTasks(runnerID)
// Mark runner as offline
s.db.With(func(conn *sql.DB) error {
_, _ = conn.Exec(
`UPDATE runners SET status = ?, last_heartbeat = ? WHERE id = ?`,
types.RunnerStatusOffline, time.Now(), runnerID,
)
return nil
})
}
log.Printf("Startup recovery complete: redistributed tasks from %d disconnected runners", len(deadRunnerIDs))
// Trigger task distribution to assign recovered tasks to available runners
s.triggerTaskDistribution()
}
// recoverStuckTasks periodically checks for dead runners and stuck tasks
@@ -639,34 +1014,53 @@ func (s *Server) recoverStuckTasks() {
// Find dead runners (no heartbeat for 90 seconds)
// But only mark as dead if they're not actually connected via WebSocket
rows, err := s.db.Query(
`SELECT id FROM runners
WHERE last_heartbeat < CURRENT_TIMESTAMP - INTERVAL '90 seconds'
var deadRunnerIDs []int64
var stillConnectedIDs []int64
err := s.db.With(func(conn *sql.DB) error {
rows, err := conn.Query(
`SELECT id FROM runners
WHERE last_heartbeat < datetime('now', '-90 seconds')
AND status = ?`,
types.RunnerStatusOnline,
)
types.RunnerStatusOnline,
)
if err != nil {
return err
}
defer rows.Close()
s.runnerConnsMu.RLock()
for rows.Next() {
var runnerID int64
if err := rows.Scan(&runnerID); err == nil {
// Only mark as dead if not actually connected via WebSocket
// The WebSocket connection is the source of truth
if _, stillConnected := s.runnerConns[runnerID]; !stillConnected {
deadRunnerIDs = append(deadRunnerIDs, runnerID)
} else {
// Runner is still connected but heartbeat is stale - update it
stillConnectedIDs = append(stillConnectedIDs, runnerID)
}
}
}
s.runnerConnsMu.RUnlock()
return nil
})
if err != nil {
log.Printf("Failed to query dead runners: %v", err)
return
}
defer rows.Close()
var deadRunnerIDs []int64
s.runnerConnsMu.RLock()
for rows.Next() {
var runnerID int64
if err := rows.Scan(&runnerID); err == nil {
// Only mark as dead if not actually connected via WebSocket
// The WebSocket connection is the source of truth
if _, stillConnected := s.runnerConns[runnerID]; !stillConnected {
deadRunnerIDs = append(deadRunnerIDs, runnerID)
}
// If still connected, heartbeat should be updated by pong handler or heartbeat message
// No need to manually update here - if it's stale, the pong handler isn't working
}
// Update heartbeat for runners that are still connected but have stale heartbeats
// This ensures the database stays in sync with actual connection state
for _, runnerID := range stillConnectedIDs {
s.db.With(func(conn *sql.DB) error {
_, _ = conn.Exec(
`UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`,
time.Now(), types.RunnerStatusOnline, runnerID,
)
return nil
})
}
s.runnerConnsMu.RUnlock()
rows.Close()
if len(deadRunnerIDs) == 0 {
// Check for task timeouts
@@ -679,10 +1073,13 @@ func (s *Server) recoverStuckTasks() {
s.redistributeRunnerTasks(runnerID)
// Mark runner as offline
_, _ = s.db.Exec(
`UPDATE runners SET status = ? WHERE id = ?`,
types.RunnerStatusOffline, runnerID,
)
s.db.With(func(conn *sql.DB) error {
_, _ = conn.Exec(
`UPDATE runners SET status = ? WHERE id = ?`,
types.RunnerStatusOffline, runnerID,
)
return nil
})
}
// Check for task timeouts
@@ -697,33 +1094,59 @@ func (s *Server) recoverStuckTasks() {
// recoverTaskTimeouts handles tasks that have exceeded their timeout
func (s *Server) recoverTaskTimeouts() {
// Find tasks running longer than their timeout
rows, err := s.db.Query(
`SELECT t.id, t.runner_id, t.retry_count, t.max_retries, t.timeout_seconds, t.started_at
var tasks []struct {
taskID int64
runnerID sql.NullInt64
retryCount int
maxRetries int
timeoutSeconds sql.NullInt64
startedAt time.Time
}
err := s.db.With(func(conn *sql.DB) error {
rows, err := conn.Query(
`SELECT t.id, t.runner_id, t.retry_count, t.max_retries, t.timeout_seconds, t.started_at
FROM tasks t
WHERE t.status = ?
AND t.started_at IS NOT NULL
AND (t.timeout_seconds IS NULL OR
t.started_at + INTERVAL (t.timeout_seconds || ' seconds') < CURRENT_TIMESTAMP)`,
types.TaskStatusRunning,
)
(julianday('now') - julianday(t.started_at)) * 86400 > t.timeout_seconds)`,
types.TaskStatusRunning,
)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var task struct {
taskID int64
runnerID sql.NullInt64
retryCount int
maxRetries int
timeoutSeconds sql.NullInt64
startedAt time.Time
}
err := rows.Scan(&task.taskID, &task.runnerID, &task.retryCount, &task.maxRetries, &task.timeoutSeconds, &task.startedAt)
if err != nil {
log.Printf("Failed to scan task row in recoverTaskTimeouts: %v", err)
continue
}
tasks = append(tasks, task)
}
return nil
})
if err != nil {
log.Printf("Failed to query timed out tasks: %v", err)
return
}
defer rows.Close()
for rows.Next() {
var taskID int64
var runnerID sql.NullInt64
var retryCount, maxRetries int
var timeoutSeconds sql.NullInt64
var startedAt time.Time
err := rows.Scan(&taskID, &runnerID, &retryCount, &maxRetries, &timeoutSeconds, &startedAt)
if err != nil {
log.Printf("Failed to scan task row in recoverTaskTimeouts: %v", err)
continue
}
for _, task := range tasks {
taskID := task.taskID
retryCount := task.retryCount
maxRetries := task.maxRetries
timeoutSeconds := task.timeoutSeconds
startedAt := task.startedAt
// Use default timeout if not set (5 minutes for frame tasks, 24 hours for FFmpeg)
timeout := 300 // 5 minutes default
@@ -738,21 +1161,39 @@ func (s *Server) recoverTaskTimeouts() {
if retryCount >= maxRetries {
// Mark as failed
_, err = s.db.Exec(
`UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL
WHERE id = ?`,
types.TaskStatusFailed, "Task timeout exceeded, max retries reached", taskID,
)
err = s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(`UPDATE tasks SET status = ? WHERE id = ?`, types.TaskStatusFailed, taskID)
if err != nil {
return err
}
_, err = conn.Exec(`UPDATE tasks SET error_message = ? WHERE id = ?`, "Task timeout exceeded, max retries reached", taskID)
if err != nil {
return err
}
_, err = conn.Exec(`UPDATE tasks SET runner_id = NULL WHERE id = ?`, taskID)
return err
})
if err != nil {
log.Printf("Failed to mark task %d as failed: %v", taskID, err)
}
} else {
// Reset to pending
_, err = s.db.Exec(
`UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL,
retry_count = retry_count + 1 WHERE id = ?`,
types.TaskStatusPending, taskID,
)
err = s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(`UPDATE tasks SET status = ? WHERE id = ?`, types.TaskStatusPending, taskID)
if err != nil {
return err
}
_, err = conn.Exec(`UPDATE tasks SET runner_id = NULL WHERE id = ?`, taskID)
if err != nil {
return err
}
_, err = conn.Exec(`UPDATE tasks SET current_step = NULL WHERE id = ?`, taskID)
if err != nil {
return err
}
_, err = conn.Exec(`UPDATE tasks SET retry_count = retry_count + 1 WHERE id = ?`, taskID)
return err
})
if err == nil {
// Add log entry using the helper function
s.logTaskEvent(taskID, nil, types.LogLevelWarn, fmt.Sprintf("Task timeout exceeded, resetting (retry %d/%d)", retryCount+1, maxRetries), "")
@@ -800,6 +1241,14 @@ func (s *Server) cleanupOldTempDirectoriesOnce() {
now := time.Now()
cleanedCount := 0
// Check upload sessions to avoid deleting active uploads
s.uploadSessionsMu.RLock()
activeSessions := make(map[string]bool)
for sessionID := range s.uploadSessions {
activeSessions[sessionID] = true
}
s.uploadSessionsMu.RUnlock()
for _, entry := range entries {
if !entry.IsDir() {
continue
@@ -807,13 +1256,18 @@ func (s *Server) cleanupOldTempDirectoriesOnce() {
entryPath := filepath.Join(tempPath, entry.Name())
// Skip if this directory has an active upload session
if activeSessions[entryPath] {
continue
}
// Get directory info to check modification time
info, err := entry.Info()
if err != nil {
continue
}
// Remove directories older than 1 hour
// Remove directories older than 1 hour (only if no active session)
age := now.Sub(info.ModTime())
if age > 1*time.Hour {
if err := os.RemoveAll(entryPath); err != nil {
@@ -829,3 +1283,47 @@ func (s *Server) cleanupOldTempDirectoriesOnce() {
log.Printf("Cleaned up %d old temp directories", cleanedCount)
}
}
// cleanupOldUploadSessions periodically cleans up abandoned upload sessions
func (s *Server) cleanupOldUploadSessions() {
// Run cleanup every 10 minutes
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
// Run once immediately on startup
s.cleanupOldUploadSessionsOnce()
for range ticker.C {
s.cleanupOldUploadSessionsOnce()
}
}
// cleanupOldUploadSessionsOnce removes upload sessions older than 1 hour
func (s *Server) cleanupOldUploadSessionsOnce() {
defer func() {
if r := recover(); r != nil {
log.Printf("Panic in cleanupOldUploadSessions: %v", r)
}
}()
s.uploadSessionsMu.Lock()
defer s.uploadSessionsMu.Unlock()
now := time.Now()
cleanedCount := 0
for sessionID, session := range s.uploadSessions {
// Remove sessions older than 1 hour
age := now.Sub(session.CreatedAt)
if age > 1*time.Hour {
delete(s.uploadSessions, sessionID)
cleanedCount++
log.Printf("Cleaned up abandoned upload session: %s (user: %d, status: %s, age: %v)",
sessionID, session.UserID, session.Status, age)
}
}
if cleanedCount > 0 {
log.Printf("Cleaned up %d abandoned upload sessions", cleanedCount)
}
}

View File

@@ -5,10 +5,13 @@ import (
"database/sql"
"encoding/json"
"fmt"
"jiggablend/internal/config"
"jiggablend/internal/database"
"log"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/google/uuid"
@@ -17,12 +20,31 @@ import (
"golang.org/x/oauth2/google"
)
// Context key types to avoid collisions (typed keys are safer than string keys)
type contextKey int
const (
contextKeyUserID contextKey = iota
contextKeyUserEmail
contextKeyUserName
contextKeyIsAdmin
)
// Configuration constants
const (
SessionDuration = 24 * time.Hour
SessionCleanupInterval = 1 * time.Hour
)
// Auth handles authentication
type Auth struct {
db *sql.DB
db *database.DB
cfg *config.Config
googleConfig *oauth2.Config
discordConfig *oauth2.Config
sessionStore map[string]*Session
sessionCache map[string]*Session // In-memory cache for performance
cacheMu sync.RWMutex
stopCleanup chan struct{}
}
// Session represents a user session
@@ -35,41 +57,53 @@ type Session struct {
}
// NewAuth creates a new auth instance
func NewAuth(db *sql.DB) (*Auth, error) {
func NewAuth(db *database.DB, cfg *config.Config) (*Auth, error) {
auth := &Auth{
db: db,
sessionStore: make(map[string]*Session),
cfg: cfg,
sessionCache: make(map[string]*Session),
stopCleanup: make(chan struct{}),
}
// Initialize Google OAuth
googleClientID := os.Getenv("GOOGLE_CLIENT_ID")
googleClientSecret := os.Getenv("GOOGLE_CLIENT_SECRET")
// Initialize Google OAuth from database config
googleClientID := cfg.GoogleClientID()
googleClientSecret := cfg.GoogleClientSecret()
if googleClientID != "" && googleClientSecret != "" {
auth.googleConfig = &oauth2.Config{
ClientID: googleClientID,
ClientSecret: googleClientSecret,
RedirectURL: os.Getenv("GOOGLE_REDIRECT_URL"),
RedirectURL: cfg.GoogleRedirectURL(),
Scopes: []string{"openid", "profile", "email"},
Endpoint: google.Endpoint,
}
log.Printf("Google OAuth configured")
}
// Initialize Discord OAuth
discordClientID := os.Getenv("DISCORD_CLIENT_ID")
discordClientSecret := os.Getenv("DISCORD_CLIENT_SECRET")
// Initialize Discord OAuth from database config
discordClientID := cfg.DiscordClientID()
discordClientSecret := cfg.DiscordClientSecret()
if discordClientID != "" && discordClientSecret != "" {
auth.discordConfig = &oauth2.Config{
ClientID: discordClientID,
ClientSecret: discordClientSecret,
RedirectURL: os.Getenv("DISCORD_REDIRECT_URL"),
RedirectURL: cfg.DiscordRedirectURL(),
Scopes: []string{"identify", "email"},
Endpoint: oauth2.Endpoint{
AuthURL: "https://discord.com/api/oauth2/authorize",
TokenURL: "https://discord.com/api/oauth2/token",
},
}
log.Printf("Discord OAuth configured")
}
// Load existing sessions from database into cache
if err := auth.loadSessionsFromDB(); err != nil {
log.Printf("Warning: Failed to load sessions from database: %v", err)
}
// Start background cleanup goroutine
go auth.cleanupExpiredSessions()
// Initialize admin settings on startup to ensure they persist between boots
if err := auth.initializeSettings(); err != nil {
log.Printf("Warning: Failed to initialize admin settings: %v", err)
@@ -85,19 +119,119 @@ func NewAuth(db *sql.DB) (*Auth, error) {
return auth, nil
}
// Close stops background goroutines
func (a *Auth) Close() {
close(a.stopCleanup)
}
// loadSessionsFromDB loads all valid sessions from database into cache
func (a *Auth) loadSessionsFromDB() error {
var sessions []struct {
sessionID string
session Session
}
err := a.db.With(func(conn *sql.DB) error {
rows, err := conn.Query(
`SELECT session_id, user_id, email, name, is_admin, expires_at
FROM sessions WHERE expires_at > CURRENT_TIMESTAMP`,
)
if err != nil {
return fmt.Errorf("failed to query sessions: %w", err)
}
defer rows.Close()
for rows.Next() {
var s struct {
sessionID string
session Session
}
err := rows.Scan(&s.sessionID, &s.session.UserID, &s.session.Email, &s.session.Name, &s.session.IsAdmin, &s.session.ExpiresAt)
if err != nil {
log.Printf("Warning: Failed to scan session row: %v", err)
continue
}
sessions = append(sessions, s)
}
return nil
})
if err != nil {
return err
}
a.cacheMu.Lock()
defer a.cacheMu.Unlock()
for _, s := range sessions {
a.sessionCache[s.sessionID] = &s.session
}
if len(sessions) > 0 {
log.Printf("Loaded %d active sessions from database", len(sessions))
}
return nil
}
// cleanupExpiredSessions periodically removes expired sessions from database and cache
func (a *Auth) cleanupExpiredSessions() {
ticker := time.NewTicker(SessionCleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
// Delete expired sessions from database
var deleted int64
err := a.db.With(func(conn *sql.DB) error {
result, err := conn.Exec(`DELETE FROM sessions WHERE expires_at < CURRENT_TIMESTAMP`)
if err != nil {
return err
}
deleted, _ = result.RowsAffected()
return nil
})
if err != nil {
log.Printf("Warning: Failed to cleanup expired sessions: %v", err)
continue
}
// Clean up cache
a.cacheMu.Lock()
now := time.Now()
for sessionID, session := range a.sessionCache {
if now.After(session.ExpiresAt) {
delete(a.sessionCache, sessionID)
}
}
a.cacheMu.Unlock()
if deleted > 0 {
log.Printf("Cleaned up %d expired sessions", deleted)
}
case <-a.stopCleanup:
return
}
}
}
// initializeSettings ensures all admin settings are initialized with defaults if they don't exist
func (a *Auth) initializeSettings() error {
// Initialize registration_enabled setting (default: true) if it doesn't exist
var settingCount int
err := a.db.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
})
if err != nil {
return fmt.Errorf("failed to check registration_enabled setting: %w", err)
}
if settingCount == 0 {
_, err = a.db.Exec(
err = a.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)`,
"registration_enabled", "true",
)
return err
})
if err != nil {
return fmt.Errorf("failed to initialize registration_enabled setting: %w", err)
}
@@ -118,7 +252,9 @@ func (a *Auth) initializeTestUser() error {
// Check if user already exists
var exists bool
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND oauth_provider = 'local')", testEmail).Scan(&exists)
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND oauth_provider = 'local')", testEmail).Scan(&exists)
})
if err != nil {
return fmt.Errorf("failed to check if test user exists: %w", err)
}
@@ -137,7 +273,12 @@ func (a *Auth) initializeTestUser() error {
// Check if this is the first user (make them admin)
var userCount int
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
err = a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
})
if err != nil {
return fmt.Errorf("failed to check user count: %w", err)
}
isAdmin := userCount == 0
// Create test user (use email as name if no name is provided)
@@ -147,10 +288,13 @@ func (a *Auth) initializeTestUser() error {
}
// Create test user
_, err = a.db.Exec(
err = a.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?)",
testEmail, testName, testEmail, string(hashedPassword), isAdmin,
)
return err
})
if err != nil {
return fmt.Errorf("failed to create test user: %w", err)
}
@@ -242,7 +386,9 @@ func (a *Auth) DiscordCallback(ctx context.Context, code string) (*Session, erro
// IsRegistrationEnabled checks if new user registration is enabled
func (a *Auth) IsRegistrationEnabled() (bool, error) {
var value string
err := a.db.QueryRow("SELECT value FROM settings WHERE key = ?", "registration_enabled").Scan(&value)
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT value FROM settings WHERE key = ?", "registration_enabled").Scan(&value)
})
if err == sql.ErrNoRows {
// Default to enabled if setting doesn't exist
return true, nil
@@ -262,29 +408,31 @@ func (a *Auth) SetRegistrationEnabled(enabled bool) error {
// Check if setting exists
var exists bool
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", "registration_enabled").Scan(&exists)
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", "registration_enabled").Scan(&exists)
})
if err != nil {
return fmt.Errorf("failed to check if setting exists: %w", err)
}
err = a.db.With(func(conn *sql.DB) error {
if exists {
// Update existing setting
_, err = a.db.Exec(
_, err = conn.Exec(
"UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?",
value, "registration_enabled",
)
if err != nil {
return fmt.Errorf("failed to update setting: %w", err)
}
} else {
// Insert new setting
_, err = a.db.Exec(
_, err = conn.Exec(
"INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)",
"registration_enabled", value,
)
if err != nil {
return fmt.Errorf("failed to insert setting: %w", err)
}
return err
})
if err != nil {
return fmt.Errorf("failed to set registration_enabled: %w", err)
}
return nil
@@ -299,17 +447,21 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
var dbProvider, dbOAuthID string
// First, try to find by provider + oauth_id
err := a.db.QueryRow(
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
"SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE oauth_provider = ? AND oauth_id = ?",
provider, oauthID,
).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID)
})
if err == sql.ErrNoRows {
// Not found by provider+oauth_id, check by email for account linking
err = a.db.QueryRow(
err = a.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
"SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE email = ?",
email,
).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID)
})
if err == sql.ErrNoRows {
// User doesn't exist, check if registration is enabled
@@ -323,14 +475,26 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
// Check if this is the first user
var userCount int
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
err = a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
})
if err != nil {
return nil, fmt.Errorf("failed to check user count: %w", err)
}
isAdmin = userCount == 0
// Create new user
err = a.db.QueryRow(
"INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?) RETURNING id",
err = a.db.With(func(conn *sql.DB) error {
result, err := conn.Exec(
"INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?)",
email, name, provider, oauthID, isAdmin,
).Scan(&userID)
)
if err != nil {
return err
}
userID, err = result.LastInsertId()
return err
})
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
@@ -339,10 +503,13 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
} else {
// User exists with same email but different provider - link accounts by updating provider info
// This allows the user to log in with any provider that has the same email
_, err = a.db.Exec(
err = a.db.With(func(conn *sql.DB) error {
_, err = conn.Exec(
"UPDATE users SET oauth_provider = ?, oauth_id = ?, name = ? WHERE id = ?",
provider, oauthID, name, userID,
)
return err
})
if err != nil {
return nil, fmt.Errorf("failed to link account: %w", err)
}
@@ -353,10 +520,13 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
} else {
// User found by provider+oauth_id, update info if changed
if dbEmail != email || dbName != name {
_, err = a.db.Exec(
err = a.db.With(func(conn *sql.DB) error {
_, err = conn.Exec(
"UPDATE users SET email = ?, name = ? WHERE id = ?",
email, name, userID,
)
return err
})
if err != nil {
return nil, fmt.Errorf("failed to update user: %w", err)
}
@@ -368,41 +538,134 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
Email: email,
Name: name,
IsAdmin: isAdmin,
ExpiresAt: time.Now().Add(24 * time.Hour),
ExpiresAt: time.Now().Add(SessionDuration),
}
return session, nil
}
// CreateSession creates a new session and returns a session ID
// Sessions are persisted to database and cached in memory
func (a *Auth) CreateSession(session *Session) string {
sessionID := uuid.New().String()
a.sessionStore[sessionID] = session
// Store in database first
err := a.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`INSERT INTO sessions (session_id, user_id, email, name, is_admin, expires_at)
VALUES (?, ?, ?, ?, ?, ?)`,
sessionID, session.UserID, session.Email, session.Name, session.IsAdmin, session.ExpiresAt,
)
return err
})
if err != nil {
log.Printf("Warning: Failed to persist session to database: %v", err)
// Continue anyway - session will work from cache but won't survive restart
}
// Store in cache
a.cacheMu.Lock()
a.sessionCache[sessionID] = session
a.cacheMu.Unlock()
return sessionID
}
// GetSession retrieves a session by ID
// First checks cache, then database if not found
func (a *Auth) GetSession(sessionID string) (*Session, bool) {
session, ok := a.sessionStore[sessionID]
if !ok {
// Check cache first
a.cacheMu.RLock()
session, ok := a.sessionCache[sessionID]
a.cacheMu.RUnlock()
if ok {
if time.Now().After(session.ExpiresAt) {
a.DeleteSession(sessionID)
return nil, false
}
// Refresh admin status from database
var isAdmin bool
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin)
})
if err == nil {
session.IsAdmin = isAdmin
}
return session, true
}
// Not in cache, check database
session = &Session{}
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT user_id, email, name, is_admin, expires_at
FROM sessions WHERE session_id = ?`,
sessionID,
).Scan(&session.UserID, &session.Email, &session.Name, &session.IsAdmin, &session.ExpiresAt)
})
if err == sql.ErrNoRows {
return nil, false
}
if err != nil {
log.Printf("Warning: Failed to query session from database: %v", err)
return nil, false
}
if time.Now().After(session.ExpiresAt) {
delete(a.sessionStore, sessionID)
a.DeleteSession(sessionID)
return nil, false
}
// Refresh admin status from database
var isAdmin bool
err := a.db.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin)
err = a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin)
})
if err == nil {
session.IsAdmin = isAdmin
}
// Add to cache
a.cacheMu.Lock()
a.sessionCache[sessionID] = session
a.cacheMu.Unlock()
return session, true
}
// DeleteSession deletes a session
// DeleteSession deletes a session from both cache and database
func (a *Auth) DeleteSession(sessionID string) {
delete(a.sessionStore, sessionID)
// Delete from cache
a.cacheMu.Lock()
delete(a.sessionCache, sessionID)
a.cacheMu.Unlock()
// Delete from database
err := a.db.With(func(conn *sql.DB) error {
_, err := conn.Exec("DELETE FROM sessions WHERE session_id = ?", sessionID)
return err
})
if err != nil {
log.Printf("Warning: Failed to delete session from database: %v", err)
}
}
// IsProductionMode returns true if running in production mode
// This is a package-level function that checks the environment variable
// For config-based checks, use Config.IsProductionMode()
func IsProductionMode() bool {
// Check environment variable first for backwards compatibility
if os.Getenv("PRODUCTION") == "true" {
return true
}
return false
}
// IsProductionModeFromConfig returns true if production mode is enabled in config
func (a *Auth) IsProductionModeFromConfig() bool {
return a.cfg.IsProductionMode()
}
// Middleware creates an authentication middleware
@@ -426,24 +689,24 @@ func (a *Auth) Middleware(next http.HandlerFunc) http.HandlerFunc {
return
}
// Add user info to request context
ctx := context.WithValue(r.Context(), "user_id", session.UserID)
ctx = context.WithValue(ctx, "user_email", session.Email)
ctx = context.WithValue(ctx, "user_name", session.Name)
ctx = context.WithValue(ctx, "is_admin", session.IsAdmin)
// Add user info to request context using typed keys
ctx := context.WithValue(r.Context(), contextKeyUserID, session.UserID)
ctx = context.WithValue(ctx, contextKeyUserEmail, session.Email)
ctx = context.WithValue(ctx, contextKeyUserName, session.Name)
ctx = context.WithValue(ctx, contextKeyIsAdmin, session.IsAdmin)
next(w, r.WithContext(ctx))
}
}
// GetUserID gets the user ID from context
func GetUserID(ctx context.Context) (int64, bool) {
userID, ok := ctx.Value("user_id").(int64)
userID, ok := ctx.Value(contextKeyUserID).(int64)
return userID, ok
}
// IsAdmin checks if the user in context is an admin
func IsAdmin(ctx context.Context) bool {
isAdmin, ok := ctx.Value("is_admin").(bool)
isAdmin, ok := ctx.Value(contextKeyIsAdmin).(bool)
return ok && isAdmin
}
@@ -478,19 +741,19 @@ func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
return
}
// Add user info to request context
ctx := context.WithValue(r.Context(), "user_id", session.UserID)
ctx = context.WithValue(ctx, "user_email", session.Email)
ctx = context.WithValue(ctx, "user_name", session.Name)
ctx = context.WithValue(ctx, "is_admin", session.IsAdmin)
// Add user info to request context using typed keys
ctx := context.WithValue(r.Context(), contextKeyUserID, session.UserID)
ctx = context.WithValue(ctx, contextKeyUserEmail, session.Email)
ctx = context.WithValue(ctx, contextKeyUserName, session.Name)
ctx = context.WithValue(ctx, contextKeyIsAdmin, session.IsAdmin)
next(w, r.WithContext(ctx))
}
}
// IsLocalLoginEnabled returns whether local login is enabled
// Local login is enabled when ENABLE_LOCAL_AUTH environment variable is set to "true"
// Checks database config first, falls back to environment variable
func (a *Auth) IsLocalLoginEnabled() bool {
return os.Getenv("ENABLE_LOCAL_AUTH") == "true"
return a.cfg.IsLocalAuthEnabled()
}
// IsGoogleOAuthConfigured returns whether Google OAuth is configured
@@ -511,10 +774,12 @@ func (a *Auth) LocalLogin(username, password string) (*Session, error) {
var dbEmail, dbName, passwordHash string
var isAdmin bool
err := a.db.QueryRow(
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
"SELECT id, email, name, password_hash, is_admin FROM users WHERE email = ? AND oauth_provider = 'local'",
email,
).Scan(&userID, &dbEmail, &dbName, &passwordHash, &isAdmin)
})
if err == sql.ErrNoRows {
return nil, fmt.Errorf("invalid credentials")
@@ -539,7 +804,7 @@ func (a *Auth) LocalLogin(username, password string) (*Session, error) {
Email: dbEmail,
Name: dbName,
IsAdmin: isAdmin,
ExpiresAt: time.Now().Add(24 * time.Hour),
ExpiresAt: time.Now().Add(SessionDuration),
}
return session, nil
@@ -558,7 +823,9 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error)
// Check if user already exists
var exists bool
err = a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ?)", email).Scan(&exists)
err = a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ?)", email).Scan(&exists)
})
if err != nil {
return nil, fmt.Errorf("failed to check if user exists: %w", err)
}
@@ -574,15 +841,27 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error)
// Check if this is the first user (make them admin)
var userCount int
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
err = a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
})
if err != nil {
return nil, fmt.Errorf("failed to check user count: %w", err)
}
isAdmin := userCount == 0
// Create user
var userID int64
err = a.db.QueryRow(
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?) RETURNING id",
err = a.db.With(func(conn *sql.DB) error {
result, err := conn.Exec(
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?)",
email, name, email, string(hashedPassword), isAdmin,
).Scan(&userID)
)
if err != nil {
return err
}
userID, err = result.LastInsertId()
return err
})
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
@@ -593,7 +872,7 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error)
Email: email,
Name: name,
IsAdmin: isAdmin,
ExpiresAt: time.Now().Add(24 * time.Hour),
ExpiresAt: time.Now().Add(SessionDuration),
}
return session, nil
@@ -603,7 +882,9 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error)
func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) error {
// Get current password hash
var passwordHash string
err := a.db.QueryRow("SELECT password_hash FROM users WHERE id = ? AND oauth_provider = 'local'", userID).Scan(&passwordHash)
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT password_hash FROM users WHERE id = ? AND oauth_provider = 'local'", userID).Scan(&passwordHash)
})
if err == sql.ErrNoRows {
return fmt.Errorf("user not found or not a local user")
}
@@ -628,7 +909,10 @@ func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) err
}
// Update password
_, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), userID)
err = a.db.With(func(conn *sql.DB) error {
_, err := conn.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), userID)
return err
})
if err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
@@ -640,7 +924,9 @@ func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) err
func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error {
// Verify user exists and is a local user
var exists bool
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ? AND oauth_provider = 'local')", targetUserID).Scan(&exists)
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ? AND oauth_provider = 'local')", targetUserID).Scan(&exists)
})
if err != nil {
return fmt.Errorf("failed to check if user exists: %w", err)
}
@@ -655,7 +941,10 @@ func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error
}
// Update password
_, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), targetUserID)
err = a.db.With(func(conn *sql.DB) error {
_, err := conn.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), targetUserID)
return err
})
if err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
@@ -666,7 +955,9 @@ func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error
// GetFirstUserID returns the ID of the first user (user with the lowest ID)
func (a *Auth) GetFirstUserID() (int64, error) {
var firstUserID int64
err := a.db.QueryRow("SELECT id FROM users ORDER BY id ASC LIMIT 1").Scan(&firstUserID)
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT id FROM users ORDER BY id ASC LIMIT 1").Scan(&firstUserID)
})
if err == sql.ErrNoRows {
return 0, fmt.Errorf("no users found")
}
@@ -680,7 +971,9 @@ func (a *Auth) GetFirstUserID() (int64, error) {
func (a *Auth) SetUserAdminStatus(targetUserID int64, isAdmin bool) error {
// Verify user exists
var exists bool
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", targetUserID).Scan(&exists)
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", targetUserID).Scan(&exists)
})
if err != nil {
return fmt.Errorf("failed to check if user exists: %w", err)
}
@@ -698,7 +991,10 @@ func (a *Auth) SetUserAdminStatus(targetUserID int64, isAdmin bool) error {
}
// Update admin status
_, err = a.db.Exec("UPDATE users SET is_admin = ? WHERE id = ?", isAdmin, targetUserID)
err = a.db.With(func(conn *sql.DB) error {
_, err := conn.Exec("UPDATE users SET is_admin = ? WHERE id = ?", isAdmin, targetUserID)
return err
})
if err != nil {
return fmt.Errorf("failed to update admin status: %w", err)
}

View File

@@ -6,7 +6,8 @@ import (
"database/sql"
"encoding/hex"
"fmt"
"os"
"jiggablend/internal/config"
"jiggablend/internal/database"
"strings"
"sync"
"time"
@@ -14,21 +15,14 @@ import (
// Secrets handles API key management
type Secrets struct {
db *sql.DB
db *database.DB
cfg *config.Config
RegistrationMu sync.Mutex // Protects concurrent runner registrations
fixedAPIKey string // Fixed API key from environment variable (optional)
}
// NewSecrets creates a new secrets manager
func NewSecrets(db *sql.DB) (*Secrets, error) {
s := &Secrets{db: db}
// Check for fixed API key from environment
if fixedKey := os.Getenv("FIXED_API_KEY"); fixedKey != "" {
s.fixedAPIKey = fixedKey
}
return s, nil
func NewSecrets(db *database.DB, cfg *config.Config) (*Secrets, error) {
return &Secrets{db: db, cfg: cfg}, nil
}
// APIKeyInfo represents information about an API key
@@ -61,25 +55,34 @@ func (s *Secrets) GenerateRunnerAPIKey(createdBy int64, name, description string
keyHash := sha256.Sum256([]byte(key))
keyHashStr := hex.EncodeToString(keyHash[:])
_, err = s.db.Exec(
var keyInfo APIKeyInfo
err = s.db.With(func(conn *sql.DB) error {
result, err := conn.Exec(
`INSERT INTO runner_api_keys (key_prefix, key_hash, name, description, scope, is_active, created_by)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
keyPrefix, keyHashStr, name, description, scope, true, createdBy,
)
if err != nil {
return nil, fmt.Errorf("failed to store API key: %w", err)
return fmt.Errorf("failed to store API key: %w", err)
}
keyID, err := result.LastInsertId()
if err != nil {
return fmt.Errorf("failed to get inserted key ID: %w", err)
}
// Get the inserted key info
var keyInfo APIKeyInfo
err = s.db.QueryRow(
err = conn.QueryRow(
`SELECT id, name, description, scope, is_active, created_at, created_by
FROM runner_api_keys WHERE key_prefix = ?`,
keyPrefix,
FROM runner_api_keys WHERE id = ?`,
keyID,
).Scan(&keyInfo.ID, &keyInfo.Name, &keyInfo.Description, &keyInfo.Scope, &keyInfo.IsActive, &keyInfo.CreatedAt, &keyInfo.CreatedBy)
return err
})
if err != nil {
return nil, fmt.Errorf("failed to retrieve created API key: %w", err)
return nil, fmt.Errorf("failed to create API key: %w", err)
}
keyInfo.Key = key
@@ -91,18 +94,25 @@ func (s *Secrets) generateAPIKey() (string, error) {
// Generate random suffix
randomBytes := make([]byte, 16)
if _, err := rand.Read(randomBytes); err != nil {
return "", err
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
randomStr := hex.EncodeToString(randomBytes)
// Generate a unique prefix (jk_r followed by 1 random digit)
prefixDigit := make([]byte, 1)
if _, err := rand.Read(prefixDigit); err != nil {
return "", err
return "", fmt.Errorf("failed to generate prefix digit: %w", err)
}
prefix := fmt.Sprintf("jk_r%d", prefixDigit[0]%10)
return fmt.Sprintf("%s_%s", prefix, randomStr), nil
key := fmt.Sprintf("%s_%s", prefix, randomStr)
// Validate generated key format
if !strings.HasPrefix(key, "jk_r") {
return "", fmt.Errorf("generated invalid API key format: %s", key)
}
return key, nil
}
// ValidateRunnerAPIKey validates an API key and returns the key ID and scope if valid
@@ -111,8 +121,9 @@ func (s *Secrets) ValidateRunnerAPIKey(apiKey string) (int64, string, error) {
return 0, "", fmt.Errorf("API key is required")
}
// Check fixed API key first (for testing/development)
if s.fixedAPIKey != "" && apiKey == s.fixedAPIKey {
// Check fixed API key first (from database config)
fixedKey := s.cfg.FixedAPIKey()
if fixedKey != "" && apiKey == fixedKey {
// Return a special ID for fixed API key (doesn't exist in database)
return -1, "manager", nil
}
@@ -137,42 +148,50 @@ func (s *Secrets) ValidateRunnerAPIKey(apiKey string) (int64, string, error) {
var scope string
var isActive bool
err := s.db.QueryRow(
err := s.db.With(func(conn *sql.DB) error {
err := conn.QueryRow(
`SELECT id, scope, is_active FROM runner_api_keys
WHERE key_prefix = ? AND key_hash = ?`,
keyPrefix, keyHashStr,
).Scan(&keyID, &scope, &isActive)
if err == sql.ErrNoRows {
return 0, "", fmt.Errorf("API key not found or invalid - please check that the key is correct and active")
return fmt.Errorf("API key not found or invalid - please check that the key is correct and active")
}
if err != nil {
return 0, "", fmt.Errorf("failed to validate API key: %w", err)
return fmt.Errorf("failed to validate API key: %w", err)
}
if !isActive {
return 0, "", fmt.Errorf("API key is inactive")
return fmt.Errorf("API key is inactive")
}
// Update last_used_at (don't fail if this update fails)
s.db.Exec(`UPDATE runner_api_keys SET last_used_at = ? WHERE id = ?`, time.Now(), keyID)
conn.Exec(`UPDATE runner_api_keys SET last_used_at = ? WHERE id = ?`, time.Now(), keyID)
return nil
})
if err != nil {
return 0, "", err
}
return keyID, scope, nil
}
// ListRunnerAPIKeys lists all runner API keys
func (s *Secrets) ListRunnerAPIKeys() ([]APIKeyInfo, error) {
rows, err := s.db.Query(
var keys []APIKeyInfo
err := s.db.With(func(conn *sql.DB) error {
rows, err := conn.Query(
`SELECT id, key_prefix, name, description, scope, is_active, created_at, created_by
FROM runner_api_keys
ORDER BY created_at DESC`,
)
if err != nil {
return nil, fmt.Errorf("failed to query API keys: %w", err)
return fmt.Errorf("failed to query API keys: %w", err)
}
defer rows.Close()
var keys []APIKeyInfo
for rows.Next() {
var key APIKeyInfo
var description sql.NullString
@@ -188,20 +207,28 @@ func (s *Secrets) ListRunnerAPIKeys() ([]APIKeyInfo, error) {
keys = append(keys, key)
}
return nil
})
if err != nil {
return nil, err
}
return keys, nil
}
// RevokeRunnerAPIKey revokes (deactivates) a runner API key
func (s *Secrets) RevokeRunnerAPIKey(keyID int64) error {
_, err := s.db.Exec("UPDATE runner_api_keys SET is_active = false WHERE id = ?", keyID)
return s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec("UPDATE runner_api_keys SET is_active = false WHERE id = ?", keyID)
return err
})
}
// DeleteRunnerAPIKey deletes a runner API key
func (s *Secrets) DeleteRunnerAPIKey(keyID int64) error {
_, err := s.db.Exec("DELETE FROM runner_api_keys WHERE id = ?", keyID)
return s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec("DELETE FROM runner_api_keys WHERE id = ?", keyID)
return err
})
}

303
internal/config/config.go Normal file
View File

@@ -0,0 +1,303 @@
package config
import (
"database/sql"
"fmt"
"jiggablend/internal/database"
"log"
"os"
"strconv"
)
// Config keys stored in database
const (
KeyGoogleClientID = "google_client_id"
KeyGoogleClientSecret = "google_client_secret"
KeyGoogleRedirectURL = "google_redirect_url"
KeyDiscordClientID = "discord_client_id"
KeyDiscordClientSecret = "discord_client_secret"
KeyDiscordRedirectURL = "discord_redirect_url"
KeyEnableLocalAuth = "enable_local_auth"
KeyFixedAPIKey = "fixed_api_key"
KeyRegistrationEnabled = "registration_enabled"
KeyProductionMode = "production_mode"
KeyAllowedOrigins = "allowed_origins"
)
// Config manages application configuration stored in the database
type Config struct {
db *database.DB
}
// NewConfig creates a new config manager
func NewConfig(db *database.DB) *Config {
return &Config{db: db}
}
// InitializeFromEnv loads configuration from environment variables on first run
// Environment variables take precedence only if the config key doesn't exist in the database
// This allows first-run setup via env vars, then subsequent runs use database values
func (c *Config) InitializeFromEnv() error {
envMappings := []struct {
envKey string
configKey string
sensitive bool
}{
{"GOOGLE_CLIENT_ID", KeyGoogleClientID, false},
{"GOOGLE_CLIENT_SECRET", KeyGoogleClientSecret, true},
{"GOOGLE_REDIRECT_URL", KeyGoogleRedirectURL, false},
{"DISCORD_CLIENT_ID", KeyDiscordClientID, false},
{"DISCORD_CLIENT_SECRET", KeyDiscordClientSecret, true},
{"DISCORD_REDIRECT_URL", KeyDiscordRedirectURL, false},
{"ENABLE_LOCAL_AUTH", KeyEnableLocalAuth, false},
{"FIXED_API_KEY", KeyFixedAPIKey, true},
{"PRODUCTION", KeyProductionMode, false},
{"ALLOWED_ORIGINS", KeyAllowedOrigins, false},
}
for _, mapping := range envMappings {
envValue := os.Getenv(mapping.envKey)
if envValue == "" {
continue
}
// Check if config already exists in database
exists, err := c.Exists(mapping.configKey)
if err != nil {
return fmt.Errorf("failed to check config %s: %w", mapping.configKey, err)
}
if !exists {
// Store env value in database
if err := c.Set(mapping.configKey, envValue); err != nil {
return fmt.Errorf("failed to store config %s: %w", mapping.configKey, err)
}
if mapping.sensitive {
log.Printf("Stored config from env: %s = [REDACTED]", mapping.configKey)
} else {
log.Printf("Stored config from env: %s = %s", mapping.configKey, envValue)
}
}
}
return nil
}
// Get retrieves a config value from the database
func (c *Config) Get(key string) (string, error) {
var value string
err := c.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT value FROM settings WHERE key = ?", key).Scan(&value)
})
if err == sql.ErrNoRows {
return "", nil
}
if err != nil {
return "", fmt.Errorf("failed to get config %s: %w", key, err)
}
return value, nil
}
// GetWithDefault retrieves a config value or returns a default if not set
func (c *Config) GetWithDefault(key, defaultValue string) string {
value, err := c.Get(key)
if err != nil || value == "" {
return defaultValue
}
return value
}
// GetBool retrieves a boolean config value
func (c *Config) GetBool(key string) (bool, error) {
value, err := c.Get(key)
if err != nil {
return false, err
}
return value == "true" || value == "1", nil
}
// GetBoolWithDefault retrieves a boolean config value or returns a default
func (c *Config) GetBoolWithDefault(key string, defaultValue bool) bool {
value, err := c.GetBool(key)
if err != nil {
return defaultValue
}
// If the key doesn't exist, Get returns empty string which becomes false
// Check if key exists to distinguish between "false" and "not set"
exists, _ := c.Exists(key)
if !exists {
return defaultValue
}
return value
}
// GetInt retrieves an integer config value
func (c *Config) GetInt(key string) (int, error) {
value, err := c.Get(key)
if err != nil {
return 0, err
}
if value == "" {
return 0, nil
}
return strconv.Atoi(value)
}
// GetIntWithDefault retrieves an integer config value or returns a default
func (c *Config) GetIntWithDefault(key string, defaultValue int) int {
value, err := c.GetInt(key)
if err != nil {
return defaultValue
}
exists, _ := c.Exists(key)
if !exists {
return defaultValue
}
return value
}
// Set stores a config value in the database
func (c *Config) Set(key, value string) error {
// Use upsert pattern
exists, err := c.Exists(key)
if err != nil {
return err
}
err = c.db.With(func(conn *sql.DB) error {
if exists {
_, err = conn.Exec(
"UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?",
value, key,
)
} else {
_, err = conn.Exec(
"INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)",
key, value,
)
}
return err
})
if err != nil {
return fmt.Errorf("failed to set config %s: %w", key, err)
}
return nil
}
// SetBool stores a boolean config value
func (c *Config) SetBool(key string, value bool) error {
strValue := "false"
if value {
strValue = "true"
}
return c.Set(key, strValue)
}
// SetInt stores an integer config value
func (c *Config) SetInt(key string, value int) error {
return c.Set(key, strconv.Itoa(value))
}
// Delete removes a config value from the database
func (c *Config) Delete(key string) error {
err := c.db.With(func(conn *sql.DB) error {
_, err := conn.Exec("DELETE FROM settings WHERE key = ?", key)
return err
})
if err != nil {
return fmt.Errorf("failed to delete config %s: %w", key, err)
}
return nil
}
// Exists checks if a config key exists in the database
func (c *Config) Exists(key string) (bool, error) {
var exists bool
err := c.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", key).Scan(&exists)
})
if err != nil {
return false, fmt.Errorf("failed to check config existence %s: %w", key, err)
}
return exists, nil
}
// GetAll returns all config values (for debugging/admin purposes)
func (c *Config) GetAll() (map[string]string, error) {
var result map[string]string
err := c.db.With(func(conn *sql.DB) error {
rows, err := conn.Query("SELECT key, value FROM settings")
if err != nil {
return fmt.Errorf("failed to get all config: %w", err)
}
defer rows.Close()
result = make(map[string]string)
for rows.Next() {
var key, value string
if err := rows.Scan(&key, &value); err != nil {
return fmt.Errorf("failed to scan config row: %w", err)
}
result[key] = value
}
return nil
})
if err != nil {
return nil, err
}
return result, nil
}
// --- Convenience methods for specific config values ---
// GoogleClientID returns the Google OAuth client ID
func (c *Config) GoogleClientID() string {
return c.GetWithDefault(KeyGoogleClientID, "")
}
// GoogleClientSecret returns the Google OAuth client secret
func (c *Config) GoogleClientSecret() string {
return c.GetWithDefault(KeyGoogleClientSecret, "")
}
// GoogleRedirectURL returns the Google OAuth redirect URL
func (c *Config) GoogleRedirectURL() string {
return c.GetWithDefault(KeyGoogleRedirectURL, "")
}
// DiscordClientID returns the Discord OAuth client ID
func (c *Config) DiscordClientID() string {
return c.GetWithDefault(KeyDiscordClientID, "")
}
// DiscordClientSecret returns the Discord OAuth client secret
func (c *Config) DiscordClientSecret() string {
return c.GetWithDefault(KeyDiscordClientSecret, "")
}
// DiscordRedirectURL returns the Discord OAuth redirect URL
func (c *Config) DiscordRedirectURL() string {
return c.GetWithDefault(KeyDiscordRedirectURL, "")
}
// IsLocalAuthEnabled returns whether local authentication is enabled
func (c *Config) IsLocalAuthEnabled() bool {
return c.GetBoolWithDefault(KeyEnableLocalAuth, false)
}
// FixedAPIKey returns the fixed API key for testing
func (c *Config) FixedAPIKey() string {
return c.GetWithDefault(KeyFixedAPIKey, "")
}
// IsProductionMode returns whether production mode is enabled
func (c *Config) IsProductionMode() bool {
return c.GetBoolWithDefault(KeyProductionMode, false)
}
// AllowedOrigins returns the allowed CORS origins
func (c *Config) AllowedOrigins() string {
return c.GetWithDefault(KeyAllowedOrigins, "")
}

View File

@@ -4,18 +4,20 @@ import (
"database/sql"
"fmt"
"log"
"sync"
_ "github.com/marcboeker/go-duckdb/v2"
_ "github.com/mattn/go-sqlite3"
)
// DB wraps the database connection
// DB wraps the database connection with mutex protection
type DB struct {
*sql.DB
db *sql.DB
mu sync.Mutex
}
// NewDB creates a new database connection
func NewDB(dbPath string) (*DB, error) {
db, err := sql.Open("duckdb", dbPath)
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
@@ -24,7 +26,12 @@ func NewDB(dbPath string) (*DB, error) {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
database := &DB{DB: db}
// Enable foreign keys for SQLite
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
}
database := &DB{db: db}
if err := database.migrate(); err != nil {
return nil, fmt.Errorf("failed to migrate database: %w", err)
}
@@ -32,58 +39,74 @@ func NewDB(dbPath string) (*DB, error) {
return database, nil
}
// With executes a function with mutex-protected access to the database
// The function receives the underlying *sql.DB connection
func (db *DB) With(fn func(*sql.DB) error) error {
db.mu.Lock()
defer db.mu.Unlock()
return fn(db.db)
}
// WithTx executes a function within a transaction with mutex protection
// The function receives a *sql.Tx transaction
// If the function returns an error, the transaction is rolled back
// If the function returns nil, the transaction is committed
func (db *DB) WithTx(fn func(*sql.Tx) error) error {
db.mu.Lock()
defer db.mu.Unlock()
tx, err := db.db.Begin()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
if err := fn(tx); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("transaction error: %w, rollback error: %v", err, rbErr)
}
return err
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
// migrate runs database migrations
func (db *DB) migrate() error {
// Create sequences for auto-incrementing primary keys
sequences := []string{
`CREATE SEQUENCE IF NOT EXISTS seq_users_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_jobs_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_runners_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_tasks_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_job_files_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_manager_secrets_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_registration_tokens_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_runner_api_keys_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_task_logs_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_task_steps_id START 1`,
}
for _, seq := range sequences {
if _, err := db.Exec(seq); err != nil {
return fmt.Errorf("failed to create sequence: %w", err)
}
}
// SQLite uses INTEGER PRIMARY KEY AUTOINCREMENT instead of sequences
schema := `
CREATE TABLE IF NOT EXISTS users (
id BIGINT PRIMARY KEY DEFAULT nextval('seq_users_id'),
id INTEGER PRIMARY KEY AUTOINCREMENT,
email TEXT UNIQUE NOT NULL,
name TEXT NOT NULL,
oauth_provider TEXT NOT NULL,
oauth_id TEXT NOT NULL,
password_hash TEXT,
is_admin BOOLEAN NOT NULL DEFAULT false,
is_admin INTEGER NOT NULL DEFAULT 0,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
UNIQUE(oauth_provider, oauth_id)
);
CREATE TABLE IF NOT EXISTS runner_api_keys (
id BIGINT PRIMARY KEY DEFAULT nextval('seq_runner_api_keys_id'),
key_prefix TEXT NOT NULL, -- First part of API key (e.g., "jk_r1")
key_hash TEXT NOT NULL, -- SHA256 hash of full API key
name TEXT NOT NULL, -- Human-readable name
description TEXT, -- Optional description
scope TEXT NOT NULL DEFAULT 'user', -- 'manager' or 'user' - manager scope allows all jobs, user scope only allows jobs from key owner
is_active BOOLEAN NOT NULL DEFAULT true,
id INTEGER PRIMARY KEY AUTOINCREMENT,
key_prefix TEXT NOT NULL,
key_hash TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT,
scope TEXT NOT NULL DEFAULT 'user',
is_active INTEGER NOT NULL DEFAULT 1,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
created_by BIGINT,
created_by INTEGER,
FOREIGN KEY (created_by) REFERENCES users(id),
UNIQUE(key_prefix)
);
CREATE TABLE IF NOT EXISTS jobs (
id BIGINT PRIMARY KEY DEFAULT nextval('seq_jobs_id'),
user_id BIGINT NOT NULL,
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
job_type TEXT NOT NULL DEFAULT 'render',
name TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
@@ -91,9 +114,11 @@ func (db *DB) migrate() error {
frame_start INTEGER,
frame_end INTEGER,
output_format TEXT,
allow_parallel_runners BOOLEAN NOT NULL DEFAULT true,
allow_parallel_runners INTEGER NOT NULL DEFAULT 1,
timeout_seconds INTEGER DEFAULT 86400,
blend_metadata TEXT,
retry_count INTEGER NOT NULL DEFAULT 0,
max_retries INTEGER NOT NULL DEFAULT 3,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
@@ -102,25 +127,25 @@ func (db *DB) migrate() error {
);
CREATE TABLE IF NOT EXISTS runners (
id BIGINT PRIMARY KEY DEFAULT nextval('seq_runners_id'),
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
hostname TEXT NOT NULL,
ip_address TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'offline',
last_heartbeat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
capabilities TEXT,
api_key_id BIGINT, -- Reference to the API key used for this runner
api_key_scope TEXT NOT NULL DEFAULT 'user', -- Scope of the API key ('manager' or 'user')
api_key_id INTEGER,
api_key_scope TEXT NOT NULL DEFAULT 'user',
priority INTEGER NOT NULL DEFAULT 100,
fingerprint TEXT, -- Hardware fingerprint (NULL for fixed API keys)
fingerprint TEXT,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (api_key_id) REFERENCES runner_api_keys(id)
);
CREATE TABLE IF NOT EXISTS tasks (
id BIGINT PRIMARY KEY DEFAULT nextval('seq_tasks_id'),
job_id BIGINT NOT NULL,
runner_id BIGINT,
id INTEGER PRIMARY KEY AUTOINCREMENT,
job_id INTEGER NOT NULL,
runner_id INTEGER,
frame_start INTEGER NOT NULL,
frame_end INTEGER NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
@@ -133,44 +158,50 @@ func (db *DB) migrate() error {
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
error_message TEXT
error_message TEXT,
FOREIGN KEY (job_id) REFERENCES jobs(id),
FOREIGN KEY (runner_id) REFERENCES runners(id)
);
CREATE TABLE IF NOT EXISTS job_files (
id BIGINT PRIMARY KEY DEFAULT nextval('seq_job_files_id'),
job_id BIGINT NOT NULL,
id INTEGER PRIMARY KEY AUTOINCREMENT,
job_id INTEGER NOT NULL,
file_type TEXT NOT NULL,
file_path TEXT NOT NULL,
file_name TEXT NOT NULL,
file_size BIGINT NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
file_size INTEGER NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (job_id) REFERENCES jobs(id)
);
CREATE TABLE IF NOT EXISTS manager_secrets (
id BIGINT PRIMARY KEY DEFAULT nextval('seq_manager_secrets_id'),
id INTEGER PRIMARY KEY AUTOINCREMENT,
secret TEXT UNIQUE NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS task_logs (
id BIGINT PRIMARY KEY DEFAULT nextval('seq_task_logs_id'),
task_id BIGINT NOT NULL,
runner_id BIGINT,
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_id INTEGER NOT NULL,
runner_id INTEGER,
log_level TEXT NOT NULL,
message TEXT NOT NULL,
step_name TEXT,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (task_id) REFERENCES tasks(id),
FOREIGN KEY (runner_id) REFERENCES runners(id)
);
CREATE TABLE IF NOT EXISTS task_steps (
id BIGINT PRIMARY KEY DEFAULT nextval('seq_task_steps_id'),
task_id BIGINT NOT NULL,
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_id INTEGER NOT NULL,
step_name TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
started_at TIMESTAMP,
completed_at TIMESTAMP,
duration_ms INTEGER,
error_message TEXT
error_message TEXT,
FOREIGN KEY (task_id) REFERENCES tasks(id)
);
CREATE INDEX IF NOT EXISTS idx_jobs_user_id ON jobs(user_id);
@@ -184,6 +215,7 @@ func (db *DB) migrate() error {
CREATE INDEX IF NOT EXISTS idx_job_files_job_id ON job_files(job_id);
CREATE INDEX IF NOT EXISTS idx_runner_api_keys_prefix ON runner_api_keys(key_prefix);
CREATE INDEX IF NOT EXISTS idx_runner_api_keys_active ON runner_api_keys(is_active);
CREATE INDEX IF NOT EXISTS idx_runner_api_keys_created_by ON runner_api_keys(created_by);
CREATE INDEX IF NOT EXISTS idx_runners_api_key_id ON runners(api_key_id);
CREATE INDEX IF NOT EXISTS idx_task_logs_task_id_created_at ON task_logs(task_id, created_at);
CREATE INDEX IF NOT EXISTS idx_task_logs_task_id_id ON task_logs(task_id, id DESC);
@@ -196,9 +228,28 @@ func (db *DB) migrate() error {
value TEXT NOT NULL,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT UNIQUE NOT NULL,
user_id INTEGER NOT NULL,
email TEXT NOT NULL,
name TEXT NOT NULL,
is_admin INTEGER NOT NULL DEFAULT 0,
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users(id)
);
CREATE INDEX IF NOT EXISTS idx_sessions_session_id ON sessions(session_id);
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id);
CREATE INDEX IF NOT EXISTS idx_sessions_expires_at ON sessions(expires_at);
`
if _, err := db.Exec(schema); err != nil {
if err := db.With(func(conn *sql.DB) error {
_, err := conn.Exec(schema)
return err
}); err != nil {
return fmt.Errorf("failed to create schema: %w", err)
}
@@ -212,19 +263,26 @@ func (db *DB) migrate() error {
}
for _, migration := range migrations {
// DuckDB supports IF NOT EXISTS for ALTER TABLE, so we can safely execute
if _, err := db.Exec(migration); err != nil {
if err := db.With(func(conn *sql.DB) error {
_, err := conn.Exec(migration)
return err
}); err != nil {
// Log but don't fail - column might already exist or table might not exist yet
// This is fine for migrations that run after schema creation
// For the file_size migration, if it fails (e.g., already BIGINT), that's fine
log.Printf("Migration warning: %v", err)
}
}
// Initialize registration_enabled setting (default: true) if it doesn't exist
var settingCount int
err := db.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
err := db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
})
if err == nil && settingCount == 0 {
_, err = db.Exec("INSERT INTO settings (key, value) VALUES (?, ?)", "registration_enabled", "true")
err = db.With(func(conn *sql.DB) error {
_, err := conn.Exec("INSERT INTO settings (key, value) VALUES (?, ?)", "registration_enabled", "true")
return err
})
if err != nil {
// Log but don't fail - setting might have been created by another process
log.Printf("Note: Could not initialize registration_enabled setting: %v", err)
@@ -234,7 +292,16 @@ func (db *DB) migrate() error {
return nil
}
// Ping checks the database connection
func (db *DB) Ping() error {
db.mu.Lock()
defer db.mu.Unlock()
return db.db.Ping()
}
// Close closes the database connection
func (db *DB) Close() error {
return db.DB.Close()
db.mu.Lock()
defer db.mu.Unlock()
return db.db.Close()
}

View File

@@ -1,31 +1,87 @@
package logger
import (
"fmt"
"io"
"log"
"os"
"path/filepath"
"sync"
"gopkg.in/natefinch/lumberjack.v2"
)
// Level represents log severity
type Level int
const (
LevelDebug Level = iota
LevelInfo
LevelWarn
LevelError
)
var levelNames = map[Level]string{
LevelDebug: "DEBUG",
LevelInfo: "INFO",
LevelWarn: "WARN",
LevelError: "ERROR",
}
// ParseLevel parses a level string into a Level
func ParseLevel(s string) Level {
switch s {
case "debug", "DEBUG":
return LevelDebug
case "info", "INFO":
return LevelInfo
case "warn", "WARN", "warning", "WARNING":
return LevelWarn
case "error", "ERROR":
return LevelError
default:
return LevelInfo
}
}
var (
defaultLogger *Logger
once sync.Once
currentLevel Level = LevelInfo
)
// Logger wraps the standard log.Logger with file and stdout output
// Logger wraps the standard log.Logger with optional file output and levels
type Logger struct {
*log.Logger
fileWriter io.WriteCloser
}
// Init initializes the default logger with both file and stdout output
func Init(logDir, logFileName string, maxSizeMB int, maxBackups int, maxAgeDays int) error {
// SetLevel sets the global log level
func SetLevel(level Level) {
currentLevel = level
}
// GetLevel returns the current log level
func GetLevel() Level {
return currentLevel
}
// InitStdout initializes the logger to only write to stdout
func InitStdout() {
once.Do(func() {
log.SetOutput(os.Stdout)
log.SetFlags(log.LstdFlags | log.Lshortfile)
defaultLogger = &Logger{
Logger: log.Default(),
fileWriter: nil,
}
})
}
// InitWithFile initializes the logger with both file and stdout output
// The file is truncated on each start
func InitWithFile(logPath string) error {
var err error
once.Do(func() {
defaultLogger, err = New(logDir, logFileName, maxSizeMB, maxBackups, maxAgeDays)
defaultLogger, err = NewWithFile(logPath)
if err != nil {
return
}
@@ -37,22 +93,19 @@ func Init(logDir, logFileName string, maxSizeMB int, maxBackups int, maxAgeDays
return err
}
// New creates a new logger that writes to both stdout and a log file
func New(logDir, logFileName string, maxSizeMB int, maxBackups int, maxAgeDays int) (*Logger, error) {
// NewWithFile creates a new logger that writes to both stdout and a log file
// The file is truncated on each start
func NewWithFile(logPath string) (*Logger, error) {
// Ensure log directory exists
logDir := filepath.Dir(logPath)
if err := os.MkdirAll(logDir, 0755); err != nil {
return nil, err
}
logPath := filepath.Join(logDir, logFileName)
// Create file writer with rotation
fileWriter := &lumberjack.Logger{
Filename: logPath,
MaxSize: maxSizeMB, // megabytes
MaxBackups: maxBackups, // number of backup files
MaxAge: maxAgeDays, // days
Compress: true, // compress old log files
// Create/truncate the log file
fileWriter, err := os.Create(logPath)
if err != nil {
return nil, err
}
// Create multi-writer that writes to both stdout and file
@@ -80,48 +133,91 @@ func GetDefault() *Logger {
return defaultLogger
}
// Printf logs a formatted message
func Printf(format string, v ...interface{}) {
if defaultLogger != nil {
defaultLogger.Printf(format, v...)
} else {
log.Printf(format, v...)
// logf logs a formatted message at the given level
func logf(level Level, format string, v ...interface{}) {
if level < currentLevel {
return
}
prefix := fmt.Sprintf("[%s] ", levelNames[level])
msg := fmt.Sprintf(format, v...)
log.Print(prefix + msg)
}
// Print logs a message
func Print(v ...interface{}) {
if defaultLogger != nil {
defaultLogger.Print(v...)
} else {
log.Print(v...)
// logln logs a message at the given level
func logln(level Level, v ...interface{}) {
if level < currentLevel {
return
}
prefix := fmt.Sprintf("[%s] ", levelNames[level])
msg := fmt.Sprint(v...)
log.Print(prefix + msg)
}
// Println logs a message with newline
func Println(v ...interface{}) {
if defaultLogger != nil {
defaultLogger.Println(v...)
} else {
log.Println(v...)
}
// Debug logs a debug message
func Debug(v ...interface{}) {
logln(LevelDebug, v...)
}
// Fatal logs a message and exits
// Debugf logs a formatted debug message
func Debugf(format string, v ...interface{}) {
logf(LevelDebug, format, v...)
}
// Info logs an info message
func Info(v ...interface{}) {
logln(LevelInfo, v...)
}
// Infof logs a formatted info message
func Infof(format string, v ...interface{}) {
logf(LevelInfo, format, v...)
}
// Warn logs a warning message
func Warn(v ...interface{}) {
logln(LevelWarn, v...)
}
// Warnf logs a formatted warning message
func Warnf(format string, v ...interface{}) {
logf(LevelWarn, format, v...)
}
// Error logs an error message
func Error(v ...interface{}) {
logln(LevelError, v...)
}
// Errorf logs a formatted error message
func Errorf(format string, v ...interface{}) {
logf(LevelError, format, v...)
}
// Fatal logs an error message and exits
func Fatal(v ...interface{}) {
if defaultLogger != nil {
defaultLogger.Fatal(v...)
} else {
log.Fatal(v...)
}
logln(LevelError, v...)
os.Exit(1)
}
// Fatalf logs a formatted message and exits
// Fatalf logs a formatted error message and exits
func Fatalf(format string, v ...interface{}) {
if defaultLogger != nil {
defaultLogger.Fatalf(format, v...)
} else {
log.Fatalf(format, v...)
}
logf(LevelError, format, v...)
os.Exit(1)
}
// --- Backwards compatibility (maps to Info level) ---
// Printf logs a formatted message at Info level
func Printf(format string, v ...interface{}) {
logf(LevelInfo, format, v...)
}
// Print logs a message at Info level
func Print(v ...interface{}) {
logln(LevelInfo, v...)
}
// Println logs a message at Info level
func Println(v ...interface{}) {
logln(LevelInfo, v...)
}

View File

@@ -25,6 +25,7 @@ import (
"sync"
"time"
"jiggablend/pkg/executils"
"jiggablend/pkg/scripts"
"jiggablend/pkg/types"
@@ -45,19 +46,19 @@ type Client struct {
stopChan chan struct{}
stepStartTimes map[string]time.Time // key: "taskID:stepName"
stepTimesMu sync.RWMutex
workspaceDir string // Persistent workspace directory for this runner
runningProcs sync.Map // map[int64]*exec.Cmd - tracks running processes by task ID
capabilities map[string]interface{} // Cached capabilities from initial probe (includes bools and numbers)
capabilitiesMu sync.RWMutex // Protects capabilities
hwAccelCache map[string]bool // Cached hardware acceleration detection results
hwAccelCacheMu sync.RWMutex // Protects hwAccelCache
vaapiDevices []string // Cached VAAPI device paths (all available devices)
vaapiDevicesMu sync.RWMutex // Protects vaapiDevices
allocatedDevices map[int64]string // map[taskID]device - tracks which device is allocated to which task
allocatedDevicesMu sync.RWMutex // Protects allocatedDevices
longRunningClient *http.Client // HTTP client for long-running operations (no timeout)
fingerprint string // Unique hardware fingerprint for this runner
fingerprintMu sync.RWMutex // Protects fingerprint
workspaceDir string // Persistent workspace directory for this runner
processTracker *executils.ProcessTracker // Tracks running processes for cleanup
capabilities map[string]interface{} // Cached capabilities from initial probe (includes bools and numbers)
capabilitiesMu sync.RWMutex // Protects capabilities
hwAccelCache map[string]bool // Cached hardware acceleration detection results
hwAccelCacheMu sync.RWMutex // Protects hwAccelCache
vaapiDevices []string // Cached VAAPI device paths (all available devices)
vaapiDevicesMu sync.RWMutex // Protects vaapiDevices
allocatedDevices map[int64]string // map[taskID]device - tracks which device is allocated to which task
allocatedDevicesMu sync.RWMutex // Protects allocatedDevices
longRunningClient *http.Client // HTTP client for long-running operations (no timeout)
fingerprint string // Unique hardware fingerprint for this runner
fingerprintMu sync.RWMutex // Protects fingerprint
}
// NewClient creates a new runner client
@@ -70,6 +71,7 @@ func NewClient(managerURL, name, hostname string) *Client {
longRunningClient: &http.Client{Timeout: 0}, // No timeout for long-running operations (context downloads, file uploads/downloads)
stopChan: make(chan struct{}),
stepStartTimes: make(map[string]time.Time),
processTracker: executils.NewProcessTracker(),
}
// Generate fingerprint immediately
client.generateFingerprint()
@@ -226,12 +228,6 @@ func (c *Client) probeCapabilities() map[string]interface{} {
c.probeGPUCapabilities(capabilities)
} else {
capabilities["ffmpeg"] = false
// Set defaults when ffmpeg is not available
capabilities["vaapi"] = false
capabilities["vaapi_gpu_count"] = 0
capabilities["nvenc"] = false
capabilities["nvenc_gpu_count"] = 0
capabilities["video_gpu_count"] = 0
}
return capabilities
@@ -256,60 +252,6 @@ func (c *Client) probeGPUCapabilities(capabilities map[string]interface{}) {
log.Printf("Available hardware encoders: %v", getKeys(hwEncoders))
}
// Check for VAAPI devices and count them
log.Printf("Checking for VAAPI hardware acceleration...")
// First check if encoder is listed (more reliable than testing)
cmd := exec.Command("ffmpeg", "-hide_banner", "-encoders")
output, err := cmd.CombinedOutput()
hasVAAPIEncoder := false
if err == nil {
encoderOutput := string(output)
if strings.Contains(encoderOutput, "h264_vaapi") {
hasVAAPIEncoder = true
log.Printf("VAAPI encoder (h264_vaapi) found in ffmpeg encoders list")
}
}
if hasVAAPIEncoder {
// Try to find and test devices
vaapiDevices := c.findVAAPIDevices()
capabilities["vaapi_gpu_count"] = len(vaapiDevices)
if len(vaapiDevices) > 0 {
capabilities["vaapi"] = true
log.Printf("VAAPI detected: %d GPU device(s) available: %v", len(vaapiDevices), vaapiDevices)
} else {
capabilities["vaapi"] = false
log.Printf("VAAPI encoder available but no working devices found")
log.Printf(" This might indicate:")
log.Printf(" - Missing or incorrect GPU drivers")
log.Printf(" - Missing libva or mesa-va-drivers packages")
log.Printf(" - Permission issues accessing /dev/dri devices")
log.Printf(" - GPU not properly initialized")
}
} else {
capabilities["vaapi"] = false
capabilities["vaapi_gpu_count"] = 0
log.Printf("VAAPI encoder not available in ffmpeg")
log.Printf(" This might indicate:")
log.Printf(" - FFmpeg was not compiled with VAAPI support")
log.Printf(" - Missing libva development libraries during FFmpeg compilation")
}
// Check for NVENC (NVIDIA) - try to detect multiple GPUs
log.Printf("Checking for NVENC hardware acceleration...")
if c.checkEncoderAvailable("h264_nvenc") {
capabilities["nvenc"] = true
// Try to detect actual GPU count using nvidia-smi if available
nvencCount := c.detectNVENCCount()
capabilities["nvenc_gpu_count"] = nvencCount
log.Printf("NVENC detected: %d GPU(s)", nvencCount)
} else {
capabilities["nvenc"] = false
capabilities["nvenc_gpu_count"] = 0
log.Printf("NVENC encoder not available")
}
// Check for other hardware encoders (for completeness)
log.Printf("Checking for other hardware encoders...")
if c.checkEncoderAvailable("h264_qsv") {
@@ -368,73 +310,6 @@ func (c *Client) probeGPUCapabilities(capabilities map[string]interface{}) {
capabilities["mediacodec"] = false
capabilities["mediacodec_gpu_count"] = 0
}
// Calculate total GPU count for video encoding
// Priority: VAAPI > NVENC > QSV > VideoToolbox > AMF > others
vaapiCount := 0
if count, ok := capabilities["vaapi_gpu_count"].(int); ok {
vaapiCount = count
}
nvencCount := 0
if count, ok := capabilities["nvenc_gpu_count"].(int); ok {
nvencCount = count
}
qsvCount := 0
if count, ok := capabilities["qsv_gpu_count"].(int); ok {
qsvCount = count
}
videotoolboxCount := 0
if count, ok := capabilities["videotoolbox_gpu_count"].(int); ok {
videotoolboxCount = count
}
amfCount := 0
if count, ok := capabilities["amf_gpu_count"].(int); ok {
amfCount = count
}
// Total GPU count - use the best available (they can't be used simultaneously)
totalGPUs := vaapiCount
if totalGPUs == 0 {
totalGPUs = nvencCount
}
if totalGPUs == 0 {
totalGPUs = qsvCount
}
if totalGPUs == 0 {
totalGPUs = videotoolboxCount
}
if totalGPUs == 0 {
totalGPUs = amfCount
}
capabilities["video_gpu_count"] = totalGPUs
if totalGPUs > 0 {
log.Printf("Total video GPU count: %d", totalGPUs)
} else {
log.Printf("No hardware-accelerated video encoding GPUs detected (will use software encoding)")
}
}
// detectNVENCCount tries to detect the actual number of NVIDIA GPUs using nvidia-smi
func (c *Client) detectNVENCCount() int {
// Try to use nvidia-smi to count GPUs
cmd := exec.Command("nvidia-smi", "--list-gpus")
output, err := cmd.CombinedOutput()
if err == nil {
// Count lines that contain "GPU" (each GPU is listed on a separate line)
lines := strings.Split(string(output), "\n")
count := 0
for _, line := range lines {
if strings.Contains(line, "GPU") {
count++
}
}
if count > 0 {
return count
}
}
// Fallback to 1 if nvidia-smi is not available
return 1
}
// getKeys returns all keys from a map as a slice (helper function)
@@ -926,29 +801,13 @@ func (c *Client) sendLog(taskID int64, logLevel types.LogLevel, message, stepNam
// KillAllProcesses kills all running processes tracked by this client
func (c *Client) KillAllProcesses() {
log.Printf("Killing all running processes...")
var killedCount int
c.runningProcs.Range(func(key, value interface{}) bool {
taskID := key.(int64)
cmd := value.(*exec.Cmd)
if cmd.Process != nil {
log.Printf("Killing process for task %d (PID: %d)", taskID, cmd.Process.Pid)
// Try graceful kill first (SIGTERM)
if err := cmd.Process.Signal(os.Interrupt); err != nil {
log.Printf("Failed to send SIGINT to process %d: %v", cmd.Process.Pid, err)
}
// Give it a moment to clean up
time.Sleep(100 * time.Millisecond)
// Force kill if still running
if err := cmd.Process.Kill(); err != nil {
log.Printf("Failed to kill process %d: %v", cmd.Process.Pid, err)
} else {
killedCount++
}
}
// Release any allocated device for this task
c.releaseVAAPIDevice(taskID)
return true
})
killedCount := c.processTracker.KillAll()
// Release all allocated VAAPI devices
c.allocatedDevicesMu.Lock()
for taskID := range c.allocatedDevices {
delete(c.allocatedDevices, taskID)
}
c.allocatedDevicesMu.Unlock()
log.Printf("Killed %d process(es)", killedCount)
}
@@ -1272,55 +1131,33 @@ func (c *Client) processTask(task map[string]interface{}, jobName string, output
// Run Blender with GPU enabled via Python script
// Use -s (start) and -e (end) for frame ranges, or -f for single frame
// Use Blender's automatic frame numbering with #### pattern
var cmd *exec.Cmd
args := []string{"-b", blendFile, "--python", scriptPath}
if enableExecution {
args = append(args, "--enable-autoexec")
}
// Always render frames individually for precise control over file naming
// This avoids Blender's automatic frame numbering quirks
for frame := frameStart; frame <= frameEnd; frame++ {
// Create temp output pattern for this frame
tempPattern := filepath.Join(outputDir, fmt.Sprintf("temp_frame.%s", strings.ToLower(renderFormat)))
tempAbsPattern, _ := filepath.Abs(tempPattern)
// Build args for this specific frame
frameArgs := []string{"-b", blendFile, "--python", scriptPath}
if enableExecution {
frameArgs = append(frameArgs, "--enable-autoexec")
}
frameArgs = append(frameArgs, "-o", tempAbsPattern, "-f", fmt.Sprintf("%d", frame))
// Output pattern uses #### which Blender will replace with frame numbers
outputPattern := filepath.Join(outputDir, fmt.Sprintf("frame_####.%s", strings.ToLower(renderFormat)))
outputAbsPattern, _ := filepath.Abs(outputPattern)
args = append(args, "-o", outputAbsPattern)
c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Rendering frame %d...", frame), "render_blender")
frameCmd := exec.Command("blender", frameArgs...)
frameCmd.Dir = workDir
frameCmd.Env = os.Environ()
// Run this frame
if output, err := frameCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("blender failed on frame %d: %v (output: %s)", frame, err, string(output))
c.sendLog(taskID, types.LogLevelError, errMsg, "render_blender")
return errors.New(errMsg)
}
// Immediately rename the temp file to the proper frame-numbered name
finalName := fmt.Sprintf("frame_%04d.%s", frame, strings.ToLower(renderFormat))
finalPath := filepath.Join(outputDir, finalName)
tempPath := filepath.Join(outputDir, fmt.Sprintf("temp_frame.%s", strings.ToLower(renderFormat)))
if err := os.Rename(tempPath, finalPath); err != nil {
errMsg := fmt.Sprintf("failed to rename temp file for frame %d: %v", frame, err)
c.sendLog(taskID, types.LogLevelError, errMsg, "render_blender")
return errors.New(errMsg)
}
c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Completed frame %d -> %s", frame, finalName), "render_blender")
if frameStart == frameEnd {
// Single frame
args = append(args, "-f", fmt.Sprintf("%d", frameStart))
c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Rendering frame %d...", frameStart), "render_blender")
} else {
// Frame range
args = append(args, "-s", fmt.Sprintf("%d", frameStart), "-e", fmt.Sprintf("%d", frameEnd))
c.sendLog(taskID, types.LogLevelInfo, fmt.Sprintf("Rendering frames %d-%d...", frameStart, frameEnd), "render_blender")
}
// Skip the rest of the function since we handled all frames above
c.sendStepUpdate(taskID, "render_blender", types.StepStatusCompleted, "")
return nil
// Create and run Blender command
cmd = exec.Command("blender", args...)
cmd.Dir = workDir
cmd.Env = os.Environ()
// Blender will handle headless rendering automatically
// We preserve the environment to allow GPU access if available
@@ -1350,8 +1187,8 @@ func (c *Client) processTask(task map[string]interface{}, jobName string, output
}
// Register process for cleanup on shutdown
c.runningProcs.Store(taskID, cmd)
defer c.runningProcs.Delete(taskID)
c.processTracker.Track(taskID, cmd)
defer c.processTracker.Untrack(taskID)
// Stream stdout line by line
stdoutDone := make(chan bool)
@@ -1396,15 +1233,23 @@ func (c *Client) processTask(task map[string]interface{}, jobName string, output
<-stdoutDone
<-stderrDone
if err != nil {
errMsg := fmt.Sprintf("blender failed: %v", err)
var errMsg string
if exitErr, ok := err.(*exec.ExitError); ok {
if exitErr.ExitCode() == 137 {
errMsg = "Blender was killed due to excessive memory usage (OOM)"
} else {
errMsg = fmt.Sprintf("blender failed: %v", err)
}
} else {
errMsg = fmt.Sprintf("blender failed: %v", err)
}
c.sendLog(taskID, types.LogLevelError, errMsg, "render_blender")
c.sendStepUpdate(taskID, "render_blender", types.StepStatusFailed, errMsg)
return errors.New(errMsg)
}
// For frame ranges, we rendered each frame individually with temp naming
// The files are already properly named during the individual frame rendering
// No additional renaming needed
// Blender has rendered frames with automatic numbering using the #### pattern
// Files will be named like frame_0001.png, frame_0002.png, etc.
// Find rendered output file(s)
// For frame ranges, we'll find all frames in the upload step
@@ -1748,8 +1593,8 @@ func (c *Client) processVideoGenerationTask(task map[string]interface{}, jobID i
// Extract frame number pattern (e.g., frame_2470.exr -> frame_%04d.exr)
baseName := filepath.Base(firstFrame)
// Find the numeric part and replace it with %04d pattern
// Use regex to find digits (including negative) after underscore and before extension
re := regexp.MustCompile(`_(-?\d+)\.`)
// Use regex to find digits (positive only, negative frames not supported) after underscore and before extension
re := regexp.MustCompile(`_(\d+)\.`)
var pattern string
var startNumber int
frameNumStr := re.FindStringSubmatch(baseName)
@@ -1763,6 +1608,7 @@ func (c *Client) processVideoGenerationTask(task map[string]interface{}, jobID i
startNumber = extractFrameNumber(baseName)
pattern = strings.Replace(baseName, fmt.Sprintf("%d", startNumber), "%04d", 1)
}
// Pattern path should be in workDir where the frame files are downloaded
patternPath := filepath.Join(workDir, pattern)
// Allocate a VAAPI device for this task (if available)
@@ -1891,8 +1737,8 @@ func (c *Client) processVideoGenerationTask(task map[string]interface{}, jobID i
}
// Register process for cleanup on shutdown
c.runningProcs.Store(taskID, cmd)
defer c.runningProcs.Delete(taskID)
c.processTracker.Track(taskID, cmd)
defer c.processTracker.Untrack(taskID)
// Stream stdout line by line
stdoutDone := make(chan bool)
@@ -1959,26 +1805,25 @@ func (c *Client) processVideoGenerationTask(task map[string]interface{}, jobID i
<-stderrDone
if err != nil {
var errMsg string
if exitErr, ok := err.(*exec.ExitError); ok {
if exitErr.ExitCode() == 137 {
errMsg = "FFmpeg was killed due to excessive memory usage (OOM)"
} else {
errMsg = fmt.Sprintf("ffmpeg encoding failed: %v", err)
}
} else {
errMsg = fmt.Sprintf("ffmpeg encoding failed: %v", err)
}
// Check for size-related errors and provide helpful messages
if sizeErr := c.checkFFmpegSizeError("ffmpeg encoding failed"); sizeErr != nil {
if sizeErr := c.checkFFmpegSizeError(errMsg); sizeErr != nil {
c.sendLog(taskID, types.LogLevelError, sizeErr.Error(), "generate_video")
c.sendStepUpdate(taskID, "generate_video", types.StepStatusFailed, sizeErr.Error())
return sizeErr
}
// Try alternative method with concat demuxer
c.sendLog(taskID, types.LogLevelWarn, "Primary ffmpeg encoding failed, trying concat method...", "generate_video")
err = c.generateMP4WithConcat(frameFiles, outputMP4, workDir, allocatedDevice, outputFormat, codec, pixFmt, useAlpha, useHardware, frameRate)
if err != nil {
// Check for size errors in concat method too
if sizeErr := c.checkFFmpegSizeError(err.Error()); sizeErr != nil {
c.sendLog(taskID, types.LogLevelError, sizeErr.Error(), "generate_video")
c.sendStepUpdate(taskID, "generate_video", types.StepStatusFailed, sizeErr.Error())
return sizeErr
}
c.sendStepUpdate(taskID, "generate_video", types.StepStatusFailed, err.Error())
return err
}
c.sendLog(taskID, types.LogLevelError, errMsg, "generate_video")
c.sendStepUpdate(taskID, "generate_video", types.StepStatusFailed, errMsg)
return errors.New(errMsg)
}
// Check if MP4 was created
@@ -2771,7 +2616,7 @@ func (c *Client) testGenericEncoder(encoder string) bool {
// generateMP4WithConcat uses ffmpeg concat demuxer as fallback
// device parameter is optional - if provided, it will be used for VAAPI encoding
func (c *Client) generateMP4WithConcat(frameFiles []string, outputMP4, workDir string, device string, outputFormat string, codec string, pixFmt string, useAlpha bool, useHardware bool, frameRate float64) error {
func (c *Client) generateMP4WithConcat(taskID int, frameFiles []string, outputMP4, workDir string, device string, outputFormat string, codec string, pixFmt string, useAlpha bool, useHardware bool, frameRate float64) error {
// Create file list for ffmpeg concat demuxer
listFile := filepath.Join(workDir, "frames.txt")
listFileHandle, err := os.Create(listFile)
@@ -2907,11 +2752,23 @@ func (c *Client) generateMP4WithConcat(frameFiles []string, outputMP4, workDir s
<-stderrDone
if err != nil {
var errMsg string
if exitErr, ok := err.(*exec.ExitError); ok {
if exitErr.ExitCode() == 137 {
errMsg = "FFmpeg was killed due to excessive memory usage (OOM)"
} else {
errMsg = fmt.Sprintf("ffmpeg concat failed: %v", err)
}
} else {
errMsg = fmt.Sprintf("ffmpeg concat failed: %v", err)
}
// Check for size-related errors
if sizeErr := c.checkFFmpegSizeError("ffmpeg concat failed"); sizeErr != nil {
if sizeErr := c.checkFFmpegSizeError(errMsg); sizeErr != nil {
return sizeErr
}
return fmt.Errorf("ffmpeg concat failed: %w", err)
c.sendLog(int64(taskID), types.LogLevelError, errMsg, "generate_video")
c.sendStepUpdate(int64(taskID), "generate_video", types.StepStatusFailed, errMsg)
return errors.New(errMsg)
}
if _, err := os.Stat(outputMP4); os.IsNotExist(err) {
@@ -3695,8 +3552,8 @@ sys.stdout.flush()
}
// Register process for cleanup on shutdown
c.runningProcs.Store(taskID, cmd)
defer c.runningProcs.Delete(taskID)
c.processTracker.Track(taskID, cmd)
defer c.processTracker.Untrack(taskID)
// Stream stdout line by line and collect for JSON parsing
stdoutDone := make(chan bool)
@@ -3743,7 +3600,16 @@ sys.stdout.flush()
<-stdoutDone
<-stderrDone
if err != nil {
errMsg := fmt.Sprintf("blender metadata extraction failed: %v", err)
var errMsg string
if exitErr, ok := err.(*exec.ExitError); ok {
if exitErr.ExitCode() == 137 {
errMsg = "Blender metadata extraction was killed due to excessive memory usage (OOM)"
} else {
errMsg = fmt.Sprintf("blender metadata extraction failed: %v", err)
}
} else {
errMsg = fmt.Sprintf("blender metadata extraction failed: %v", err)
}
c.sendLog(taskID, types.LogLevelError, errMsg, "extract_metadata")
c.sendStepUpdate(taskID, "extract_metadata", types.StepStatusFailed, errMsg)
return errors.New(errMsg)

View File

@@ -378,43 +378,47 @@ func (s *Storage) CreateJobContext(jobID int64) (string, error) {
return "", fmt.Errorf("failed to open file %s: %w", filePath, err)
}
info, err := file.Stat()
// Use a function closure to ensure file is closed even on error
err = func() error {
defer file.Close()
info, err := file.Stat()
if err != nil {
return fmt.Errorf("failed to stat file %s: %w", filePath, err)
}
// Get relative path for tar header
relPath, err := filepath.Rel(jobPath, filePath)
if err != nil {
return fmt.Errorf("failed to get relative path for %s: %w", filePath, err)
}
// Normalize path separators for tar (use forward slashes)
tarPath := filepath.ToSlash(relPath)
// Create tar header
header, err := tar.FileInfoHeader(info, "")
if err != nil {
return fmt.Errorf("failed to create tar header for %s: %w", filePath, err)
}
header.Name = tarPath
// Write header
if err := tarWriter.WriteHeader(header); err != nil {
return fmt.Errorf("failed to write tar header for %s: %w", filePath, err)
}
// Copy file contents using streaming
if _, err := io.Copy(tarWriter, file); err != nil {
return fmt.Errorf("failed to write file %s to tar: %w", filePath, err)
}
return nil
}()
if err != nil {
file.Close()
return "", fmt.Errorf("failed to stat file %s: %w", filePath, err)
return "", err
}
// Get relative path for tar header
relPath, err := filepath.Rel(jobPath, filePath)
if err != nil {
file.Close()
return "", fmt.Errorf("failed to get relative path for %s: %w", filePath, err)
}
// Normalize path separators for tar (use forward slashes)
tarPath := filepath.ToSlash(relPath)
// Create tar header
header, err := tar.FileInfoHeader(info, "")
if err != nil {
file.Close()
return "", fmt.Errorf("failed to create tar header for %s: %w", filePath, err)
}
header.Name = tarPath
// Write header
if err := tarWriter.WriteHeader(header); err != nil {
file.Close()
return "", fmt.Errorf("failed to write tar header for %s: %w", filePath, err)
}
// Copy file contents using streaming
if _, err := io.Copy(tarWriter, file); err != nil {
file.Close()
return "", fmt.Errorf("failed to write file %s to tar: %w", filePath, err)
}
file.Close()
}
// Ensure all data is flushed
@@ -550,42 +554,47 @@ func (s *Storage) CreateJobContextFromDir(sourceDir string, jobID int64, exclude
return "", fmt.Errorf("failed to open file %s: %w", filePath, err)
}
info, err := file.Stat()
// Use a function closure to ensure file is closed even on error
err = func() error {
defer file.Close()
info, err := file.Stat()
if err != nil {
return fmt.Errorf("failed to stat file %s: %w", filePath, err)
}
// Get relative path and strip common prefix if present
relPath := relPaths[i]
tarPath := filepath.ToSlash(relPath)
// Strip common prefix if found
if commonPrefix != "" && strings.HasPrefix(tarPath, commonPrefix) {
tarPath = strings.TrimPrefix(tarPath, commonPrefix)
}
// Create tar header
header, err := tar.FileInfoHeader(info, "")
if err != nil {
return fmt.Errorf("failed to create tar header for %s: %w", filePath, err)
}
header.Name = tarPath
// Write header
if err := tarWriter.WriteHeader(header); err != nil {
return fmt.Errorf("failed to write tar header for %s: %w", filePath, err)
}
// Copy file contents using streaming
if _, err := io.Copy(tarWriter, file); err != nil {
return fmt.Errorf("failed to write file %s to tar: %w", filePath, err)
}
return nil
}()
if err != nil {
file.Close()
return "", fmt.Errorf("failed to stat file %s: %w", filePath, err)
return "", err
}
// Get relative path and strip common prefix if present
relPath := relPaths[i]
tarPath := filepath.ToSlash(relPath)
// Strip common prefix if found
if commonPrefix != "" && strings.HasPrefix(tarPath, commonPrefix) {
tarPath = strings.TrimPrefix(tarPath, commonPrefix)
}
// Create tar header
header, err := tar.FileInfoHeader(info, "")
if err != nil {
file.Close()
return "", fmt.Errorf("failed to create tar header for %s: %w", filePath, err)
}
header.Name = tarPath
// Write header
if err := tarWriter.WriteHeader(header); err != nil {
file.Close()
return "", fmt.Errorf("failed to write tar header for %s: %w", filePath, err)
}
// Copy file contents using streaming
if _, err := io.Copy(tarWriter, file); err != nil {
file.Close()
return "", fmt.Errorf("failed to write file %s to tar: %w", filePath, err)
}
file.Close()
}
// Ensure all data is flushed

BIN
jiggablend Executable file

Binary file not shown.

366
pkg/executils/exec.go Normal file
View File

@@ -0,0 +1,366 @@
package executils
import (
"bufio"
"errors"
"fmt"
"os"
"os/exec"
"sync"
"time"
"jiggablend/pkg/types"
)
// DefaultTracker is the global default process tracker
// Use this for processes that should be tracked globally and killed on shutdown
var DefaultTracker = NewProcessTracker()
// ProcessTracker tracks running processes for cleanup
type ProcessTracker struct {
processes sync.Map // map[int64]*exec.Cmd - tracks running processes by task ID
}
// NewProcessTracker creates a new process tracker
func NewProcessTracker() *ProcessTracker {
return &ProcessTracker{}
}
// Track registers a process for tracking
func (pt *ProcessTracker) Track(taskID int64, cmd *exec.Cmd) {
pt.processes.Store(taskID, cmd)
}
// Untrack removes a process from tracking
func (pt *ProcessTracker) Untrack(taskID int64) {
pt.processes.Delete(taskID)
}
// Get returns the command for a task ID if it exists
func (pt *ProcessTracker) Get(taskID int64) (*exec.Cmd, bool) {
if val, ok := pt.processes.Load(taskID); ok {
return val.(*exec.Cmd), true
}
return nil, false
}
// Kill kills a specific process by task ID
// Returns true if the process was found and killed
func (pt *ProcessTracker) Kill(taskID int64) bool {
cmd, ok := pt.Get(taskID)
if !ok || cmd.Process == nil {
return false
}
// Try graceful kill first (SIGINT)
if err := cmd.Process.Signal(os.Interrupt); err != nil {
// If SIGINT fails, try SIGKILL
cmd.Process.Kill()
} else {
// Give it a moment to clean up gracefully
time.Sleep(100 * time.Millisecond)
// Force kill if still running
cmd.Process.Kill()
}
pt.Untrack(taskID)
return true
}
// KillAll kills all tracked processes
// Returns the number of processes killed
func (pt *ProcessTracker) KillAll() int {
var killedCount int
pt.processes.Range(func(key, value interface{}) bool {
taskID := key.(int64)
cmd := value.(*exec.Cmd)
if cmd.Process != nil {
// Try graceful kill first (SIGINT)
if err := cmd.Process.Signal(os.Interrupt); err == nil {
// Give it a moment to clean up
time.Sleep(100 * time.Millisecond)
}
// Force kill
cmd.Process.Kill()
killedCount++
}
pt.processes.Delete(taskID)
return true
})
return killedCount
}
// Count returns the number of tracked processes
func (pt *ProcessTracker) Count() int {
count := 0
pt.processes.Range(func(key, value interface{}) bool {
count++
return true
})
return count
}
// CommandResult holds the output from a command execution
type CommandResult struct {
Stdout string
Stderr string
ExitCode int
}
// RunCommand executes a command and returns the output
// If tracker is provided, the process will be registered for tracking
// This is useful for commands where you need to capture output (like metadata extraction)
func RunCommand(
cmdPath string,
args []string,
dir string,
env []string,
taskID int64,
tracker *ProcessTracker,
) (*CommandResult, error) {
cmd := exec.Command(cmdPath, args...)
cmd.Dir = dir
if env != nil {
cmd.Env = env
}
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
}
stderrPipe, err := cmd.StderrPipe()
if err != nil {
return nil, fmt.Errorf("failed to create stderr pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start command: %w", err)
}
// Track the process if tracker is provided
if tracker != nil {
tracker.Track(taskID, cmd)
defer tracker.Untrack(taskID)
}
// Collect stdout
var stdoutBuf, stderrBuf []byte
var stdoutErr, stderrErr error
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
stdoutBuf, stdoutErr = readAll(stdoutPipe)
}()
go func() {
defer wg.Done()
stderrBuf, stderrErr = readAll(stderrPipe)
}()
waitErr := cmd.Wait()
wg.Wait()
// Check for read errors
if stdoutErr != nil {
return nil, fmt.Errorf("failed to read stdout: %w", stdoutErr)
}
if stderrErr != nil {
return nil, fmt.Errorf("failed to read stderr: %w", stderrErr)
}
result := &CommandResult{
Stdout: string(stdoutBuf),
Stderr: string(stderrBuf),
}
if waitErr != nil {
if exitErr, ok := waitErr.(*exec.ExitError); ok {
result.ExitCode = exitErr.ExitCode()
} else {
result.ExitCode = -1
}
return result, waitErr
}
result.ExitCode = 0
return result, nil
}
// readAll reads all data from a reader
func readAll(r interface{ Read([]byte) (int, error) }) ([]byte, error) {
var buf []byte
tmp := make([]byte, 4096)
for {
n, err := r.Read(tmp)
if n > 0 {
buf = append(buf, tmp[:n]...)
}
if err != nil {
if err.Error() == "EOF" {
break
}
return buf, err
}
}
return buf, nil
}
// LogSender is a function type for sending logs
type LogSender func(taskID int, level types.LogLevel, message string, stepName string)
// LineFilter is a function that processes a line and returns whether to filter it out and the log level
type LineFilter func(line string) (shouldFilter bool, level types.LogLevel)
// RunCommandWithStreaming executes a command with streaming output and OOM detection
// If tracker is provided, the process will be registered for tracking
func RunCommandWithStreaming(
cmdPath string,
args []string,
dir string,
env []string,
taskID int,
stepName string,
logSender LogSender,
stdoutFilter LineFilter,
stderrFilter LineFilter,
oomMessage string,
tracker *ProcessTracker,
) error {
cmd := exec.Command(cmdPath, args...)
cmd.Dir = dir
cmd.Env = env
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
errMsg := fmt.Sprintf("failed to create stdout pipe: %v", err)
logSender(taskID, types.LogLevelError, errMsg, stepName)
return errors.New(errMsg)
}
stderrPipe, err := cmd.StderrPipe()
if err != nil {
errMsg := fmt.Sprintf("failed to create stderr pipe: %v", err)
logSender(taskID, types.LogLevelError, errMsg, stepName)
return errors.New(errMsg)
}
if err := cmd.Start(); err != nil {
errMsg := fmt.Sprintf("failed to start command: %v", err)
logSender(taskID, types.LogLevelError, errMsg, stepName)
return errors.New(errMsg)
}
// Track the process if tracker is provided
if tracker != nil {
tracker.Track(int64(taskID), cmd)
defer tracker.Untrack(int64(taskID))
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
scanner := bufio.NewScanner(stdoutPipe)
for scanner.Scan() {
line := scanner.Text()
if line != "" {
shouldFilter, level := stdoutFilter(line)
if !shouldFilter {
logSender(taskID, level, line, stepName)
}
}
}
}()
go func() {
defer wg.Done()
scanner := bufio.NewScanner(stderrPipe)
for scanner.Scan() {
line := scanner.Text()
if line != "" {
shouldFilter, level := stderrFilter(line)
if !shouldFilter {
logSender(taskID, level, line, stepName)
}
}
}
}()
err = cmd.Wait()
wg.Wait()
if err != nil {
var errMsg string
if exitErr, ok := err.(*exec.ExitError); ok {
if exitErr.ExitCode() == 137 {
errMsg = oomMessage
} else {
errMsg = fmt.Sprintf("command failed: %v", err)
}
} else {
errMsg = fmt.Sprintf("command failed: %v", err)
}
logSender(taskID, types.LogLevelError, errMsg, stepName)
return errors.New(errMsg)
}
return nil
}
// ============================================================================
// Helper functions using DefaultTracker
// ============================================================================
// Run executes a command using the default tracker and returns the output
// This is a convenience wrapper around RunCommand that uses DefaultTracker
func Run(cmdPath string, args []string, dir string, env []string, taskID int64) (*CommandResult, error) {
return RunCommand(cmdPath, args, dir, env, taskID, DefaultTracker)
}
// RunStreaming executes a command with streaming output using the default tracker
// This is a convenience wrapper around RunCommandWithStreaming that uses DefaultTracker
func RunStreaming(
cmdPath string,
args []string,
dir string,
env []string,
taskID int,
stepName string,
logSender LogSender,
stdoutFilter LineFilter,
stderrFilter LineFilter,
oomMessage string,
) error {
return RunCommandWithStreaming(cmdPath, args, dir, env, taskID, stepName, logSender, stdoutFilter, stderrFilter, oomMessage, DefaultTracker)
}
// KillAll kills all processes tracked by the default tracker
// Returns the number of processes killed
func KillAll() int {
return DefaultTracker.KillAll()
}
// Kill kills a specific process by task ID using the default tracker
// Returns true if the process was found and killed
func Kill(taskID int64) bool {
return DefaultTracker.Kill(taskID)
}
// Track registers a process with the default tracker
func Track(taskID int64, cmd *exec.Cmd) {
DefaultTracker.Track(taskID, cmd)
}
// Untrack removes a process from the default tracker
func Untrack(taskID int64) {
DefaultTracker.Untrack(taskID)
}
// GetTrackedCount returns the number of processes tracked by the default tracker
func GetTrackedCount() int {
return DefaultTracker.Count()
}

View File

@@ -46,6 +46,10 @@ scene = bpy.context.scene
frame_start = scene.frame_start
frame_end = scene.frame_end
# Check for negative frames (not supported)
has_negative_start = frame_start < 0
has_negative_end = frame_end < 0
# Also check for actual animation range (keyframes)
# Find the earliest and latest keyframes across all objects
animation_start = None
@@ -54,15 +58,21 @@ animation_end = None
for obj in scene.objects:
if obj.animation_data and obj.animation_data.action:
action = obj.animation_data.action
if action.fcurves:
for fcurve in action.fcurves:
if fcurve.keyframe_points:
for keyframe in fcurve.keyframe_points:
frame = int(keyframe.co[0])
if animation_start is None or frame < animation_start:
animation_start = frame
if animation_end is None or frame > animation_end:
animation_end = frame
# Check if action has fcurves attribute (varies by Blender version/context)
try:
fcurves = action.fcurves if hasattr(action, 'fcurves') else None
if fcurves:
for fcurve in fcurves:
if fcurve.keyframe_points:
for keyframe in fcurve.keyframe_points:
frame = int(keyframe.co[0])
if animation_start is None or frame < animation_start:
animation_start = frame
if animation_end is None or frame > animation_end:
animation_end = frame
except (AttributeError, TypeError) as e:
# Action doesn't have fcurves or fcurves is not iterable - skip this object
pass
# Use animation range if available, otherwise use scene frame range
# If scene range seems wrong (start == end), prefer animation range
@@ -72,6 +82,11 @@ if animation_start is not None and animation_end is not None:
frame_start = animation_start
frame_end = animation_end
# Check for negative frames (not supported)
has_negative_start = frame_start < 0
has_negative_end = frame_end < 0
has_negative_animation = (animation_start is not None and animation_start < 0) or (animation_end is not None and animation_end < 0)
# Extract render settings
render = scene.render
resolution_x = render.resolution_x
@@ -87,56 +102,230 @@ engine_settings = {}
if engine == 'CYCLES':
cycles = scene.cycles
# Get denoiser settings - in Blender 3.0+ it's on the view layer
denoiser = 'OPENIMAGEDENOISE' # Default
denoising_use_gpu = False
denoising_input_passes = 'RGB_ALBEDO_NORMAL' # Default: Albedo and Normal
denoising_prefilter = 'ACCURATE' # Default
denoising_quality = 'HIGH' # Default (for OpenImageDenoise)
try:
view_layer = bpy.context.view_layer
if hasattr(view_layer, 'cycles'):
vl_cycles = view_layer.cycles
denoiser = getattr(vl_cycles, 'denoiser', 'OPENIMAGEDENOISE')
denoising_use_gpu = getattr(vl_cycles, 'denoising_use_gpu', False)
denoising_input_passes = getattr(vl_cycles, 'denoising_input_passes', 'RGB_ALBEDO_NORMAL')
denoising_prefilter = getattr(vl_cycles, 'denoising_prefilter', 'ACCURATE')
# Quality is only for OpenImageDenoise in Blender 4.0+
denoising_quality = getattr(vl_cycles, 'denoising_quality', 'HIGH')
except:
pass
engine_settings = {
"samples": getattr(cycles, 'samples', 128),
# Sampling settings
"samples": getattr(cycles, 'samples', 4096), # Max Samples
"adaptive_min_samples": getattr(cycles, 'adaptive_min_samples', 0), # Min Samples
"use_adaptive_sampling": getattr(cycles, 'use_adaptive_sampling', True), # Noise Threshold enabled
"adaptive_threshold": getattr(cycles, 'adaptive_threshold', 0.01), # Noise Threshold value
"time_limit": getattr(cycles, 'time_limit', 0.0), # Time Limit (0 = disabled)
# Denoising settings
"use_denoising": getattr(cycles, 'use_denoising', False),
"denoising_radius": getattr(cycles, 'denoising_radius', 0),
"denoising_strength": getattr(cycles, 'denoising_strength', 0.0),
"denoiser": denoiser,
"denoising_use_gpu": denoising_use_gpu,
"denoising_input_passes": denoising_input_passes,
"denoising_prefilter": denoising_prefilter,
"denoising_quality": denoising_quality,
# Path Guiding settings
"use_guiding": getattr(cycles, 'use_guiding', False),
"guiding_training_samples": getattr(cycles, 'guiding_training_samples', 128),
"use_surface_guiding": getattr(cycles, 'use_surface_guiding', True),
"use_volume_guiding": getattr(cycles, 'use_volume_guiding', True),
# Lights settings
"use_light_tree": getattr(cycles, 'use_light_tree', True),
"light_sampling_threshold": getattr(cycles, 'light_sampling_threshold', 0.01),
# Device
"device": getattr(cycles, 'device', 'CPU'),
"use_adaptive_sampling": getattr(cycles, 'use_adaptive_sampling', False),
"adaptive_threshold": getattr(cycles, 'adaptive_threshold', 0.01) if getattr(cycles, 'use_adaptive_sampling', False) else 0.01,
"use_fast_gi": getattr(cycles, 'use_fast_gi', False),
"light_tree": getattr(cycles, 'use_light_tree', False),
"use_light_linking": getattr(cycles, 'use_light_linking', False),
"caustics_reflective": getattr(cycles, 'caustics_reflective', False),
"caustics_refractive": getattr(cycles, 'caustics_refractive', False),
"blur_glossy": getattr(cycles, 'blur_glossy', 0.0),
# Advanced/Seed settings
"seed": getattr(cycles, 'seed', 0),
"use_animated_seed": getattr(cycles, 'use_animated_seed', False),
"sampling_pattern": getattr(cycles, 'sampling_pattern', 'AUTOMATIC'),
"scrambling_distance": getattr(cycles, 'scrambling_distance', 1.0),
"auto_scrambling_distance_multiplier": getattr(cycles, 'auto_scrambling_distance_multiplier', 1.0),
"preview_scrambling_distance": getattr(cycles, 'preview_scrambling_distance', False),
"min_light_bounces": getattr(cycles, 'min_light_bounces', 0),
"min_transparent_bounces": getattr(cycles, 'min_transparent_bounces', 0),
# Clamping
"sample_clamp_direct": getattr(cycles, 'sample_clamp_direct', 0.0),
"sample_clamp_indirect": getattr(cycles, 'sample_clamp_indirect', 0.0),
# Light Paths / Bounces
"max_bounces": getattr(cycles, 'max_bounces', 12),
"diffuse_bounces": getattr(cycles, 'diffuse_bounces', 4),
"glossy_bounces": getattr(cycles, 'glossy_bounces', 4),
"transmission_bounces": getattr(cycles, 'transmission_bounces', 12),
"volume_bounces": getattr(cycles, 'volume_bounces', 0),
"transparent_max_bounces": getattr(cycles, 'transparent_max_bounces', 8),
# Caustics
"caustics_reflective": getattr(cycles, 'caustics_reflective', False),
"caustics_refractive": getattr(cycles, 'caustics_refractive', False),
"blur_glossy": getattr(cycles, 'blur_glossy', 0.0), # Filter Glossy
# Fast GI Approximation
"use_fast_gi": getattr(cycles, 'use_fast_gi', False),
"fast_gi_method": getattr(cycles, 'fast_gi_method', 'REPLACE'), # REPLACE or ADD
"ao_bounces": getattr(cycles, 'ao_bounces', 1), # Viewport bounces
"ao_bounces_render": getattr(cycles, 'ao_bounces_render', 1), # Render bounces
# Volumes
"volume_step_rate": getattr(cycles, 'volume_step_rate', 1.0),
"volume_preview_step_rate": getattr(cycles, 'volume_preview_step_rate', 1.0),
"volume_max_steps": getattr(cycles, 'volume_max_steps', 1024),
# Film
"film_exposure": getattr(cycles, 'film_exposure', 1.0),
"film_transparent": getattr(cycles, 'film_transparent', False),
"film_transparent_glass": getattr(cycles, 'film_transparent_glass', False),
"film_transparent_roughness": getattr(cycles, 'film_transparent_roughness', 0.1),
"filter_type": getattr(cycles, 'filter_type', 'BLACKMAN_HARRIS'), # BOX, GAUSSIAN, BLACKMAN_HARRIS
"filter_width": getattr(cycles, 'filter_width', 1.5),
"pixel_filter_type": getattr(cycles, 'pixel_filter_type', 'BLACKMAN_HARRIS'),
# Performance
"use_auto_tile": getattr(cycles, 'use_auto_tile', True),
"tile_size": getattr(cycles, 'tile_size', 2048),
"use_persistent_data": getattr(cycles, 'use_persistent_data', False),
# Hair/Curves
"use_hair": getattr(cycles, 'use_hair', True),
"hair_subdivisions": getattr(cycles, 'hair_subdivisions', 2),
"hair_shape": getattr(cycles, 'hair_shape', 'THICK'), # ROUND, RIBBONS, THICK
# Simplify (from scene.render)
"use_simplify": getattr(scene.render, 'use_simplify', False),
"simplify_subdivision_render": getattr(scene.render, 'simplify_subdivision_render', 6),
"simplify_child_particles_render": getattr(scene.render, 'simplify_child_particles_render', 1.0),
# Other
"use_light_linking": getattr(cycles, 'use_light_linking', False),
"use_layer_samples": getattr(cycles, 'use_layer_samples', False),
}
elif engine == 'EEVEE' or engine == 'EEVEE_NEXT':
# Treat EEVEE_NEXT as EEVEE (modern Blender uses EEVEE for what was EEVEE_NEXT)
eevee = scene.eevee
engine_settings = {
# Sampling
"taa_render_samples": getattr(eevee, 'taa_render_samples', 64),
"taa_samples": getattr(eevee, 'taa_samples', 16), # Viewport samples
"use_taa_reprojection": getattr(eevee, 'use_taa_reprojection', True),
# Clamping
"clamp_surface_direct": getattr(eevee, 'clamp_surface_direct', 0.0),
"clamp_surface_indirect": getattr(eevee, 'clamp_surface_indirect', 0.0),
"clamp_volume_direct": getattr(eevee, 'clamp_volume_direct', 0.0),
"clamp_volume_indirect": getattr(eevee, 'clamp_volume_indirect', 0.0),
# Shadows
"shadow_cube_size": getattr(eevee, 'shadow_cube_size', '512'),
"shadow_cascade_size": getattr(eevee, 'shadow_cascade_size', '1024'),
"use_shadow_high_bitdepth": getattr(eevee, 'use_shadow_high_bitdepth', False),
"use_soft_shadows": getattr(eevee, 'use_soft_shadows', True),
"light_threshold": getattr(eevee, 'light_threshold', 0.01),
# Raytracing (EEVEE Next / modern EEVEE)
"use_raytracing": getattr(eevee, 'use_raytracing', False),
"ray_tracing_method": getattr(eevee, 'ray_tracing_method', 'SCREEN'), # SCREEN or PROBE
"ray_tracing_options_trace_max_roughness": getattr(eevee, 'ray_tracing_options', {}).get('trace_max_roughness', 0.5) if hasattr(getattr(eevee, 'ray_tracing_options', None), 'get') else 0.5,
# Screen Space Reflections (legacy/fallback)
"use_ssr": getattr(eevee, 'use_ssr', False),
"use_ssr_refraction": getattr(eevee, 'use_ssr_refraction', False),
"use_ssr_halfres": getattr(eevee, 'use_ssr_halfres', True),
"ssr_quality": getattr(eevee, 'ssr_quality', 0.25),
"ssr_max_roughness": getattr(eevee, 'ssr_max_roughness', 0.5),
"ssr_thickness": getattr(eevee, 'ssr_thickness', 0.2),
"ssr_border_fade": getattr(eevee, 'ssr_border_fade', 0.075),
"ssr_firefly_fac": getattr(eevee, 'ssr_firefly_fac', 10.0),
# Ambient Occlusion
"use_gtao": getattr(eevee, 'use_gtao', False),
"gtao_distance": getattr(eevee, 'gtao_distance', 0.2),
"gtao_factor": getattr(eevee, 'gtao_factor', 1.0),
"gtao_quality": getattr(eevee, 'gtao_quality', 0.25),
"use_gtao_bent_normals": getattr(eevee, 'use_gtao_bent_normals', True),
"use_gtao_bounce": getattr(eevee, 'use_gtao_bounce', True),
# Bloom
"use_bloom": getattr(eevee, 'use_bloom', False),
"bloom_threshold": getattr(eevee, 'bloom_threshold', 0.8),
"bloom_intensity": getattr(eevee, 'bloom_intensity', 0.05),
"bloom_knee": getattr(eevee, 'bloom_knee', 0.5),
"bloom_radius": getattr(eevee, 'bloom_radius', 6.5),
"use_ssr": getattr(eevee, 'use_ssr', True),
"use_ssr_refraction": getattr(eevee, 'use_ssr_refraction', False),
"ssr_quality": getattr(eevee, 'ssr_quality', 'MEDIUM'),
"use_ssao": getattr(eevee, 'use_ssao', True),
"ssao_quality": getattr(eevee, 'ssao_quality', 'MEDIUM'),
"ssao_distance": getattr(eevee, 'ssao_distance', 0.2),
"ssao_factor": getattr(eevee, 'ssao_factor', 1.0),
"use_soft_shadows": getattr(eevee, 'use_soft_shadows', True),
"use_shadow_high_bitdepth": getattr(eevee, 'use_shadow_high_bitdepth', True),
"use_volumetric": getattr(eevee, 'use_volumetric', False),
"bloom_color": list(getattr(eevee, 'bloom_color', (1.0, 1.0, 1.0))),
"bloom_intensity": getattr(eevee, 'bloom_intensity', 0.05),
"bloom_clamp": getattr(eevee, 'bloom_clamp', 0.0),
# Depth of Field
"bokeh_max_size": getattr(eevee, 'bokeh_max_size', 100.0),
"bokeh_threshold": getattr(eevee, 'bokeh_threshold', 1.0),
"bokeh_neighbor_max": getattr(eevee, 'bokeh_neighbor_max', 10.0),
"bokeh_denoise_fac": getattr(eevee, 'bokeh_denoise_fac', 0.75),
"use_bokeh_high_quality_slight_defocus": getattr(eevee, 'use_bokeh_high_quality_slight_defocus', False),
"use_bokeh_jittered": getattr(eevee, 'use_bokeh_jittered', False),
"bokeh_overblur": getattr(eevee, 'bokeh_overblur', 5.0),
# Subsurface Scattering
"sss_samples": getattr(eevee, 'sss_samples', 7),
"sss_jitter_threshold": getattr(eevee, 'sss_jitter_threshold', 0.3),
# Volumetrics
"use_volumetric_lights": getattr(eevee, 'use_volumetric_lights', True),
"use_volumetric_shadows": getattr(eevee, 'use_volumetric_shadows', False),
"volumetric_start": getattr(eevee, 'volumetric_start', 0.1),
"volumetric_end": getattr(eevee, 'volumetric_end', 100.0),
"volumetric_tile_size": getattr(eevee, 'volumetric_tile_size', '8'),
"volumetric_samples": getattr(eevee, 'volumetric_samples', 64),
"volumetric_start": getattr(eevee, 'volumetric_start', 0.0),
"volumetric_end": getattr(eevee, 'volumetric_end', 100.0),
"use_volumetric_lights": getattr(eevee, 'use_volumetric_lights', True),
"use_volumetric_shadows": getattr(eevee, 'use_volumetric_shadows', True),
"use_gtao": getattr(eevee, 'use_gtao', False),
"gtao_quality": getattr(eevee, 'gtao_quality', 'MEDIUM'),
"volumetric_sample_distribution": getattr(eevee, 'volumetric_sample_distribution', 0.8),
"volumetric_ray_depth": getattr(eevee, 'volumetric_ray_depth', 16),
# Motion Blur
"use_motion_blur": getattr(eevee, 'use_motion_blur', False),
"motion_blur_position": getattr(eevee, 'motion_blur_position', 'CENTER'),
"motion_blur_shutter": getattr(eevee, 'motion_blur_shutter', 0.5),
"motion_blur_depth_scale": getattr(eevee, 'motion_blur_depth_scale', 100.0),
"motion_blur_max": getattr(eevee, 'motion_blur_max', 32),
"motion_blur_steps": getattr(eevee, 'motion_blur_steps', 1),
# Film
"use_overscan": getattr(eevee, 'use_overscan', False),
"overscan_size": getattr(eevee, 'overscan_size', 3.0),
# Indirect Lighting
"gi_diffuse_bounces": getattr(eevee, 'gi_diffuse_bounces', 3),
"gi_cubemap_resolution": getattr(eevee, 'gi_cubemap_resolution', '512'),
"gi_visibility_resolution": getattr(eevee, 'gi_visibility_resolution', '32'),
"gi_irradiance_smoothing": getattr(eevee, 'gi_irradiance_smoothing', 0.1),
"gi_glossy_clamp": getattr(eevee, 'gi_glossy_clamp', 0.0),
"gi_filter_quality": getattr(eevee, 'gi_filter_quality', 3.0),
"gi_show_irradiance": getattr(eevee, 'gi_show_irradiance', False),
"gi_show_cubemaps": getattr(eevee, 'gi_show_cubemaps', False),
"gi_auto_bake": getattr(eevee, 'gi_auto_bake', False),
# Hair/Curves
"hair_type": getattr(eevee, 'hair_type', 'STRIP'), # STRIP or STRAND
# Performance
"use_shadow_jitter_viewport": getattr(eevee, 'use_shadow_jitter_viewport', True),
# Simplify (from scene.render)
"use_simplify": getattr(scene.render, 'use_simplify', False),
"simplify_subdivision_render": getattr(scene.render, 'simplify_subdivision_render', 6),
"simplify_child_particles_render": getattr(scene.render, 'simplify_child_particles_render', 1.0),
}
else:
# For other engines, extract basic samples if available
@@ -149,10 +338,20 @@ camera_count = len([obj for obj in scene.objects if obj.type == 'CAMERA'])
object_count = len(scene.objects)
material_count = len(bpy.data.materials)
# Extract Blender version info
# bpy.app.version gives the current running Blender version
# For the file's saved version, we check bpy.data.version (version the file was saved with)
blender_version = {
"current": bpy.app.version_string, # Version of Blender running this script
"file_saved_with": ".".join(map(str, bpy.data.version)) if hasattr(bpy.data, 'version') else None, # Version file was saved with
}
# Build metadata dictionary
metadata = {
"frame_start": frame_start,
"frame_end": frame_end,
"has_negative_frames": has_negative_start or has_negative_end or has_negative_animation,
"blender_version": blender_version,
"render_settings": {
"resolution_x": resolution_x,
"resolution_y": resolution_y,

View File

@@ -338,10 +338,28 @@ if current_engine == 'CYCLES':
if gpu_available:
scene.cycles.device = 'GPU'
print(f"Using GPU for rendering (blend file had: {current_device})")
# Auto-enable GPU denoising when using GPU (OpenImageDenoise supports all GPUs)
try:
view_layer = bpy.context.view_layer
if hasattr(view_layer, 'cycles') and hasattr(view_layer.cycles, 'denoising_use_gpu'):
view_layer.cycles.denoising_use_gpu = True
print("Auto-enabled GPU denoising (OpenImageDenoise)")
except Exception as e:
print(f"Could not auto-enable GPU denoising: {e}")
else:
scene.cycles.device = 'CPU'
print(f"GPU not available, using CPU for rendering (blend file had: {current_device})")
# Ensure GPU denoising is disabled when using CPU
try:
view_layer = bpy.context.view_layer
if hasattr(view_layer, 'cycles') and hasattr(view_layer.cycles, 'denoising_use_gpu'):
view_layer.cycles.denoising_use_gpu = False
print("Using CPU denoising")
except Exception as e:
pass
# Verify device setting
if current_engine == 'CYCLES':
final_device = scene.cycles.device

View File

@@ -227,8 +227,9 @@ type TaskLogEntry struct {
// BlendMetadata represents extracted metadata from a blend file
type BlendMetadata struct {
FrameStart int `json:"frame_start"`
FrameEnd int `json:"frame_end"`
FrameStart int `json:"frame_start"`
FrameEnd int `json:"frame_end"`
HasNegativeFrames bool `json:"has_negative_frames"` // True if blend file has negative frame numbers (not supported)
RenderSettings RenderSettings `json:"render_settings"`
SceneInfo SceneInfo `json:"scene_info"`
MissingFilesInfo *MissingFilesInfo `json:"missing_files_info,omitempty"`

45
web/embed.go Normal file
View File

@@ -0,0 +1,45 @@
package web
import (
"embed"
"io/fs"
"net/http"
"strings"
)
//go:embed dist/*
var distFS embed.FS
// GetFileSystem returns an http.FileSystem for the embedded web UI files
func GetFileSystem() http.FileSystem {
subFS, err := fs.Sub(distFS, "dist")
if err != nil {
panic(err)
}
return http.FS(subFS)
}
// SPAHandler returns an http.Handler that serves the embedded SPA
// It serves static files if they exist, otherwise falls back to index.html
func SPAHandler() http.Handler {
fsys := GetFileSystem()
fileServer := http.FileServer(fsys)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
// Try to open the file
f, err := fsys.Open(strings.TrimPrefix(path, "/"))
if err != nil {
// File doesn't exist, serve index.html for SPA routing
r.URL.Path = "/"
fileServer.ServeHTTP(w, r)
return
}
f.Close()
// File exists, serve it
fileServer.ServeHTTP(w, r)
})
}

View File

@@ -5,6 +5,8 @@ import Layout from './components/Layout';
import JobList from './components/JobList';
import JobSubmission from './components/JobSubmission';
import AdminPanel from './components/AdminPanel';
import ErrorBoundary from './components/ErrorBoundary';
import LoadingSpinner from './components/LoadingSpinner';
import './styles/index.css';
function App() {
@@ -17,7 +19,7 @@ function App() {
if (loading) {
return (
<div className="min-h-screen flex items-center justify-center bg-gray-900">
<div className="animate-spin rounded-full h-12 w-12 border-b-2 border-orange-500"></div>
<LoadingSpinner size="md" />
</div>
);
}
@@ -26,26 +28,20 @@ function App() {
return loginComponent;
}
// Wrapper to check auth before changing tabs
const handleTabChange = async (newTab) => {
// Check auth before allowing navigation
try {
await refresh();
// If refresh succeeds, user is still authenticated
setActiveTab(newTab);
} catch (error) {
// Auth check failed, user will be set to null and login will show
console.error('Auth check failed on navigation:', error);
}
// Wrapper to change tabs - only check auth on mount, not on every navigation
const handleTabChange = (newTab) => {
setActiveTab(newTab);
};
return (
<Layout activeTab={activeTab} onTabChange={handleTabChange}>
{activeTab === 'jobs' && <JobList />}
{activeTab === 'submit' && (
<JobSubmission onSuccess={() => handleTabChange('jobs')} />
)}
{activeTab === 'admin' && <AdminPanel />}
<ErrorBoundary>
{activeTab === 'jobs' && <JobList />}
{activeTab === 'submit' && (
<JobSubmission onSuccess={() => handleTabChange('jobs')} />
)}
{activeTab === 'admin' && <AdminPanel />}
</ErrorBoundary>
</Layout>
);
}

View File

@@ -1,7 +1,9 @@
import { useState, useEffect } from 'react';
import { admin } from '../utils/api';
import { useState, useEffect, useRef } from 'react';
import { admin, jobs, normalizeArrayResponse } from '../utils/api';
import { wsManager } from '../utils/websocket';
import UserJobs from './UserJobs';
import PasswordChange from './PasswordChange';
import LoadingSpinner from './LoadingSpinner';
export default function AdminPanel() {
const [activeSection, setActiveSection] = useState('api-keys');
@@ -16,16 +18,110 @@ export default function AdminPanel() {
const [selectedUser, setSelectedUser] = useState(null);
const [registrationEnabled, setRegistrationEnabled] = useState(true);
const [passwordChangeUser, setPasswordChangeUser] = useState(null);
const listenerIdRef = useRef(null); // Listener ID for shared WebSocket
const subscribedChannelsRef = useRef(new Set()); // Track confirmed subscribed channels
const pendingSubscriptionsRef = useRef(new Set()); // Track pending subscriptions (waiting for confirmation)
// Connect to shared WebSocket on mount
useEffect(() => {
listenerIdRef.current = wsManager.subscribe('adminpanel', {
open: () => {
console.log('AdminPanel: Shared WebSocket connected');
// Subscribe to runners if already viewing runners section
if (activeSection === 'runners') {
subscribeToRunners();
}
},
message: (data) => {
// Handle subscription responses
if (data.type === 'subscribed' && data.channel) {
pendingSubscriptionsRef.current.delete(data.channel);
subscribedChannelsRef.current.add(data.channel);
console.log('Successfully subscribed to channel:', data.channel);
} else if (data.type === 'subscription_error' && data.channel) {
pendingSubscriptionsRef.current.delete(data.channel);
subscribedChannelsRef.current.delete(data.channel);
console.error('Subscription failed for channel:', data.channel, data.error);
}
// Handle runners channel messages
if (data.channel === 'runners' && data.type === 'runner_status') {
// Update runner in list
setRunners(prev => {
const index = prev.findIndex(r => r.id === data.runner_id);
if (index >= 0 && data.data) {
const updated = [...prev];
updated[index] = { ...updated[index], ...data.data };
return updated;
}
return prev;
});
}
},
error: (error) => {
console.error('AdminPanel: Shared WebSocket error:', error);
},
close: (event) => {
console.log('AdminPanel: Shared WebSocket closed:', event);
subscribedChannelsRef.current.clear();
pendingSubscriptionsRef.current.clear();
}
});
// Ensure connection is established
wsManager.connect();
return () => {
// Unsubscribe from all channels before unmounting
unsubscribeFromRunners();
if (listenerIdRef.current) {
wsManager.unsubscribe(listenerIdRef.current);
listenerIdRef.current = null;
}
};
}, []);
const subscribeToRunners = () => {
const channel = 'runners';
if (wsManager.getReadyState() !== WebSocket.OPEN) {
return;
}
// Don't subscribe if already subscribed or pending
if (subscribedChannelsRef.current.has(channel) || pendingSubscriptionsRef.current.has(channel)) {
return;
}
wsManager.send({ type: 'subscribe', channel });
pendingSubscriptionsRef.current.add(channel);
console.log('Subscribing to runners channel');
};
const unsubscribeFromRunners = () => {
const channel = 'runners';
if (wsManager.getReadyState() !== WebSocket.OPEN) {
return;
}
if (!subscribedChannelsRef.current.has(channel)) {
return; // Not subscribed
}
wsManager.send({ type: 'unsubscribe', channel });
subscribedChannelsRef.current.delete(channel);
pendingSubscriptionsRef.current.delete(channel);
console.log('Unsubscribed from runners channel');
};
useEffect(() => {
if (activeSection === 'api-keys') {
loadAPIKeys();
unsubscribeFromRunners();
} else if (activeSection === 'runners') {
loadRunners();
subscribeToRunners();
} else if (activeSection === 'users') {
loadUsers();
unsubscribeFromRunners();
} else if (activeSection === 'settings') {
loadSettings();
unsubscribeFromRunners();
}
}, [activeSection]);
@@ -33,7 +129,7 @@ export default function AdminPanel() {
setLoading(true);
try {
const data = await admin.listAPIKeys();
setApiKeys(Array.isArray(data) ? data : []);
setApiKeys(normalizeArrayResponse(data));
} catch (error) {
console.error('Failed to load API keys:', error);
setApiKeys([]);
@@ -47,7 +143,7 @@ export default function AdminPanel() {
setLoading(true);
try {
const data = await admin.listRunners();
setRunners(Array.isArray(data) ? data : []);
setRunners(normalizeArrayResponse(data));
} catch (error) {
console.error('Failed to load runners:', error);
setRunners([]);
@@ -61,7 +157,7 @@ export default function AdminPanel() {
setLoading(true);
try {
const data = await admin.listUsers();
setUsers(Array.isArray(data) ? data : []);
setUsers(normalizeArrayResponse(data));
} catch (error) {
console.error('Failed to load users:', error);
setUsers([]);
@@ -121,29 +217,22 @@ export default function AdminPanel() {
}
};
const revokeAPIKey = async (keyId) => {
if (!confirm('Are you sure you want to revoke this API key? Revoked keys cannot be used for new runner registrations.')) {
return;
}
try {
await admin.revokeAPIKey(keyId);
await loadAPIKeys();
} catch (error) {
console.error('Failed to revoke API key:', error);
alert('Failed to revoke API key');
}
};
const [deletingKeyId, setDeletingKeyId] = useState(null);
const [deletingRunnerId, setDeletingRunnerId] = useState(null);
const deleteAPIKey = async (keyId) => {
if (!confirm('Are you sure you want to permanently delete this API key? This action cannot be undone.')) {
const revokeAPIKey = async (keyId) => {
if (!confirm('Are you sure you want to delete this API key? This action cannot be undone.')) {
return;
}
setDeletingKeyId(keyId);
try {
await admin.deleteAPIKey(keyId);
await loadAPIKeys();
} catch (error) {
console.error('Failed to delete API key:', error);
alert('Failed to delete API key');
} finally {
setDeletingKeyId(null);
}
};
@@ -152,12 +241,15 @@ export default function AdminPanel() {
if (!confirm('Are you sure you want to delete this runner?')) {
return;
}
setDeletingRunnerId(runnerId);
try {
await admin.deleteRunner(runnerId);
await loadRunners();
} catch (error) {
console.error('Failed to delete runner:', error);
alert('Failed to delete runner');
} finally {
setDeletingRunnerId(null);
}
};
@@ -313,9 +405,7 @@ export default function AdminPanel() {
<div className="bg-gray-800 rounded-lg shadow-md p-6 border border-gray-700">
<h2 className="text-xl font-semibold mb-4 text-gray-100">API Keys</h2>
{loading ? (
<div className="flex justify-center py-8">
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-orange-500"></div>
</div>
<LoadingSpinner size="sm" className="py-8" />
) : !apiKeys || apiKeys.length === 0 ? (
<p className="text-gray-400 text-center py-8">No API keys generated yet.</p>
) : (
@@ -384,21 +474,13 @@ export default function AdminPanel() {
{new Date(key.created_at).toLocaleString()}
</td>
<td className="px-6 py-4 whitespace-nowrap text-sm space-x-2">
{key.is_active && !expired && (
<button
onClick={() => revokeAPIKey(key.id)}
className="text-yellow-400 hover:text-yellow-300 font-medium"
title="Revoke API key"
>
Revoke
</button>
)}
<button
onClick={() => deleteAPIKey(key.id)}
className="text-red-400 hover:text-red-300 font-medium"
title="Permanently delete API key"
onClick={() => revokeAPIKey(key.id)}
disabled={deletingKeyId === key.id}
className="text-red-400 hover:text-red-300 font-medium disabled:opacity-50 disabled:cursor-not-allowed"
title="Delete API key"
>
Delete
{deletingKeyId === key.id ? 'Deleting...' : 'Delete'}
</button>
</td>
</tr>
@@ -416,9 +498,7 @@ export default function AdminPanel() {
<div className="bg-gray-800 rounded-lg shadow-md p-6 border border-gray-700">
<h2 className="text-xl font-semibold mb-4 text-gray-100">Runner Management</h2>
{loading ? (
<div className="flex justify-center py-8">
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-orange-500"></div>
</div>
<LoadingSpinner size="sm" className="py-8" />
) : !runners || runners.length === 0 ? (
<p className="text-gray-400 text-center py-8">No runners registered.</p>
) : (
@@ -506,9 +586,10 @@ export default function AdminPanel() {
<td className="px-6 py-4 whitespace-nowrap text-sm">
<button
onClick={() => deleteRunner(runner.id)}
className="text-red-400 hover:text-red-300 font-medium"
disabled={deletingRunnerId === runner.id}
className="text-red-400 hover:text-red-300 font-medium disabled:opacity-50 disabled:cursor-not-allowed"
>
Delete
{deletingRunnerId === runner.id ? 'Deleting...' : 'Delete'}
</button>
</td>
</tr>
@@ -558,9 +639,7 @@ export default function AdminPanel() {
<div className="bg-gray-800 rounded-lg shadow-md p-6 border border-gray-700">
<h2 className="text-xl font-semibold mb-4 text-gray-100">User Management</h2>
{loading ? (
<div className="flex justify-center py-8">
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-orange-500"></div>
</div>
<LoadingSpinner size="sm" className="py-8" />
) : !users || users.length === 0 ? (
<p className="text-gray-400 text-center py-8">No users found.</p>
) : (

View File

@@ -0,0 +1,41 @@
import React from 'react';
class ErrorBoundary extends React.Component {
constructor(props) {
super(props);
this.state = { hasError: false, error: null };
}
static getDerivedStateFromError(error) {
return { hasError: true, error };
}
componentDidCatch(error, errorInfo) {
console.error('ErrorBoundary caught an error:', error, errorInfo);
}
render() {
if (this.state.hasError) {
return (
<div className="p-6 bg-red-400/20 border border-red-400/50 rounded-lg text-red-400">
<h2 className="text-xl font-semibold mb-2">Something went wrong</h2>
<p className="mb-4">{this.state.error?.message || 'An unexpected error occurred'}</p>
<button
onClick={() => {
this.setState({ hasError: false, error: null });
window.location.reload();
}}
className="px-4 py-2 bg-red-600 text-white rounded-lg hover:bg-red-500 transition-colors"
>
Reload Page
</button>
</div>
);
}
return this.props.children;
}
}
export default ErrorBoundary;

View File

@@ -0,0 +1,26 @@
import React from 'react';
/**
* Shared ErrorMessage component for consistent error display
* Sanitizes error messages to prevent XSS
*/
export default function ErrorMessage({ error, className = '' }) {
if (!error) return null;
// Sanitize error message - escape HTML entities
const sanitize = (text) => {
const div = document.createElement('div');
div.textContent = text;
return div.innerHTML;
};
const sanitizedError = typeof error === 'string' ? sanitize(error) : sanitize(error.message || 'An error occurred');
return (
<div className={`p-4 bg-red-400/20 border border-red-400/50 rounded-lg text-red-400 ${className}`}>
<p className="font-semibold">Error:</p>
<p dangerouslySetInnerHTML={{ __html: sanitizedError }} />
</div>
);
}

View File

@@ -1,7 +1,10 @@
import { useState, useEffect, useRef } from 'react';
import { jobs, REQUEST_SUPERSEDED } from '../utils/api';
import { wsManager } from '../utils/websocket';
import VideoPlayer from './VideoPlayer';
import FileExplorer from './FileExplorer';
import ErrorMessage from './ErrorMessage';
import LoadingSpinner from './LoadingSpinner';
export default function JobDetails({ job, onClose, onUpdate }) {
const [jobDetails, setJobDetails] = useState(job);
@@ -17,80 +20,66 @@ export default function JobDetails({ job, onClose, onUpdate }) {
const [expandedSteps, setExpandedSteps] = useState(new Set());
const [streaming, setStreaming] = useState(false);
const [previewImage, setPreviewImage] = useState(null); // { url, fileName } or null
const wsRef = useRef(null);
const jobWsRef = useRef(null); // Separate ref for job WebSocket
const listenerIdRef = useRef(null); // Listener ID for shared WebSocket
const subscribedChannelsRef = useRef(new Set()); // Track confirmed subscribed channels
const pendingSubscriptionsRef = useRef(new Set()); // Track pending subscriptions (waiting for confirmation)
const logContainerRefs = useRef({}); // Refs for each step's log container
const shouldAutoScrollRefs = useRef({}); // Auto-scroll state per step
const abortControllerRef = useRef(null); // AbortController for HTTP requests
// Sync job prop to state when it changes
useEffect(() => {
setJobDetails(job);
}, [job.id, job.status, job.progress]);
useEffect(() => {
// Create new AbortController for this effect
abortControllerRef.current = new AbortController();
loadDetails();
// Use WebSocket for real-time updates instead of polling
if (jobDetails.status === 'running' || jobDetails.status === 'pending' || !jobDetails.status) {
connectJobWebSocket();
return () => {
if (jobWsRef.current) {
try {
jobWsRef.current.close();
} catch (e) {
// Ignore errors when closing
}
jobWsRef.current = null;
}
if (wsRef.current) {
try {
wsRef.current.close();
} catch (e) {
// Ignore errors when closing
}
wsRef.current = null;
}
};
} else {
// Job is completed/failed/cancelled - close WebSocket
if (jobWsRef.current) {
try {
jobWsRef.current.close();
} catch (e) {
// Ignore errors when closing
}
jobWsRef.current = null;
// Use shared WebSocket manager for real-time updates
listenerIdRef.current = wsManager.subscribe(`jobdetails_${job.id}`, {
open: () => {
console.log('JobDetails: Shared WebSocket connected for job', job.id);
// Subscribe to job channel
subscribe(`job:${job.id}`);
},
message: (data) => {
handleWebSocketMessage(data);
},
error: (error) => {
console.error('JobDetails: Shared WebSocket error:', error);
},
close: (event) => {
console.log('JobDetails: Shared WebSocket closed:', event);
subscribedChannelsRef.current.clear();
pendingSubscriptionsRef.current.clear();
}
if (wsRef.current) {
try {
wsRef.current.close();
} catch (e) {
// Ignore errors when closing
}
wsRef.current = null;
}
}
}, [job.id, jobDetails.status]);
});
// Ensure connection is established
wsManager.connect();
useEffect(() => {
// Load logs and steps for all running tasks
if (jobDetails.status === 'running' && tasks.length > 0) {
const runningTasks = tasks.filter(t => t.status === 'running' || t.status === 'pending');
runningTasks.forEach(task => {
if (!taskData[task.id]) {
loadTaskData(task.id);
}
});
// Start streaming for the first running task (WebSocket supports one at a time)
if (runningTasks.length > 0 && !streaming) {
startLogStream(runningTasks.map(t => t.id));
}
} else if (wsRef.current) {
wsRef.current.close();
wsRef.current = null;
setStreaming(false);
}
return () => {
if (wsRef.current && jobDetails.status !== 'running') {
wsRef.current.close();
wsRef.current = null;
// Cancel any pending HTTP requests
if (abortControllerRef.current) {
abortControllerRef.current.abort();
abortControllerRef.current = null;
}
// Unsubscribe from all channels
unsubscribeAll();
if (listenerIdRef.current) {
wsManager.unsubscribe(listenerIdRef.current);
listenerIdRef.current = null;
}
};
}, [tasks, jobDetails.status]);
}, [job.id]);
useEffect(() => {
// Update log subscriptions based on expanded tasks (not steps)
updateLogSubscriptions();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [expandedTasks, tasks.length, jobDetails.status]); // Use tasks.length instead of tasks to avoid unnecessary re-runs
// Auto-scroll logs to bottom when new logs arrive
useEffect(() => {
@@ -119,11 +108,17 @@ export default function JobDetails({ job, onClose, onUpdate }) {
try {
setLoading(true);
// Use summary endpoint for tasks initially - much faster
const signal = abortControllerRef.current?.signal;
const [details, fileList, taskListResult] = await Promise.all([
jobs.get(job.id),
jobs.getFiles(job.id, { limit: 50 }), // Only load first page of files
jobs.getTasksSummary(job.id, { limit: 100, sort: 'frame_start:asc' }), // Use summary endpoint
jobs.get(job.id, { signal }),
jobs.getFiles(job.id, { limit: 50, signal }), // Only load first page of files
jobs.getTasksSummary(job.id, { sort: 'frame_start:asc', signal }), // Get all tasks
]);
// Check if request was aborted
if (signal?.aborted) {
return;
}
setJobDetails(details);
// Handle paginated file response - check for superseded sentinel
@@ -159,9 +154,11 @@ export default function JobDetails({ job, onClose, onUpdate }) {
// Fetch context archive contents separately (may not exist for old jobs)
try {
const contextList = await jobs.getContextArchive(job.id);
const contextList = await jobs.getContextArchive(job.id, { signal });
if (signal?.aborted) return;
setContextFiles(contextList || []);
} catch (error) {
if (signal?.aborted) return;
// Context archive may not exist for old jobs
setContextFiles([]);
}
@@ -205,11 +202,17 @@ export default function JobDetails({ job, onClose, onUpdate }) {
const loadTaskData = async (taskId) => {
try {
console.log(`Loading task data for task ${taskId}...`);
const signal = abortControllerRef.current?.signal;
const [logsResult, steps] = await Promise.all([
jobs.getTaskLogs(job.id, taskId, { limit: 1000 }), // Increased limit for completed tasks
jobs.getTaskSteps(job.id, taskId),
jobs.getTaskLogs(job.id, taskId, { limit: 1000, signal }), // Increased limit for completed tasks
jobs.getTaskSteps(job.id, taskId, { signal }),
]);
// Check if request was aborted
if (signal?.aborted) {
return;
}
// Check for superseded sentinel
if (logsResult === REQUEST_SUPERSEDED || steps === REQUEST_SUPERSEDED) {
return; // Request was superseded, skip this update
@@ -247,7 +250,14 @@ export default function JobDetails({ job, onClose, onUpdate }) {
const loadTaskStepsOnly = async (taskId) => {
try {
const steps = await jobs.getTaskSteps(job.id, taskId);
const signal = abortControllerRef.current?.signal;
const steps = await jobs.getTaskSteps(job.id, taskId, { signal });
// Check if request was aborted
if (signal?.aborted) {
return;
}
// Check for superseded sentinel
if (steps === REQUEST_SUPERSEDED) {
return; // Request was superseded, skip this update
@@ -267,356 +277,416 @@ export default function JobDetails({ job, onClose, onUpdate }) {
}
};
const connectJobWebSocket = () => {
try {
// Close existing connection if any
if (jobWsRef.current) {
try {
jobWsRef.current.close();
} catch (e) {
// Ignore errors when closing
}
jobWsRef.current = null;
}
const ws = jobs.streamJobWebSocket(job.id);
jobWsRef.current = ws; // Store reference
ws.onopen = () => {
console.log('Job WebSocket connected for job', job.id);
};
ws.onmessage = (event) => {
try {
const data = JSON.parse(event.data);
console.log('Job WebSocket message received:', data.type, data);
if (data.type === 'job_update' && data.data) {
// Update job details
setJobDetails(prev => ({ ...prev, ...data.data }));
} else if (data.type === 'task_update' && data.data) {
// Update task in list
setTasks(prev => {
// Ensure prev is always an array
const prevArray = Array.isArray(prev) ? prev : [];
if (!data.task_id) {
console.warn('task_update message missing task_id:', data);
return prevArray;
}
const index = prevArray.findIndex(t => t.id === data.task_id);
if (index >= 0) {
const updated = [...prevArray];
updated[index] = { ...updated[index], ...data.data };
return updated;
}
// If task not found, it might be a new task - reload to be safe
if (data.data && (data.data.status === 'running' || data.data.status === 'pending')) {
setTimeout(() => {
const reloadTasks = async () => {
try {
const taskListResult = await jobs.getTasksSummary(job.id, { limit: 100, sort: 'frame_start:asc' });
// Check for superseded sentinel
if (taskListResult === REQUEST_SUPERSEDED) {
return; // Request was superseded, skip this update
}
const taskData = taskListResult.data || taskListResult;
const taskSummaries = Array.isArray(taskData) ? taskData : [];
const tasksForDisplay = taskSummaries.map(summary => ({
id: summary.id,
job_id: job.id,
frame_start: summary.frame_start,
frame_end: summary.frame_end,
status: summary.status,
task_type: summary.task_type,
runner_id: summary.runner_id,
current_step: null,
retry_count: 0,
max_retries: 3,
created_at: new Date().toISOString(),
}));
setTasks(Array.isArray(tasksForDisplay) ? tasksForDisplay : []);
} catch (error) {
console.error('Failed to reload tasks:', error);
}
};
reloadTasks();
}, 100);
}
return prevArray;
});
} else if (data.type === 'task_added' && data.data) {
// New task was added - reload task summaries to get the new task
console.log('task_added message received, reloading tasks...', data);
const reloadTasks = async () => {
try {
const taskListResult = await jobs.getTasksSummary(job.id, { limit: 100, sort: 'frame_start:asc' });
// Check for superseded sentinel
if (taskListResult === REQUEST_SUPERSEDED) {
return; // Request was superseded, skip this update
}
const taskData = taskListResult.data || taskListResult;
const taskSummaries = Array.isArray(taskData) ? taskData : [];
const tasksForDisplay = taskSummaries.map(summary => ({
id: summary.id,
job_id: job.id,
frame_start: summary.frame_start,
frame_end: summary.frame_end,
status: summary.status,
task_type: summary.task_type,
runner_id: summary.runner_id,
current_step: null,
retry_count: 0,
max_retries: 3,
created_at: new Date().toISOString(),
}));
setTasks(Array.isArray(tasksForDisplay) ? tasksForDisplay : []);
} catch (error) {
console.error('Failed to reload tasks:', error);
// Fallback to full reload
loadDetails();
}
};
reloadTasks();
} else if (data.type === 'tasks_added' && data.data) {
// Multiple new tasks were added - reload task summaries
console.log('tasks_added message received, reloading tasks...', data);
const reloadTasks = async () => {
try {
const taskListResult = await jobs.getTasksSummary(job.id, { limit: 100, sort: 'frame_start:asc' });
// Check for superseded sentinel
if (taskListResult === REQUEST_SUPERSEDED) {
return; // Request was superseded, skip this update
}
const taskData = taskListResult.data || taskListResult;
const taskSummaries = Array.isArray(taskData) ? taskData : [];
const tasksForDisplay = taskSummaries.map(summary => ({
id: summary.id,
job_id: job.id,
frame_start: summary.frame_start,
frame_end: summary.frame_end,
status: summary.status,
task_type: summary.task_type,
runner_id: summary.runner_id,
current_step: null,
retry_count: 0,
max_retries: 3,
created_at: new Date().toISOString(),
}));
setTasks(Array.isArray(tasksForDisplay) ? tasksForDisplay : []);
} catch (error) {
console.error('Failed to reload tasks:', error);
// Fallback to full reload
loadDetails();
}
};
reloadTasks();
} else if (data.type === 'file_added' && data.data) {
// New file was added - reload file list
const reloadFiles = async () => {
try {
const fileList = await jobs.getFiles(job.id, { limit: 50 });
// Check for superseded sentinel
if (fileList === REQUEST_SUPERSEDED) {
return; // Request was superseded, skip this update
}
const fileData = fileList.data || fileList;
setFiles(Array.isArray(fileData) ? fileData : []);
} catch (error) {
console.error('Failed to reload files:', error);
}
};
reloadFiles();
} else if (data.type === 'step_update' && data.data && data.task_id) {
// Step was created or updated - update task data
console.log('step_update message received:', data);
setTaskData(prev => {
const taskId = data.task_id;
const current = prev[taskId] || { steps: [], logs: [] };
const stepData = data.data;
// Find if step already exists
const existingSteps = current.steps || [];
const stepIndex = existingSteps.findIndex(s => s.step_name === stepData.step_name);
let updatedSteps;
if (stepIndex >= 0) {
// Update existing step
updatedSteps = [...existingSteps];
updatedSteps[stepIndex] = {
...updatedSteps[stepIndex],
...stepData,
id: stepData.step_id || updatedSteps[stepIndex].id,
};
} else {
// Add new step
updatedSteps = [...existingSteps, {
id: stepData.step_id,
step_name: stepData.step_name,
status: stepData.status,
duration_ms: stepData.duration_ms,
error_message: stepData.error_message,
}];
}
return {
...prev,
[taskId]: {
...current,
steps: updatedSteps,
}
};
});
} else if (data.type === 'connected') {
// Connection established
}
} catch (error) {
console.error('Failed to parse WebSocket message:', error);
}
};
ws.onerror = (error) => {
console.error('Job WebSocket error:', {
error,
readyState: ws.readyState,
url: ws.url,
jobId: job.id,
status: jobDetails.status
});
// WebSocket errors don't provide much detail, but we can check readyState
if (ws.readyState === WebSocket.CLOSED || ws.readyState === WebSocket.CLOSING) {
console.warn('Job WebSocket is closed or closing, will attempt reconnect');
}
};
ws.onclose = (event) => {
console.log('Job WebSocket closed:', {
code: event.code,
reason: event.reason,
wasClean: event.wasClean,
jobId: job.id,
status: jobDetails.status
});
jobWsRef.current = null;
// Code 1006 = Abnormal Closure (connection lost without close frame)
// Code 1000 = Normal Closure
// Code 1001 = Going Away (server restart, etc.)
// We should reconnect for abnormal closures (1006) or unexpected closes
const shouldReconnect = !event.wasClean || event.code === 1006 || event.code === 1001;
// Get current status from state to avoid stale closure
const currentStatus = jobDetails.status;
const isActiveJob = currentStatus === 'running' || currentStatus === 'pending';
if (shouldReconnect && isActiveJob) {
console.log(`Attempting to reconnect job WebSocket in 2 seconds... (code: ${event.code})`);
setTimeout(() => {
// Check status again before reconnecting (might have changed)
// Use a ref or check the current state directly
if ((!jobWsRef.current || jobWsRef.current.readyState === WebSocket.CLOSED)) {
// Re-check if job is still active by reading current state
// We'll check this in connectJobWebSocket if needed
connectJobWebSocket();
}
}, 2000);
} else if (!isActiveJob) {
console.log('Job is no longer active, not reconnecting WebSocket');
}
};
} catch (error) {
console.error('Failed to connect job WebSocket:', error);
const subscribe = (channel) => {
if (wsManager.getReadyState() !== WebSocket.OPEN) {
return;
}
// Don't subscribe if already subscribed or pending
if (subscribedChannelsRef.current.has(channel) || pendingSubscriptionsRef.current.has(channel)) {
return; // Already subscribed or subscription pending
}
wsManager.send({ type: 'subscribe', channel });
pendingSubscriptionsRef.current.add(channel); // Mark as pending
};
const startLogStream = (taskIds) => {
if (taskIds.length === 0 || streaming) return;
const unsubscribe = (channel) => {
if (wsManager.getReadyState() !== WebSocket.OPEN) {
return;
}
if (!subscribedChannelsRef.current.has(channel)) {
return; // Not subscribed
}
wsManager.send({ type: 'unsubscribe', channel });
subscribedChannelsRef.current.delete(channel);
console.log('Unsubscribed from channel:', channel);
};
// Don't start streaming if job is no longer running
if (jobDetails.status !== 'running' && jobDetails.status !== 'pending') {
console.log('Job is not running, skipping log stream');
const unsubscribeAll = () => {
subscribedChannelsRef.current.forEach(channel => {
unsubscribe(channel);
});
};
const updateLogSubscriptions = () => {
if (wsManager.getReadyState() !== WebSocket.OPEN) {
return;
}
setStreaming(true);
// For now, stream the first task's logs (WebSocket supports one task at a time)
// In the future, we could have multiple WebSocket connections
const primaryTaskId = taskIds[0];
const ws = jobs.streamTaskLogsWebSocket(job.id, primaryTaskId);
wsRef.current = ws;
// Determine which log channels should be subscribed
const shouldSubscribe = new Set();
const isRunning = jobDetails.status === 'running' || jobDetails.status === 'pending';
ws.onmessage = (event) => {
try {
const data = JSON.parse(event.data);
// Subscribe to logs when task is expanded (not when step is expanded)
if (isRunning) {
expandedTasks.forEach(taskId => {
const channel = `logs:${job.id}:${taskId}`;
shouldSubscribe.add(channel);
});
}
// Subscribe to new channels
shouldSubscribe.forEach(channel => {
subscribe(channel);
});
// Unsubscribe from channels that shouldn't be subscribed
subscribedChannelsRef.current.forEach(channel => {
if (channel.startsWith('logs:') && !shouldSubscribe.has(channel)) {
unsubscribe(channel);
}
});
};
const handleWebSocketMessage = (data) => {
try {
console.log('JobDetails: Client WebSocket message received:', data.type, data.channel, data);
// Handle subscription responses
if (data.type === 'subscribed' && data.channel) {
pendingSubscriptionsRef.current.delete(data.channel); // Remove from pending
subscribedChannelsRef.current.add(data.channel); // Add to confirmed
console.log('Successfully subscribed to channel:', data.channel, 'Total subscriptions:', subscribedChannelsRef.current.size);
} else if (data.type === 'subscription_error' && data.channel) {
pendingSubscriptionsRef.current.delete(data.channel); // Remove from pending
subscribedChannelsRef.current.delete(data.channel); // Remove from confirmed (if it was there)
console.error('Subscription failed for channel:', data.channel, data.error);
// If it's the job channel, this is a critical error
if (data.channel === `job:${job.id}`) {
console.error('Failed to subscribe to job channel - job may not exist or access denied');
}
}
// Handle job channel messages
// Check both explicit channel and job_id match (for backwards compatibility)
const isJobChannel = data.channel === `job:${job.id}` ||
(data.job_id === job.id && !data.channel);
if (isJobChannel) {
console.log('Job channel message received:', data.type, data);
if (data.type === 'job_update' && data.data) {
// Update job details
console.log('Updating job details:', data.data);
setJobDetails(prev => {
const updated = { ...prev, ...data.data };
console.log('Job details updated:', {
old_progress: prev.progress,
new_progress: updated.progress,
old_status: prev.status,
new_status: updated.status
});
// Notify parent component of update
if (onUpdate) {
onUpdate(data.job_id || job.id, updated);
}
return updated;
});
} else if (data.type === 'task_update') {
// Handle task_update - data.data contains the update fields
const taskId = data.task_id || (data.data && (data.data.id || data.data.task_id));
console.log('Task update received:', { task_id: taskId, data: data.data, full_message: data });
if (!taskId) {
console.warn('task_update message missing task_id:', data);
return;
}
if (!data.data) {
console.warn('task_update message missing data:', data);
return;
}
// Update task in list
setTasks(prev => {
// Ensure prev is always an array
const prevArray = Array.isArray(prev) ? prev : [];
const index = prevArray.findIndex(t => t.id === taskId);
if (index >= 0) {
// Task exists - update it
const updated = [...prevArray];
const oldTask = updated[index];
// Create a completely new task object to ensure React detects the change
const newTask = {
...oldTask,
// Explicitly update each field from data.data to ensure changes are detected
status: data.data.status !== undefined ? data.data.status : oldTask.status,
runner_id: data.data.runner_id !== undefined ? data.data.runner_id : oldTask.runner_id,
started_at: data.data.started_at !== undefined ? data.data.started_at : oldTask.started_at,
completed_at: data.data.completed_at !== undefined ? data.data.completed_at : oldTask.completed_at,
error_message: data.data.error_message !== undefined ? data.data.error_message : oldTask.error_message,
output_path: data.data.output_path !== undefined ? data.data.output_path : oldTask.output_path,
current_step: data.data.current_step !== undefined ? data.data.current_step : oldTask.current_step,
// Merge any other fields
...Object.keys(data.data).reduce((acc, key) => {
if (!['status', 'runner_id', 'started_at', 'completed_at', 'error_message', 'output_path', 'current_step'].includes(key)) {
acc[key] = data.data[key];
}
return acc;
}, {})
};
updated[index] = newTask;
console.log('Updated task at index', index, {
task_id: taskId,
old_status: oldTask.status,
new_status: newTask.status,
old_runner_id: oldTask.runner_id,
new_runner_id: newTask.runner_id,
update_data: data.data,
full_new_task: newTask
});
return updated;
}
// Task not found - check if data contains full task info (from initial state)
// Check both 'id' and 'task_id' fields
const taskIdFromData = data.data.id || data.data.task_id;
if (data.data && typeof data.data === 'object' && taskIdFromData && taskIdFromData === taskId) {
// This is a full task object from initial state - add it
console.log('Adding new task from initial state:', data.data);
return [...prevArray, { ...data.data, id: taskIdFromData }];
}
// If task not found and it's a partial update, reload tasks to get the full list
console.log('Task not found in list, reloading tasks...');
setTimeout(() => {
const reloadTasks = async () => {
try {
const signal = abortControllerRef.current?.signal;
const taskListResult = await jobs.getTasksSummary(job.id, { sort: 'frame_start:asc', signal });
// Check if request was aborted
if (signal?.aborted) {
return;
}
// Check for superseded sentinel
if (taskListResult === REQUEST_SUPERSEDED) {
return; // Request was superseded, skip this update
}
const taskData = taskListResult.data || taskListResult;
const taskSummaries = Array.isArray(taskData) ? taskData : [];
const tasksForDisplay = taskSummaries.map(summary => ({
id: summary.id,
job_id: job.id,
frame_start: summary.frame_start,
frame_end: summary.frame_end,
status: summary.status,
task_type: summary.task_type,
runner_id: summary.runner_id,
current_step: summary.current_step || null,
retry_count: summary.retry_count || 0,
max_retries: summary.max_retries || 3,
created_at: summary.created_at || new Date().toISOString(),
started_at: summary.started_at,
completed_at: summary.completed_at,
error_message: summary.error_message,
output_path: summary.output_path,
}));
setTasks(Array.isArray(tasksForDisplay) ? tasksForDisplay : []);
} catch (error) {
console.error('Failed to reload tasks:', error);
}
};
reloadTasks();
}, 100);
return prevArray;
});
} else if (data.type === 'task_added' && data.data) {
// New task was added - reload task summaries to get the new task
console.log('task_added message received, reloading tasks...', data);
const reloadTasks = async () => {
try {
const signal = abortControllerRef.current?.signal;
const taskListResult = await jobs.getTasksSummary(job.id, { limit: 100, sort: 'frame_start:asc', signal });
// Check if request was aborted
if (signal?.aborted) {
return;
}
// Check for superseded sentinel
if (taskListResult === REQUEST_SUPERSEDED) {
return; // Request was superseded, skip this update
}
const taskData = taskListResult.data || taskListResult;
const taskSummaries = Array.isArray(taskData) ? taskData : [];
const tasksForDisplay = taskSummaries.map(summary => ({
id: summary.id,
job_id: job.id,
frame_start: summary.frame_start,
frame_end: summary.frame_end,
status: summary.status,
task_type: summary.task_type,
runner_id: summary.runner_id,
current_step: null,
retry_count: 0,
max_retries: 3,
created_at: new Date().toISOString(),
}));
setTasks(Array.isArray(tasksForDisplay) ? tasksForDisplay : []);
} catch (error) {
console.error('Failed to reload tasks:', error);
// Fallback to full reload
loadDetails();
}
};
reloadTasks();
} else if (data.type === 'tasks_added' && data.data) {
// Multiple new tasks were added - reload task summaries
console.log('tasks_added message received, reloading tasks...', data);
const reloadTasks = async () => {
try {
const signal = abortControllerRef.current?.signal;
const taskListResult = await jobs.getTasksSummary(job.id, { limit: 100, sort: 'frame_start:asc', signal });
// Check if request was aborted
if (signal?.aborted) {
return;
}
// Check for superseded sentinel
if (taskListResult === REQUEST_SUPERSEDED) {
return; // Request was superseded, skip this update
}
const taskData = taskListResult.data || taskListResult;
const taskSummaries = Array.isArray(taskData) ? taskData : [];
const tasksForDisplay = taskSummaries.map(summary => ({
id: summary.id,
job_id: job.id,
frame_start: summary.frame_start,
frame_end: summary.frame_end,
status: summary.status,
task_type: summary.task_type,
runner_id: summary.runner_id,
current_step: null,
retry_count: 0,
max_retries: 3,
created_at: new Date().toISOString(),
}));
setTasks(Array.isArray(tasksForDisplay) ? tasksForDisplay : []);
} catch (error) {
console.error('Failed to reload tasks:', error);
// Fallback to full reload
loadDetails();
}
};
reloadTasks();
} else if (data.type === 'file_added' && data.data) {
// New file was added - reload file list
const reloadFiles = async () => {
try {
const fileList = await jobs.getFiles(job.id, { limit: 50 });
// Check for superseded sentinel
if (fileList === REQUEST_SUPERSEDED) {
return; // Request was superseded, skip this update
}
const fileData = fileList.data || fileList;
setFiles(Array.isArray(fileData) ? fileData : []);
} catch (error) {
console.error('Failed to reload files:', error);
}
};
reloadFiles();
} else if (data.type === 'step_update' && data.data && data.task_id) {
// Step was created or updated - update task data
console.log('step_update message received:', data);
setTaskData(prev => {
const taskId = data.task_id;
const current = prev[taskId] || { steps: [], logs: [] };
const stepData = data.data;
// Find if step already exists
const existingSteps = current.steps || [];
const stepIndex = existingSteps.findIndex(s => s.step_name === stepData.step_name);
let updatedSteps;
if (stepIndex >= 0) {
// Update existing step
updatedSteps = [...existingSteps];
updatedSteps[stepIndex] = {
...updatedSteps[stepIndex],
...stepData,
id: stepData.step_id || updatedSteps[stepIndex].id,
};
} else {
// Add new step
updatedSteps = [...existingSteps, {
id: stepData.step_id,
step_name: stepData.step_name,
status: stepData.status,
duration_ms: stepData.duration_ms,
error_message: stepData.error_message,
}];
}
return {
...prev,
[taskId]: {
...current,
steps: updatedSteps,
}
};
});
}
} else if (data.channel && data.channel.startsWith('logs:')) {
// Handle log channel messages
if (data.type === 'log' && data.data) {
const log = data.data;
// Get task_id from log data or top-level message
const taskId = log.task_id || data.task_id;
if (!taskId) {
console.warn('Log message missing task_id:', data);
return;
}
console.log('Received log for task:', taskId, log);
setTaskData(prev => {
const taskId = log.task_id;
const current = prev[taskId] || { steps: [], logs: [] };
// If log has a step_name, ensure the step exists in the steps array
let updatedSteps = current.steps || [];
if (log.step_name) {
const stepExists = updatedSteps.some(s => s.step_name === log.step_name);
if (!stepExists) {
// Create placeholder step for logs that arrive before step_update
console.log('Creating placeholder step for:', log.step_name, 'in task:', taskId);
updatedSteps = [...updatedSteps, {
id: null, // Will be updated when step_update arrives
step_name: log.step_name,
status: 'running', // Default to running since we're receiving logs
duration_ms: null,
error_message: null,
}];
}
}
// Check if log already exists (avoid duplicates)
if (!current.logs.find(l => l.id === log.id)) {
return {
...prev,
[taskId]: {
...current,
steps: updatedSteps,
logs: [...current.logs, log]
}
};
}
return prev;
// Even if log is duplicate, update steps if needed
return {
...prev,
[taskId]: {
...current,
steps: updatedSteps,
}
};
});
} else if (data.type === 'connected') {
// Connection established
}
} catch (error) {
console.error('Failed to parse log message:', error);
} else if (data.type === 'connected') {
// Connection established
}
};
ws.onopen = () => {
console.log('Log WebSocket connected for task', primaryTaskId);
};
ws.onerror = (error) => {
console.error('Log WebSocket error:', {
error,
readyState: ws.readyState,
url: ws.url,
taskId: primaryTaskId,
jobId: job.id
});
setStreaming(false);
};
ws.onclose = (event) => {
console.log('Log WebSocket closed:', {
code: event.code,
reason: event.reason,
wasClean: event.wasClean,
taskId: primaryTaskId,
jobId: job.id
});
setStreaming(false);
wsRef.current = null;
// Code 1006 = Abnormal Closure (connection lost without close frame)
// Code 1000 = Normal Closure
// Code 1001 = Going Away (server restart, etc.)
const shouldReconnect = !event.wasClean || event.code === 1006 || event.code === 1001;
// Auto-reconnect if job is still running and close was unexpected
if (shouldReconnect && jobDetails.status === 'running' && taskIds.length > 0) {
console.log(`Attempting to reconnect log WebSocket in 2 seconds... (code: ${event.code})`);
setTimeout(() => {
// Check status again before reconnecting (might have changed)
// The startLogStream function will check if job is still running
if (jobDetails.status === 'running' && taskIds.length > 0) {
startLogStream(taskIds);
}
}, 2000);
}
};
} catch (error) {
console.error('Failed to parse WebSocket message:', error);
}
};
// startLogStream is no longer needed - subscriptions are managed by updateLogSubscriptions
const toggleTask = async (taskId) => {
const newExpanded = new Set(expandedTasks);
if (newExpanded.has(taskId)) {
@@ -629,10 +699,17 @@ export default function JobDetails({ job, onClose, onUpdate }) {
if (currentTask && !currentTask.created_at) {
// This is a summary - fetch full task details
try {
const signal = abortControllerRef.current?.signal;
const fullTasks = await jobs.getTasks(job.id, {
limit: 1,
signal,
// We can't filter by task ID, so we'll get all and find the one we need
});
// Check if request was aborted
if (signal?.aborted) {
return;
}
const taskData = fullTasks.data || fullTasks;
const fullTask = Array.isArray(taskData) ? taskData.find(t => t.id === taskId) : null;
if (fullTask) {
@@ -834,11 +911,7 @@ export default function JobDetails({ job, onClose, onUpdate }) {
</div>
<div className="p-6 space-y-6">
{loading && (
<div className="flex justify-center py-8">
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-orange-500"></div>
</div>
)}
{loading && <LoadingSpinner size="sm" className="py-8" />}
{!loading && (
<>
@@ -850,7 +923,7 @@ export default function JobDetails({ job, onClose, onUpdate }) {
<div>
<p className="text-sm text-gray-400">Progress</p>
<p className="font-semibold text-gray-100">
{jobDetails.progress.toFixed(1)}%
{(jobDetails.progress || 0).toFixed(1)}%
</p>
</div>
<div>
@@ -911,12 +984,7 @@ export default function JobDetails({ job, onClose, onUpdate }) {
</div>
)}
{jobDetails.error_message && (
<div className="p-4 bg-red-400/20 border border-red-400/50 rounded-lg text-red-400">
<p className="font-semibold">Error:</p>
<p>{jobDetails.error_message}</p>
</div>
)}
<ErrorMessage error={jobDetails.error_message} />
<div>
<h3 className="text-lg font-semibold text-gray-100 mb-3">

View File

@@ -1,6 +1,8 @@
import { useState, useEffect, useRef } from 'react';
import { jobs } from '../utils/api';
import { jobs, normalizeArrayResponse } from '../utils/api';
import { wsManager } from '../utils/websocket';
import JobDetails from './JobDetails';
import LoadingSpinner from './LoadingSpinner';
export default function JobList() {
const [jobList, setJobList] = useState([]);
@@ -8,140 +10,75 @@ export default function JobList() {
const [selectedJob, setSelectedJob] = useState(null);
const [pagination, setPagination] = useState({ total: 0, limit: 50, offset: 0 });
const [hasMore, setHasMore] = useState(true);
const pollingIntervalRef = useRef(null);
const wsRef = useRef(null);
const listenerIdRef = useRef(null);
useEffect(() => {
loadJobs();
// Use WebSocket for real-time updates instead of polling
connectWebSocket();
return () => {
if (pollingIntervalRef.current) {
clearInterval(pollingIntervalRef.current);
}
if (wsRef.current) {
try {
wsRef.current.close();
} catch (e) {
// Ignore errors when closing
}
wsRef.current = null;
}
};
}, []);
const connectWebSocket = () => {
try {
// Close existing connection if any
if (wsRef.current) {
try {
wsRef.current.close();
} catch (e) {
// Ignore errors when closing
}
wsRef.current = null;
}
const ws = jobs.streamJobsWebSocket();
wsRef.current = ws;
ws.onopen = () => {
console.log('Job list WebSocket connected');
};
ws.onmessage = (event) => {
try {
const data = JSON.parse(event.data);
// Use shared WebSocket manager for real-time updates
listenerIdRef.current = wsManager.subscribe('joblist', {
open: () => {
console.log('JobList: Shared WebSocket connected');
// Load initial job list via HTTP to get current state
loadJobs();
},
message: (data) => {
console.log('JobList: Client WebSocket message received:', data.type, data.channel, data);
// Handle jobs channel messages (always broadcasted)
if (data.channel === 'jobs') {
if (data.type === 'job_update' && data.data) {
console.log('JobList: Updating job:', data.job_id, data.data);
// Update job in list
setJobList(prev => {
const index = prev.findIndex(j => j.id === data.job_id);
const prevArray = Array.isArray(prev) ? prev : [];
const index = prevArray.findIndex(j => j.id === data.job_id);
if (index >= 0) {
const updated = [...prev];
const updated = [...prevArray];
updated[index] = { ...updated[index], ...data.data };
console.log('JobList: Updated job at index', index, updated[index]);
return updated;
}
// If job not in current page, reload to get updated list
if (data.data.status === 'completed' || data.data.status === 'failed') {
loadJobs();
}
return prev;
return prevArray;
});
} else if (data.type === 'job_created' && data.data) {
console.log('JobList: New job created:', data.job_id, data.data);
// New job created - add to list
setJobList(prev => {
const prevArray = Array.isArray(prev) ? prev : [];
// Check if job already exists (avoid duplicates)
if (prevArray.findIndex(j => j.id === data.job_id) >= 0) {
return prevArray;
}
// Add new job at the beginning
return [data.data, ...prevArray];
});
} else if (data.type === 'connected') {
// Connection established
}
} catch (error) {
console.error('Failed to parse WebSocket message:', error);
} else if (data.type === 'connected') {
// Connection established
console.log('JobList: WebSocket connected');
}
};
ws.onerror = (error) => {
console.error('Job list WebSocket error:', {
error,
readyState: ws.readyState,
url: ws.url
});
// WebSocket errors don't provide much detail, but we can check readyState
if (ws.readyState === WebSocket.CLOSED || ws.readyState === WebSocket.CLOSING) {
console.warn('Job list WebSocket is closed or closing, will fallback to polling');
// Fallback to polling on error
startAdaptivePolling();
}
};
ws.onclose = (event) => {
console.log('Job list WebSocket closed:', {
code: event.code,
reason: event.reason,
wasClean: event.wasClean
});
wsRef.current = null;
// Code 1006 = Abnormal Closure (connection lost without close frame)
// Code 1000 = Normal Closure
// Code 1001 = Going Away (server restart, etc.)
// We should reconnect for abnormal closures (1006) or unexpected closes
const shouldReconnect = !event.wasClean || event.code === 1006 || event.code === 1001;
if (shouldReconnect) {
console.log(`Attempting to reconnect job list WebSocket in 2 seconds... (code: ${event.code})`);
setTimeout(() => {
if (wsRef.current === null || (wsRef.current && wsRef.current.readyState === WebSocket.CLOSED)) {
connectWebSocket();
}
}, 2000);
} else {
// Clean close (code 1000) - fallback to polling
console.log('WebSocket closed cleanly, falling back to polling');
startAdaptivePolling();
}
};
} catch (error) {
console.error('Failed to connect WebSocket:', error);
// Fallback to polling
startAdaptivePolling();
}
};
const startAdaptivePolling = () => {
const checkAndPoll = () => {
const hasRunningJobs = jobList.some(job => job.status === 'running' || job.status === 'pending');
const interval = hasRunningJobs ? 5000 : 10000; // 5s for running, 10s for completed
if (pollingIntervalRef.current) {
clearInterval(pollingIntervalRef.current);
},
error: (error) => {
console.error('JobList: Shared WebSocket error:', error);
},
close: (event) => {
console.log('JobList: Shared WebSocket closed:', event);
}
});
pollingIntervalRef.current = setInterval(() => {
loadJobs();
}, interval);
// Ensure connection is established
wsManager.connect();
return () => {
if (listenerIdRef.current) {
wsManager.unsubscribe(listenerIdRef.current);
listenerIdRef.current = null;
}
};
checkAndPoll();
// Re-check interval when job list changes
const checkInterval = setInterval(checkAndPoll, 5000);
return () => clearInterval(checkInterval);
};
}, []);
const loadJobs = async (append = false) => {
try {
@@ -153,20 +90,27 @@ export default function JobList() {
});
// Handle both old format (array) and new format (object with data, total, etc.)
const jobsData = result.data || result;
const total = result.total !== undefined ? result.total : jobsData.length;
const jobsArray = normalizeArrayResponse(result);
const total = result.total !== undefined ? result.total : jobsArray.length;
if (append) {
setJobList(prev => [...prev, ...jobsData]);
setJobList(prev => {
const prevArray = Array.isArray(prev) ? prev : [];
return [...prevArray, ...jobsArray];
});
setPagination(prev => ({ ...prev, offset, total }));
} else {
setJobList(jobsData);
setJobList(jobsArray);
setPagination({ total, limit: result.limit || pagination.limit, offset: result.offset || 0 });
}
setHasMore(offset + jobsData.length < total);
setHasMore(offset + jobsArray.length < total);
} catch (error) {
console.error('Failed to load jobs:', error);
// Ensure jobList is always an array even on error
if (!append) {
setJobList([]);
}
} finally {
setLoading(false);
}
@@ -206,12 +150,21 @@ export default function JobList() {
const handleDelete = async (jobId) => {
if (!confirm('Are you sure you want to permanently delete this job? This action cannot be undone.')) return;
try {
await jobs.delete(jobId);
loadJobs();
// Optimistically update the list
setJobList(prev => {
const prevArray = Array.isArray(prev) ? prev : [];
return prevArray.filter(j => j.id !== jobId);
});
if (selectedJob && selectedJob.id === jobId) {
setSelectedJob(null);
}
// Then actually delete
await jobs.delete(jobId);
// Reload to ensure consistency
loadJobs();
} catch (error) {
// On error, reload to restore correct state
loadJobs();
alert('Failed to delete job: ' + error.message);
}
};
@@ -228,11 +181,7 @@ export default function JobList() {
};
if (loading && jobList.length === 0) {
return (
<div className="flex justify-center items-center h-64">
<div className="animate-spin rounded-full h-12 w-12 border-b-2 border-orange-500"></div>
</div>
);
return <LoadingSpinner size="md" className="h-64" />;
}
if (jobList.length === 0) {

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,19 @@
import React from 'react';
/**
* Shared LoadingSpinner component with size variants
*/
export default function LoadingSpinner({ size = 'md', className = '', borderColor = 'border-orange-500' }) {
const sizeClasses = {
sm: 'h-8 w-8',
md: 'h-12 w-12',
lg: 'h-16 w-16',
};
return (
<div className={`flex justify-center items-center ${className}`}>
<div className={`animate-spin rounded-full border-b-2 ${borderColor} ${sizeClasses[size]}`}></div>
</div>
);
}

View File

@@ -1,5 +1,6 @@
import { useState, useEffect } from 'react';
import { auth } from '../utils/api';
import ErrorMessage from './ErrorMessage';
export default function Login() {
const [providers, setProviders] = useState({
@@ -92,11 +93,7 @@ export default function Login() {
</div>
<div className="space-y-4">
{error && (
<div className="p-4 bg-red-400/20 border border-red-400/50 rounded-lg text-red-400 text-sm">
{error}
</div>
)}
<ErrorMessage error={error} className="text-sm" />
{providers.local && (
<div className="pb-4 border-b border-gray-700">
<div className="flex gap-2 mb-4">

View File

@@ -1,5 +1,6 @@
import { useState } from 'react';
import { auth } from '../utils/api';
import ErrorMessage from './ErrorMessage';
import { useAuth } from '../hooks/useAuth';
export default function PasswordChange({ targetUserId = null, targetUserName = null, onSuccess }) {
@@ -64,11 +65,7 @@ export default function PasswordChange({ targetUserId = null, targetUserName = n
{isChangingOtherUser ? `Change Password for ${targetUserName || 'User'}` : 'Change Password'}
</h2>
{error && (
<div className="mb-4 p-3 bg-red-400/20 border border-red-400/50 rounded-lg text-red-400 text-sm">
{error}
</div>
)}
<ErrorMessage error={error} className="mb-4 text-sm" />
{success && (
<div className="mb-4 p-3 bg-green-400/20 border border-green-400/50 rounded-lg text-green-400 text-sm">

View File

@@ -1,22 +1,71 @@
import { useState, useEffect } from 'react';
import { admin } from '../utils/api';
import { useState, useEffect, useRef } from 'react';
import { admin, normalizeArrayResponse } from '../utils/api';
import { wsManager } from '../utils/websocket';
import JobDetails from './JobDetails';
import LoadingSpinner from './LoadingSpinner';
export default function UserJobs({ userId, userName, onBack }) {
const [jobList, setJobList] = useState([]);
const [loading, setLoading] = useState(true);
const [selectedJob, setSelectedJob] = useState(null);
const listenerIdRef = useRef(null);
useEffect(() => {
loadJobs();
const interval = setInterval(loadJobs, 5000);
return () => clearInterval(interval);
// Use shared WebSocket manager for real-time updates instead of polling
listenerIdRef.current = wsManager.subscribe(`userjobs_${userId}`, {
open: () => {
console.log('UserJobs: Shared WebSocket connected');
loadJobs();
},
message: (data) => {
// Handle jobs channel messages (always broadcasted)
if (data.channel === 'jobs') {
if (data.type === 'job_update' && data.data) {
// Update job in list if it belongs to this user
setJobList(prev => {
const prevArray = Array.isArray(prev) ? prev : [];
const index = prevArray.findIndex(j => j.id === data.job_id);
if (index >= 0) {
const updated = [...prevArray];
updated[index] = { ...updated[index], ...data.data };
return updated;
}
// If job not in current list, reload to get updated list
if (data.data.status === 'completed' || data.data.status === 'failed') {
loadJobs();
}
return prevArray;
});
} else if (data.type === 'job_created' && data.data) {
// New job created - reload to check if it belongs to this user
loadJobs();
}
}
},
error: (error) => {
console.error('UserJobs: Shared WebSocket error:', error);
},
close: (event) => {
console.log('UserJobs: Shared WebSocket closed:', event);
}
});
// Ensure connection is established
wsManager.connect();
return () => {
if (listenerIdRef.current) {
wsManager.unsubscribe(listenerIdRef.current);
listenerIdRef.current = null;
}
};
}, [userId]);
const loadJobs = async () => {
try {
const data = await admin.getUserJobs(userId);
setJobList(Array.isArray(data) ? data : []);
setJobList(normalizeArrayResponse(data));
} catch (error) {
console.error('Failed to load jobs:', error);
setJobList([]);
@@ -47,11 +96,7 @@ export default function UserJobs({ userId, userName, onBack }) {
}
if (loading) {
return (
<div className="flex justify-center items-center h-64">
<div className="animate-spin rounded-full h-12 w-12 border-b-2 border-orange-500"></div>
</div>
);
return <LoadingSpinner size="md" className="h-64" />;
}
return (

View File

@@ -1,4 +1,6 @@
import { useState, useRef, useEffect } from 'react';
import ErrorMessage from './ErrorMessage';
import LoadingSpinner from './LoadingSpinner';
export default function VideoPlayer({ videoUrl, onClose }) {
const videoRef = useRef(null);
@@ -55,10 +57,10 @@ export default function VideoPlayer({ videoUrl, onClose }) {
if (error) {
return (
<div className="bg-red-50 border border-red-200 rounded-lg p-4 text-red-700">
{error}
<div className="mt-2 text-sm text-red-600">
<a href={videoUrl} download className="underline">Download video instead</a>
<div>
<ErrorMessage error={error} />
<div className="mt-2 text-sm text-gray-400">
<a href={videoUrl} download className="text-orange-400 hover:text-orange-300 underline">Download video instead</a>
</div>
</div>
);
@@ -68,7 +70,7 @@ export default function VideoPlayer({ videoUrl, onClose }) {
<div className="relative bg-black rounded-lg overflow-hidden">
{loading && (
<div className="absolute inset-0 flex items-center justify-center bg-black bg-opacity-50 z-10">
<div className="animate-spin rounded-full h-12 w-12 border-b-2 border-white"></div>
<LoadingSpinner size="lg" className="border-white" />
</div>
)}
<video

View File

@@ -21,6 +21,12 @@ function getCacheKey(endpoint, options = {}) {
return `${endpoint}${query ? '?' + query : ''}`;
}
// Utility function to normalize array responses (handles both old and new formats)
export function normalizeArrayResponse(response) {
const data = response?.data || response;
return Array.isArray(data) ? data : [];
}
// Sentinel value to indicate a request was superseded (instead of rejecting)
// Export it so components can check for it
export const REQUEST_SUPERSEDED = Symbol('REQUEST_SUPERSEDED');
@@ -36,6 +42,9 @@ function debounceRequest(key, requestFn, delay = DEBOUNCE_DELAY) {
if (pending.timestamp && (now - pending.timestamp) < DEDUPE_WINDOW) {
pending.promise.then(resolve).catch(reject);
return;
} else {
// Request is older than dedupe window - remove it and create new one
pendingRequests.delete(key);
}
}
@@ -74,8 +83,16 @@ export const setAuthErrorHandler = (handler) => {
onAuthError = handler;
};
const handleAuthError = (response) => {
// Whitelist of endpoints that should NOT trigger auth error handling
// These are endpoints that can legitimately return 401/403 without meaning the user is logged out
const AUTH_CHECK_ENDPOINTS = ['/auth/me', '/auth/logout'];
const handleAuthError = (response, endpoint) => {
if (response.status === 401 || response.status === 403) {
// Don't trigger auth error handler for endpoints that check auth status
if (AUTH_CHECK_ENDPOINTS.includes(endpoint)) {
return;
}
// Trigger auth error handler if set (this will clear user state)
if (onAuthError) {
onAuthError();
@@ -89,60 +106,79 @@ const handleAuthError = (response) => {
}
};
// Extract error message from response - centralized to avoid duplication
async function extractErrorMessage(response) {
try {
const errorData = await response.json();
return errorData?.error || response.statusText;
} catch {
return response.statusText;
}
}
export const api = {
async get(endpoint) {
async get(endpoint, options = {}) {
const abortController = options.signal || new AbortController();
const response = await fetch(`${API_BASE}${endpoint}`, {
credentials: 'include', // Include cookies for session
signal: abortController.signal,
});
if (!response.ok) {
// Handle auth errors before parsing response
// Don't redirect on /auth/me - that's the auth check itself
if ((response.status === 401 || response.status === 403) && !endpoint.startsWith('/auth/')) {
handleAuthError(response);
// Don't redirect - let React handle UI change through state
}
const errorData = await response.json().catch(() => null);
const errorMessage = errorData?.error || response.statusText;
handleAuthError(response, endpoint);
const errorMessage = await extractErrorMessage(response);
throw new Error(errorMessage);
}
return response.json();
},
async post(endpoint, data) {
async post(endpoint, data, options = {}) {
const abortController = options.signal || new AbortController();
const response = await fetch(`${API_BASE}${endpoint}`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(data),
credentials: 'include', // Include cookies for session
signal: abortController.signal,
});
if (!response.ok) {
// Handle auth errors before parsing response
// Don't redirect on /auth/* endpoints - those are login/logout
if ((response.status === 401 || response.status === 403) && !endpoint.startsWith('/auth/')) {
handleAuthError(response);
// Don't redirect - let React handle UI change through state
}
const errorData = await response.json().catch(() => null);
const errorMessage = errorData?.error || response.statusText;
handleAuthError(response, endpoint);
const errorMessage = await extractErrorMessage(response);
throw new Error(errorMessage);
}
return response.json();
},
async delete(endpoint) {
async patch(endpoint, data, options = {}) {
const abortController = options.signal || new AbortController();
const response = await fetch(`${API_BASE}${endpoint}`, {
method: 'DELETE',
method: 'PATCH',
headers: { 'Content-Type': 'application/json' },
body: data ? JSON.stringify(data) : undefined,
credentials: 'include', // Include cookies for session
signal: abortController.signal,
});
if (!response.ok) {
// Handle auth errors before parsing response
// Don't redirect on /auth/* endpoints
if ((response.status === 401 || response.status === 403) && !endpoint.startsWith('/auth/')) {
handleAuthError(response);
// Don't redirect - let React handle UI change through state
}
const errorData = await response.json().catch(() => null);
const errorMessage = errorData?.error || response.statusText;
handleAuthError(response, endpoint);
const errorMessage = await extractErrorMessage(response);
throw new Error(errorMessage);
}
return response.json();
},
async delete(endpoint, options = {}) {
const abortController = options.signal || new AbortController();
const response = await fetch(`${API_BASE}${endpoint}`, {
method: 'DELETE',
credentials: 'include', // Include cookies for session
signal: abortController.signal,
});
if (!response.ok) {
// Handle auth errors before parsing response
handleAuthError(response, endpoint);
const errorMessage = await extractErrorMessage(response);
throw new Error(errorMessage);
}
return response.json();
@@ -179,8 +215,7 @@ export const api = {
} else {
// Handle auth errors
if (xhr.status === 401 || xhr.status === 403) {
handleAuthError({ status: xhr.status });
// Don't redirect - let React handle UI change through state
handleAuthError({ status: xhr.status }, endpoint);
}
try {
const errorData = JSON.parse(xhr.responseText);
@@ -263,7 +298,7 @@ export const jobs = {
if (options.status) params.append('status', options.status);
if (options.sort) params.append('sort', options.sort);
const query = params.toString();
return api.get(`/jobs/summary${query ? '?' + query : ''}`);
return api.get(`/jobs/summary${query ? '?' + query : ''}`, options);
});
},
@@ -286,7 +321,7 @@ export const jobs = {
}
return response.json();
}
return api.get(`/jobs/${id}`);
return api.get(`/jobs/${id}`, options);
});
},
@@ -319,7 +354,7 @@ export const jobs = {
if (options.file_type) params.append('file_type', options.file_type);
if (options.extension) params.append('extension', options.extension);
const query = params.toString();
return api.get(`/jobs/${jobId}/files${query ? '?' + query : ''}`);
return api.get(`/jobs/${jobId}/files${query ? '?' + query : ''}`, options);
});
},
@@ -333,8 +368,8 @@ export const jobs = {
});
},
async getContextArchive(jobId) {
return api.get(`/jobs/${jobId}/context`);
async getContextArchive(jobId, options = {}) {
return api.get(`/jobs/${jobId}/context`, options);
},
downloadFile(jobId, fileId) {
@@ -354,7 +389,7 @@ export const jobs = {
if (options.limit) params.append('limit', options.limit.toString());
if (options.sinceId) params.append('since_id', options.sinceId.toString());
const query = params.toString();
const result = await api.get(`/jobs/${jobId}/tasks/${taskId}/logs${query ? '?' + query : ''}`);
const result = await api.get(`/jobs/${jobId}/tasks/${taskId}/logs${query ? '?' + query : ''}`, options);
// Handle both old format (array) and new format (object with logs, last_id, limit)
if (Array.isArray(result)) {
return { logs: result, last_id: result.length > 0 ? result[result.length - 1].id : 0, limit: options.limit || 100 };
@@ -363,10 +398,21 @@ export const jobs = {
});
},
async getTaskSteps(jobId, taskId) {
return api.get(`/jobs/${jobId}/tasks/${taskId}/steps`);
async getTaskSteps(jobId, taskId, options = {}) {
return api.get(`/jobs/${jobId}/tasks/${taskId}/steps`, options);
},
// New unified client WebSocket - DEPRECATED: Use wsManager from websocket.js instead
// This is kept for backwards compatibility but should not be used
streamClientWebSocket() {
console.warn('streamClientWebSocket() is deprecated - use wsManager from websocket.js instead');
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsHost = window.location.host;
const url = `${wsProtocol}//${wsHost}${API_BASE}/ws`;
return new WebSocket(url);
},
// Old WebSocket methods (to be removed after migration)
streamTaskLogsWebSocket(jobId, taskId, lastId = 0) {
// Convert HTTP to WebSocket URL
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
@@ -378,7 +424,7 @@ export const jobs = {
streamJobsWebSocket() {
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsHost = window.location.host;
const url = `${wsProtocol}//${wsHost}${API_BASE}/jobs/ws`;
const url = `${wsProtocol}//${wsHost}${API_BASE}/jobs/ws-old`;
return new WebSocket(url);
},
@@ -408,7 +454,7 @@ export const jobs = {
if (options.frameEnd) params.append('frame_end', options.frameEnd.toString());
if (options.sort) params.append('sort', options.sort);
const query = params.toString();
return api.get(`/jobs/${jobId}/tasks${query ? '?' + query : ''}`);
return api.get(`/jobs/${jobId}/tasks${query ? '?' + query : ''}`, options);
});
},
@@ -421,7 +467,7 @@ export const jobs = {
if (options.status) params.append('status', options.status);
if (options.sort) params.append('sort', options.sort);
const query = params.toString();
return api.get(`/jobs/${jobId}/tasks/summary${query ? '?' + query : ''}`);
return api.get(`/jobs/${jobId}/tasks/summary${query ? '?' + query : ''}`, options);
});
},

177
web/src/utils/websocket.js Normal file
View File

@@ -0,0 +1,177 @@
// Shared WebSocket connection manager
// All components should use this instead of creating their own connections
class WebSocketManager {
constructor() {
this.ws = null;
this.listeners = new Map(); // Map of listener IDs to callback functions
this.reconnectTimeout = null;
this.reconnectDelay = 2000;
this.isConnecting = false;
this.listenerIdCounter = 0;
this.verboseLogging = false; // Set to true to enable verbose WebSocket logging
}
connect() {
// If already connected or connecting, don't create a new connection
if (this.ws && (this.ws.readyState === WebSocket.CONNECTING || this.ws.readyState === WebSocket.OPEN)) {
return;
}
if (this.isConnecting) {
return;
}
this.isConnecting = true;
try {
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsHost = window.location.host;
const API_BASE = '/api';
const url = `${wsProtocol}//${wsHost}${API_BASE}/jobs/ws`;
this.ws = new WebSocket(url);
this.ws.onopen = () => {
if (this.verboseLogging) {
console.log('Shared WebSocket connected');
}
this.isConnecting = false;
this.notifyListeners('open', {});
};
this.ws.onmessage = (event) => {
try {
const data = JSON.parse(event.data);
if (this.verboseLogging) {
console.log('WebSocketManager: Message received:', data.type, data.channel || 'no channel', data);
}
this.notifyListeners('message', data);
} catch (error) {
console.error('WebSocketManager: Failed to parse message:', error, 'Raw data:', event.data);
}
};
this.ws.onerror = (error) => {
console.error('Shared WebSocket error:', error);
this.isConnecting = false;
this.notifyListeners('error', error);
};
this.ws.onclose = (event) => {
if (this.verboseLogging) {
console.log('Shared WebSocket closed:', {
code: event.code,
reason: event.reason,
wasClean: event.wasClean
});
}
this.ws = null;
this.isConnecting = false;
this.notifyListeners('close', event);
// Always retry connection
if (this.reconnectTimeout) {
clearTimeout(this.reconnectTimeout);
}
this.reconnectTimeout = setTimeout(() => {
if (!this.ws || this.ws.readyState === WebSocket.CLOSED) {
this.connect();
}
}, this.reconnectDelay);
};
} catch (error) {
console.error('Failed to create WebSocket:', error);
this.isConnecting = false;
// Retry after delay
this.reconnectTimeout = setTimeout(() => {
this.connect();
}, this.reconnectDelay);
}
}
subscribe(listenerId, callbacks) {
// Generate ID if not provided
if (!listenerId) {
listenerId = `listener_${this.listenerIdCounter++}`;
}
if (this.verboseLogging) {
console.log('WebSocketManager: Subscribing listener:', listenerId, 'WebSocket state:', this.ws ? this.ws.readyState : 'no connection');
}
this.listeners.set(listenerId, callbacks);
// Connect if not already connected
if (!this.ws || this.ws.readyState === WebSocket.CLOSED) {
if (this.verboseLogging) {
console.log('WebSocketManager: WebSocket not connected, connecting...');
}
this.connect();
}
// If already open, notify immediately
if (this.ws && this.ws.readyState === WebSocket.OPEN && callbacks.open) {
if (this.verboseLogging) {
console.log('WebSocketManager: WebSocket already open, calling open callback for listener:', listenerId);
}
// Use setTimeout to ensure this happens after the listener is registered
setTimeout(() => {
if (callbacks.open) {
callbacks.open();
}
}, 0);
}
return listenerId;
}
unsubscribe(listenerId) {
this.listeners.delete(listenerId);
// If no more listeners, we could close the connection, but let's keep it open
// in case other components need it
}
send(data) {
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
if (this.verboseLogging) {
console.log('WebSocketManager: Sending message:', data);
}
this.ws.send(JSON.stringify(data));
} else {
console.warn('WebSocketManager: Cannot send message - connection not open. State:', this.ws ? this.ws.readyState : 'no connection', 'Message:', data);
}
}
notifyListeners(eventType, data) {
this.listeners.forEach((callbacks) => {
if (callbacks[eventType]) {
try {
callbacks[eventType](data);
} catch (error) {
console.error('Error in WebSocket listener:', error);
}
}
});
}
getReadyState() {
return this.ws ? this.ws.readyState : WebSocket.CLOSED;
}
disconnect() {
if (this.reconnectTimeout) {
clearTimeout(this.reconnectTimeout);
this.reconnectTimeout = null;
}
if (this.ws) {
this.ws.close();
this.ws = null;
}
this.listeners.clear();
}
}
// Export singleton instance
export const wsManager = new WebSocketManager();