13 Commits

Author SHA1 Message Date
0a8f40b9cb Add MIT License file to the project
All checks were successful
Release Tag / release (push) Successful in 47s
- Introduced LICENSE file with the MIT License terms.
- Ensured compliance with open-source distribution and usage rights.
2026-01-02 14:55:53 -06:00
7440511740 Add GoReleaser configuration and update Makefile for streamlined builds
Some checks failed
Release Tag / release (push) Failing after 46s
- Introduced .goreleaser.yaml for automated release management.
- Updated Makefile to utilize GoReleaser for building the jiggablend binary.
- Added new workflows for release tagging and pull request checks in Gitea.
- Updated dependencies in go.mod and go.sum, including new packages for versioning.
- Enhanced .gitignore to exclude build artifacts in the dist directory.
2026-01-02 14:28:03 -06:00
c7c8762164 Update README.md to reflect significant changes in architecture and features. Replace DuckDB with SQLite for the database, enhance authentication options, and introduce a modern React-based web UI. Expand job management capabilities, including video encoding support and metadata extraction. Revise installation and configuration instructions, and clarify output formats and storage structure. Improve development guidelines for building and testing the application. 2026-01-02 14:03:03 -06:00
94490237fe Update .gitignore to include log files and database journal files. Modify go.mod to update dependencies for go-sqlite3 and cloud.google.com/go/compute/metadata. Enhance Makefile to include logging options for manager and runner commands. Introduce new job token handling in auth package and implement database migration scripts. Refactor manager and runner components to improve job processing and metadata extraction. Add support for video preview in frontend components and enhance WebSocket management for channel subscriptions. 2026-01-02 13:55:19 -06:00
edc8ea160c something 2025-11-27 00:46:48 -06:00
11e7552b5b Refactor API key handling in runners.go to streamline authorization logic. Remove redundant checks for "Bearer " prefix in API key extraction, enhancing code clarity and maintainability. 2025-11-25 08:23:25 -06:00
690e6b13f8 its a bit broken 2025-11-25 03:48:28 -06:00
a53ea4dce7 Refactor runner and API components to remove IP address handling. Update client and server logic to streamline runner registration and task distribution. Introduce write mutexes for connection management to enhance concurrency control. Clean up whitespace and improve code readability across multiple files. 2025-11-24 22:58:56 -06:00
3217bbfe4d Refactor error handling and improve code formatting in runners.go. Replace fmt.Errorf with errors.New for better error management. Clean up whitespace and enhance readability in various API response structures. 2025-11-24 21:49:17 -06:00
4ac05d50a1 Enhance logging and context handling in job management. Introduce a logger initialization with configurable parameters in the manager and runner commands. Update job context handling to use tar files instead of tar.gz, and implement ETag generation for improved caching. Refactor API endpoints to support new context file structure and enhance error handling in job submissions. Add support for unhide objects and auto-execution options in job creation requests. 2025-11-24 21:48:05 -06:00
a029714e08 Implement context archive handling and metadata extraction for render jobs. Add functionality to check for Blender availability, create context archives, and extract metadata from .blend files. Update job creation and retrieval processes to support new metadata structure and context file management. Enhance client-side components to display context files and integrate new API endpoints for context handling. 2025-11-24 10:02:13 -06:00
f9ff4d0138 Enhance server configuration for large file uploads and improve token handling. Increase request body size limit in the server to 20 GB, update registration token expiration logic to support infinite expiration, and adjust database schema to accommodate larger file sizes. Add detailed logging for file upload processes and error handling improvements. 2025-11-23 16:59:36 -06:00
f7e1766d8b Update go.mod to include golang.org/x/crypto v0.45.0 and remove indirect reference. Refactor task handling in client.go to use switch-case for task types and remove unused VAAPI device functions for cleaner code. 2025-11-23 11:03:54 -06:00
78 changed files with 24940 additions and 9515 deletions

View File

@@ -0,0 +1,24 @@
name: Release Tag
on:
push:
tags:
- '*'
jobs:
release:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@main
with:
fetch-depth: 0
- run: git fetch --force --tags
- uses: actions/setup-go@main
with:
go-version-file: 'go.mod'
- uses: goreleaser/goreleaser-action@master
with:
distribution: goreleaser
version: 'latest'
args: release
env:
GITEA_TOKEN: ${{secrets.RELEASE_TOKEN}}

View File

@@ -0,0 +1,15 @@
name: PR Check
on:
- pull_request
jobs:
check-and-test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@main
- uses: actions/setup-go@main
with:
go-version-file: 'go.mod'
- run: go mod tidy
- run: go build ./...
- run: go test -race -v -shuffle=on ./...

6
.gitignore vendored
View File

@@ -27,7 +27,11 @@ go.work
jiggablend.db
jiggablend.db.wal
jiggablend.db-shm
jiggablend.db-journal
# Log files
*.log
logs/
# Secrets and configuration
runner-secrets.json
runner-secrets-*.json
@@ -61,6 +65,7 @@ lerna-debug.log*
*.o
*.a
*.so
/dist/
# Temporary files
*.tmp
@@ -69,6 +74,7 @@ lerna-debug.log*
# Logs
*.log
/logs/
# OS files
Thumbs.db

48
.goreleaser.yaml Normal file
View File

@@ -0,0 +1,48 @@
version: 2
before:
hooks:
- go mod tidy -v
- sh -c "cd web && npm install && npm run build"
builds:
- id: default
main: ./cmd/jiggablend
binary: jiggablend
ldflags:
- -X jiggablend/version.Version={{.Version}}
- -X jiggablend/version.Date={{.Date}}
env:
- CGO_ENABLED=1
goos:
- linux
goarch:
- amd64
checksum:
name_template: "checksums.txt"
archives:
- id: default
name_template: "{{ .ProjectName }}-{{ .Os }}-{{ .Arch }}"
formats: tar.gz
format_overrides:
- goos: windows
formats: zip
files:
- README.md
- LICENSE
changelog:
sort: asc
filters:
exclude:
- "^docs:"
- "^test:"
release:
name_template: "{{ .ProjectName }}-{{ .Version }}"
gitea_urls:
api: https://git.s1d3sw1ped.com/api/v1
download: https://git.s1d3sw1ped.com

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright © 2026 s1d3sw1ped
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

102
Makefile
View File

@@ -1,58 +1,61 @@
.PHONY: build build-manager build-runner build-web run-manager run-runner run cleanup cleanup-manager cleanup-runner clean test
.PHONY: build build-web run run-manager run-runner cleanup cleanup-manager cleanup-runner clean-bin clean-web test help install
# Build all
build: clean-bin build-manager build-runner
# Build manager
build-manager: clean-bin build-web
go build -o bin/manager ./cmd/manager
# Build runner
build-runner: clean-bin
GOOS=linux GOARCH=amd64 go build -o bin/runner ./cmd/runner
# Build the jiggablend binary (includes embedded web UI)
build:
@echo "Building with GoReleaser..."
goreleaser build --clean --snapshot --single-target
@mkdir -p bin
@find dist -name jiggablend -type f -exec cp {} bin/jiggablend \;
# Build web UI
build-web: clean-web
cd web && npm install && npm run build
# Cleanup manager (database and storage)
# Cleanup manager logs
cleanup-manager:
@echo "Cleaning up manager database and storage..."
@rm -f jiggablend.db 2>/dev/null || true
@rm -f jiggablend.db-shm 2>/dev/null || true
@rm -f jiggablend.db-wal 2>/dev/null || true
@rm -rf jiggablend-storage 2>/dev/null || true
@echo "Cleaning up manager logs..."
@rm -rf logs/manager.log 2>/dev/null || true
@echo "Manager cleanup complete"
# Cleanup runner (workspaces and secrets)
# Cleanup runner logs
cleanup-runner:
@echo "Cleaning up runner workspaces and secrets..."
@rm -rf jiggablend-workspaces jiggablend-workspace* *workspace* runner-secrets*.json 2>/dev/null || true
@echo "Cleaning up runner logs..."
@rm -rf logs/runner*.log 2>/dev/null || true
@echo "Runner cleanup complete"
# Cleanup both manager and runner
# Cleanup both manager and runner logs
cleanup: cleanup-manager cleanup-runner
# Run all parallel
run: cleanup-manager cleanup-runner build-manager build-runner
# Run manager and runner in parallel (for testing)
run: cleanup build init-test
@echo "Starting manager and runner in parallel..."
@echo "Press Ctrl+C to stop both..."
@trap 'kill $$MANAGER_PID $$RUNNER_PID 2>/dev/null; exit' INT TERM; \
FIXED_REGISTRATION_TOKEN=test-token ENABLE_LOCAL_AUTH=true LOCAL_TEST_EMAIL=test@example.com LOCAL_TEST_PASSWORD=testpassword bin/manager & \
bin/jiggablend manager -l manager.log & \
MANAGER_PID=$$!; \
REGISTRATION_TOKEN=test-token bin/runner & \
sleep 2; \
bin/jiggablend runner -l runner.log --api-key=jk_r0_test_key_123456789012345678901234567890 & \
RUNNER_PID=$$!; \
wait $$MANAGER_PID $$RUNNER_PID
# Run manager
# Note: ENABLE_LOCAL_AUTH enables local user registration/login
# LOCAL_TEST_EMAIL and LOCAL_TEST_PASSWORD create a test user on startup (if it doesn't exist)
run-manager: cleanup-manager build-manager
FIXED_REGISTRATION_TOKEN=test-token ENABLE_LOCAL_AUTH=true LOCAL_TEST_EMAIL=test@example.com LOCAL_TEST_PASSWORD=testpassword bin/manager
# Run manager server
run-manager: cleanup-manager build init-test
bin/jiggablend manager -l manager.log
# Run runner
run-runner: cleanup-runner build-runner
REGISTRATION_TOKEN=test-token bin/runner
run-runner: cleanup-runner build
bin/jiggablend runner -l runner.log --api-key=jk_r0_test_key_123456789012345678901234567890
# Initialize for testing (first run setup)
init-test: build
@echo "Initializing test configuration..."
bin/jiggablend manager config enable localauth
bin/jiggablend manager config set fixed-apikey jk_r0_test_key_123456789012345678901234567890 -f -y
bin/jiggablend manager config add user test@example.com testpassword --admin -f -y
@echo "Test configuration complete!"
@echo "fixed api key: jk_r0_test_key_123456789012345678901234567890"
@echo "test user: test@example.com"
@echo "test password: testpassword"
# Clean bin build artifacts
clean-bin:
@@ -66,3 +69,38 @@ clean-web:
test:
go test ./... -timeout 30s
# Show help
help:
@echo "Jiggablend Build and Run Makefile"
@echo ""
@echo "Build targets:"
@echo " build - Build jiggablend binary with embedded web UI"
@echo " build-web - Build web UI only"
@echo ""
@echo "Run targets:"
@echo " run - Run manager and runner in parallel (for testing)"
@echo " run-manager - Run manager server"
@echo " run-runner - Run runner with test API key"
@echo " init-test - Initialize test configuration (run once)"
@echo ""
@echo "Cleanup targets:"
@echo " cleanup - Clean all logs"
@echo " cleanup-manager - Clean manager logs"
@echo " cleanup-runner - Clean runner logs"
@echo ""
@echo "Other targets:"
@echo " clean-bin - Clean build artifacts"
@echo " clean-web - Clean web build artifacts"
@echo " test - Run Go tests"
@echo " help - Show this help"
@echo ""
@echo "CLI Usage:"
@echo " jiggablend manager serve - Start the manager server"
@echo " jiggablend runner - Start a runner"
@echo " jiggablend manager config show - Show configuration"
@echo " jiggablend manager config enable localauth"
@echo " jiggablend manager config add user --email=x --password=y"
@echo " jiggablend manager config add apikey --name=mykey"
@echo " jiggablend manager config set fixed-apikey <key>"
@echo " jiggablend manager config list users"
@echo " jiggablend manager config list apikeys"

266
README.md
View File

@@ -4,28 +4,41 @@ A distributed Blender render farm system built with Go. The system consists of a
## Architecture
- **Manager**: Central server with REST API, web UI, DuckDB database, and local file storage
- **Manager**: Central server with REST API, embedded web UI, SQLite database, and local file storage
- **Runner**: Linux amd64 client that connects to manager, receives jobs, executes Blender renders, and reports back
Both manager and runner are part of a single binary (`jiggablend`) with subcommands.
## Features
- OAuth authentication (Google and Discord)
- Web-based job submission and monitoring
- Distributed rendering across multiple runners
- Real-time job progress tracking
- File upload/download for Blender files and rendered outputs
- Runner health monitoring
- **Authentication**: OAuth (Google and Discord) and local authentication with user management
- **Web UI**: Modern React-based interface for job submission and monitoring
- **Distributed Rendering**: Scale across multiple runners with automatic job distribution
- **Real-time Updates**: WebSocket-based progress tracking and job status updates
- **Video Encoding**: Automatic video encoding from EXR/PNG sequences with multiple codec support:
- H.264 (MP4) - SDR and HDR support
- AV1 (MP4) - With alpha channel support
- VP9 (WebM) - With alpha channel and HDR support
- **Output Formats**: PNG, JPEG, EXR, and video formats (MP4, WebM)
- **Blender Version Management**: Support for multiple Blender versions with automatic detection
- **Metadata Extraction**: Automatic extraction of scene metadata from Blender files
- **Admin Panel**: User and runner management interface
- **Runner Management**: API key-based authentication for runners with health monitoring
- **HDR Support**: Preserve HDR range in video encoding with HLG transfer function
- **Alpha Channel**: Preserve alpha channel in video encoding (AV1 and VP9)
## Prerequisites
### Manager
- Go 1.21 or later
- DuckDB (via Go driver)
- Go 1.25.4 or later
- SQLite (via Go driver)
- Blender installed and in PATH (for metadata extraction)
- ImageMagick installed (for EXR preview conversion)
### Runner
- Linux amd64
- Blender installed and in PATH
- FFmpeg installed (optional, for video processing)
- Blender installed (can use bundled versions from storage)
- FFmpeg installed (required for video encoding)
## Installation
@@ -44,44 +57,87 @@ go mod download
### Manager
Set the following environment variables for authentication (optional):
Configuration is managed through the CLI using `jiggablend manager config` commands. The configuration is stored in the SQLite database.
#### Initial Setup
For testing, use the Makefile helper:
```bash
make init-test
```
This will:
- Enable local authentication
- Set a fixed API key for testing
- Create a test admin user (test@example.com / testpassword)
#### Manual Configuration
```bash
# OAuth Providers (optional)
export GOOGLE_CLIENT_ID="your-google-client-id"
export GOOGLE_CLIENT_SECRET="your-google-client-secret"
export GOOGLE_REDIRECT_URL="http://localhost:8080/api/auth/google/callback"
# Enable local authentication
jiggablend manager config enable localauth
export DISCORD_CLIENT_ID="your-discord-client-id"
export DISCORD_CLIENT_SECRET="your-discord-client-secret"
export DISCORD_REDIRECT_URL="http://localhost:8080/api/auth/discord/callback"
# Add a user
jiggablend manager config add user <email> <password> --admin
# Local Authentication (optional)
export ENABLE_LOCAL_AUTH="true"
# Generate an API key for runners
jiggablend manager config add apikey <name> --scope manager
# Test User (optional, for testing only)
# Creates a local user on startup if it doesn't exist
export LOCAL_TEST_EMAIL="test@example.com"
export LOCAL_TEST_PASSWORD="testpassword"
# Set OAuth credentials
jiggablend manager config set google-oauth <client-id> <client-secret> --redirect-url <url>
jiggablend manager config set discord-oauth <client-id> <client-secret> --redirect-url <url>
# View current configuration
jiggablend manager config show
# List users and API keys
jiggablend manager config list users
jiggablend manager config list apikeys
```
#### Environment Variables
You can also use environment variables with the `JIGGABLEND_` prefix:
- `JIGGABLEND_PORT` - Server port (default: 8080)
- `JIGGABLEND_DB` - Database path (default: jiggablend.db)
- `JIGGABLEND_STORAGE` - Storage path (default: ./jiggablend-storage)
- `JIGGABLEND_LOG_FILE` - Log file path
- `JIGGABLEND_LOG_LEVEL` - Log level (debug, info, warn, error)
- `JIGGABLEND_VERBOSE` - Enable verbose logging
### Runner
No configuration required. Runner will auto-detect hostname and IP.
The runner requires an API key to connect to the manager. The runner will auto-detect hostname and IP.
## Usage
### Building
```bash
# Build the unified binary (includes embedded web UI)
make build
# Or build directly
go build -o bin/jiggablend ./cmd/jiggablend
# Build web UI separately
make build-web
```
### Running the Manager
```bash
# Using make
# Using make (includes test setup)
make run-manager
# Or directly
go run ./cmd/manager
bin/jiggablend manager
# With custom options
go run ./cmd/manager -port 8080 -db jiggablend.db -storage ./storage
bin/jiggablend manager --port 8080 --db jiggablend.db --storage ./jiggablend-storage --log-file manager.log
# Using environment variables
JIGGABLEND_PORT=8080 JIGGABLEND_DB=jiggablend.db bin/jiggablend manager
```
The manager will start on `http://localhost:8080` by default.
@@ -89,26 +145,28 @@ The manager will start on `http://localhost:8080` by default.
### Running a Runner
```bash
# Using make
# Using make (uses test API key)
make run-runner
# Or directly
go run ./cmd/runner
# Or directly (requires API key)
bin/jiggablend runner --api-key <your-api-key>
# With custom options
go run ./cmd/runner -manager http://localhost:8080 -name my-runner
bin/jiggablend runner --manager http://localhost:8080 --name my-runner --api-key <key> --log-file runner.log
# Using environment variables
JIGGABLEND_MANAGER=http://localhost:8080 JIGGABLEND_API_KEY=<key> bin/jiggablend runner
```
### Building
### Running Both (for Testing)
```bash
# Build manager
make build-manager
# Build runner (Linux amd64)
make build-runner
# Run manager and runner in parallel
make run
```
This will start both the manager and a test runner with a fixed API key.
## OAuth Setup
### Google OAuth
@@ -118,7 +176,10 @@ make build-runner
3. Enable Google+ API
4. Create OAuth 2.0 credentials
5. Add authorized redirect URI: `http://localhost:8080/api/auth/google/callback`
6. Set environment variables with Client ID and Secret
6. Configure using CLI:
```bash
jiggablend manager config set google-oauth <client-id> <client-secret> --redirect-url http://localhost:8080/api/auth/google/callback
```
### Discord OAuth
@@ -126,25 +187,39 @@ make build-runner
2. Create a new application
3. Go to OAuth2 section
4. Add redirect URI: `http://localhost:8080/api/auth/discord/callback`
5. Set environment variables with Client ID and Secret
5. Configure using CLI:
```bash
jiggablend manager config set discord-oauth <client-id> <client-secret> --redirect-url http://localhost:8080/api/auth/discord/callback
```
## Project Structure
```
jiggablend/
├── cmd/
── manager/ # Manager server application
└── runner/ # Runner client application
── jiggablend/ # Unified CLI application
├── cmd/ # Cobra command definitions
│ └── main.go # Entry point
├── internal/
│ ├── api/ # REST API handlers
│ ├── auth/ # OAuth authentication
│ ├── database/ # DuckDB database models and migrations
│ ├── queue/ # Job queue management
│ ├── storage/ # File storage operations
── runner/ # Runner management logic
│ ├── auth/ # Authentication (OAuth, local, sessions)
│ ├── config/ # Configuration management
│ ├── database/ # SQLite database models and migrations
│ ├── logger/ # Logging utilities
│ ├── manager/ # Manager server logic
── runner/ # Runner client logic
│ │ ├── api/ # Manager API client
│ │ ├── blender/ # Blender version detection
│ │ ├── encoding/ # Video encoding (H.264, AV1, VP9)
│ │ ├── tasks/ # Task execution (render, encode, process)
│ │ └── workspace/ # Workspace management
│ └── storage/ # File storage operations
├── pkg/
── types/ # Shared types and models
├── web/ # Static web UI files
── executils/ # Execution utilities
│ ├── scripts/ # Python scripts for Blender
│ └── types/ # Shared types and models
├── web/ # React web UI
│ ├── src/ # Source files
│ └── dist/ # Built files (embedded in binary)
├── go.mod
└── Makefile
```
@@ -156,27 +231,106 @@ jiggablend/
- `GET /api/auth/google/callback` - Google OAuth callback
- `GET /api/auth/discord/login` - Initiate Discord OAuth
- `GET /api/auth/discord/callback` - Discord OAuth callback
- `POST /api/auth/login` - Local authentication login
- `POST /api/auth/register` - User registration (if enabled)
- `POST /api/auth/logout` - Logout
- `GET /api/auth/me` - Get current user
- `POST /api/auth/password/change` - Change password
### Jobs
- `POST /api/jobs` - Create a new job
- `GET /api/jobs` - List user's jobs
- `GET /api/jobs/{id}` - Get job details
- `DELETE /api/jobs/{id}` - Cancel a job
- `POST /api/jobs/{id}/upload` - Upload job file
- `POST /api/jobs/{id}/upload` - Upload job file (Blender file)
- `GET /api/jobs/{id}/files` - List job files
- `GET /api/jobs/{id}/files/{fileId}/download` - Download job file
- `GET /api/jobs/{id}/metadata` - Extract metadata from uploaded file
- `GET /api/jobs/{id}/outputs` - List job output files
### Runners
- `GET /api/admin/runners` - List all runners (admin only)
- `POST /api/runner/register` - Register a runner (uses registration token)
- `POST /api/runner/heartbeat` - Update runner heartbeat (runner authenticated)
### Blender
- `GET /api/blender/versions` - List available Blender versions
### Runners (Internal API)
- `POST /api/runner/register` - Register a runner (uses API key)
- `POST /api/runner/heartbeat` - Update runner heartbeat
- `GET /api/runner/tasks` - Get pending tasks for runner
- `POST /api/runner/tasks/{id}/complete` - Mark task as complete
- `GET /api/runner/files/{jobId}/{fileName}` - Download file for runner
- `POST /api/runner/files/{jobId}/upload` - Upload file from runner
### Admin (Admin Only)
- `GET /api/admin/runners` - List all runners
- `GET /api/admin/jobs` - List all jobs
- `GET /api/admin/users` - List all users
- `GET /api/admin/stats` - System statistics
### WebSocket
- `WS /api/ws` - WebSocket connection for real-time updates
- Subscribe to job channels: `job:{jobId}`
- Receive job status updates, progress, and logs
## Output Formats
The system supports the following output formats:
### Image Formats
- **PNG** - Standard PNG output
- **JPEG** - JPEG output
- **EXR** - OpenEXR format (HDR)
### Video Formats
- **EXR_264_MP4** - H.264 encoded MP4 from EXR sequence (SDR or HDR)
- **EXR_AV1_MP4** - AV1 encoded MP4 from EXR sequence (with alpha channel support)
- **EXR_VP9_WEBM** - VP9 encoded WebM from EXR sequence (with alpha channel and HDR support)
Video encoding features:
- 2-pass encoding for optimal quality
- HDR preservation using HLG transfer function
- Alpha channel preservation (AV1 and VP9 only)
- Automatic detection of source format (EXR or PNG)
- Software encoding (libx264, libaom-av1, libvpx-vp9)
## Storage Structure
The manager uses a local storage directory (default: `./jiggablend-storage`) with the following structure:
```
jiggablend-storage/
├── blender-versions/ # Bundled Blender versions
│ └── <version>/
├── jobs/ # Job context files
│ └── <job-id>/
│ └── context.tar
├── outputs/ # Rendered outputs
│ └── <job-id>/
├── temp/ # Temporary files
└── uploads/ # Uploaded files
```
## Development
### Running Tests
```bash
make test
# Or directly
go test ./... -timeout 30s
```
### Web UI Development
The web UI is built with React and Vite. To develop the UI:
```bash
cd web
npm install
npm run dev # Development server
npm run build # Build for production
```
The built files are embedded in the Go binary using `embed.FS`.
## License
MIT

View File

@@ -0,0 +1,177 @@
package cmd
import (
"fmt"
"net/http"
"os/exec"
"strings"
"jiggablend/internal/auth"
"jiggablend/internal/config"
"jiggablend/internal/database"
"jiggablend/internal/logger"
manager "jiggablend/internal/manager"
"jiggablend/internal/storage"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
var managerCmd = &cobra.Command{
Use: "manager",
Short: "Start the Jiggablend manager server",
Long: `Start the Jiggablend manager server to coordinate render jobs.`,
Run: runManager,
}
func init() {
rootCmd.AddCommand(managerCmd)
// Flags with env binding via viper
managerCmd.Flags().StringP("port", "p", "8080", "Server port")
managerCmd.Flags().String("db", "jiggablend.db", "Database path")
managerCmd.Flags().String("storage", "./jiggablend-storage", "Storage path")
managerCmd.Flags().StringP("log-file", "l", "", "Log file path (truncated on start, if not set logs only to stdout)")
managerCmd.Flags().String("log-level", "info", "Log level (debug, info, warn, error)")
managerCmd.Flags().BoolP("verbose", "v", false, "Enable verbose logging (same as --log-level=debug)")
// Bind flags to viper with JIGGABLEND_ prefix
viper.SetEnvPrefix("JIGGABLEND")
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
viper.AutomaticEnv()
viper.BindPFlag("port", managerCmd.Flags().Lookup("port"))
viper.BindPFlag("db", managerCmd.Flags().Lookup("db"))
viper.BindPFlag("storage", managerCmd.Flags().Lookup("storage"))
viper.BindPFlag("log_file", managerCmd.Flags().Lookup("log-file"))
viper.BindPFlag("log_level", managerCmd.Flags().Lookup("log-level"))
viper.BindPFlag("verbose", managerCmd.Flags().Lookup("verbose"))
}
func runManager(cmd *cobra.Command, args []string) {
// Get config values (flags take precedence over env vars)
port := viper.GetString("port")
dbPath := viper.GetString("db")
storagePath := viper.GetString("storage")
logFile := viper.GetString("log_file")
logLevel := viper.GetString("log_level")
verbose := viper.GetBool("verbose")
// Initialize logger
if logFile != "" {
if err := logger.InitWithFile(logFile); err != nil {
logger.Fatalf("Failed to initialize logger: %v", err)
}
defer func() {
if l := logger.GetDefault(); l != nil {
l.Close()
}
}()
} else {
logger.InitStdout()
}
// Set log level
if verbose {
logger.SetLevel(logger.LevelDebug)
} else {
logger.SetLevel(logger.ParseLevel(logLevel))
}
if logFile != "" {
logger.Infof("Logging to file: %s", logFile)
}
logger.Debugf("Log level: %s", logLevel)
// Initialize database
db, err := database.NewDB(dbPath)
if err != nil {
logger.Fatalf("Failed to initialize database: %v", err)
}
defer db.Close()
// Initialize config from database
cfg := config.NewConfig(db)
if err := cfg.InitializeFromEnv(); err != nil {
logger.Fatalf("Failed to initialize config: %v", err)
}
logger.Info("Configuration loaded from database")
// Initialize auth
authHandler, err := auth.NewAuth(db, cfg)
if err != nil {
logger.Fatalf("Failed to initialize auth: %v", err)
}
// Initialize storage
storageHandler, err := storage.NewStorage(storagePath)
if err != nil {
logger.Fatalf("Failed to initialize storage: %v", err)
}
// Check if Blender is available
if err := checkBlenderAvailable(); err != nil {
logger.Fatalf("Blender is not available: %v\n"+
"The manager requires Blender to be installed and in PATH for metadata extraction.\n"+
"Please install Blender and ensure it's accessible via the 'blender' command.", err)
}
logger.Info("Blender is available")
// Check if ImageMagick is available
if err := checkImageMagickAvailable(); err != nil {
logger.Fatalf("ImageMagick is not available: %v\n"+
"The manager requires ImageMagick to be installed and in PATH for EXR preview conversion.\n"+
"Please install ImageMagick and ensure 'magick' or 'convert' command is accessible.", err)
}
logger.Info("ImageMagick is available")
// Create manager server
server, err := manager.NewManager(db, cfg, authHandler, storageHandler)
if err != nil {
logger.Fatalf("Failed to create server: %v", err)
}
// Start server
addr := fmt.Sprintf(":%s", port)
logger.Infof("Starting manager server on %s", addr)
logger.Infof("Database: %s", dbPath)
logger.Infof("Storage: %s", storagePath)
httpServer := &http.Server{
Addr: addr,
Handler: server,
MaxHeaderBytes: 1 << 20,
ReadTimeout: 0,
WriteTimeout: 0,
}
if err := httpServer.ListenAndServe(); err != nil {
logger.Fatalf("Server failed: %v", err)
}
}
func checkBlenderAvailable() error {
cmd := exec.Command("blender", "--version")
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to run 'blender --version': %w (output: %s)", err, string(output))
}
return nil
}
func checkImageMagickAvailable() error {
// Try 'magick' first (ImageMagick 7+)
cmd := exec.Command("magick", "--version")
output, err := cmd.CombinedOutput()
if err == nil {
return nil
}
// Fall back to 'convert' (ImageMagick 6 or legacy mode)
cmd = exec.Command("convert", "--version")
output, err = cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to run 'magick --version' or 'convert --version': %w (output: %s)", err, string(output))
}
return nil
}

View File

@@ -0,0 +1,621 @@
package cmd
import (
"bufio"
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/hex"
"fmt"
"os"
"strings"
"jiggablend/internal/config"
"jiggablend/internal/database"
"github.com/spf13/cobra"
"golang.org/x/crypto/bcrypt"
)
var (
configDBPath string
configYes bool // Auto-confirm prompts
configForce bool // Force override existing
)
var configCmd = &cobra.Command{
Use: "config",
Short: "Configure the manager",
Long: `Configure the Jiggablend manager settings stored in the database.`,
}
// --- Enable/Disable commands ---
var enableCmd = &cobra.Command{
Use: "enable",
Short: "Enable a feature",
}
var disableCmd = &cobra.Command{
Use: "disable",
Short: "Disable a feature",
}
var enableLocalAuthCmd = &cobra.Command{
Use: "localauth",
Short: "Enable local authentication",
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
if err := cfg.SetBool(config.KeyEnableLocalAuth, true); err != nil {
exitWithError("Failed to enable local auth: %v", err)
}
fmt.Println("Local authentication enabled")
})
},
}
var disableLocalAuthCmd = &cobra.Command{
Use: "localauth",
Short: "Disable local authentication",
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
if err := cfg.SetBool(config.KeyEnableLocalAuth, false); err != nil {
exitWithError("Failed to disable local auth: %v", err)
}
fmt.Println("Local authentication disabled")
})
},
}
var enableRegistrationCmd = &cobra.Command{
Use: "registration",
Short: "Enable user registration",
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
if err := cfg.SetBool(config.KeyRegistrationEnabled, true); err != nil {
exitWithError("Failed to enable registration: %v", err)
}
fmt.Println("User registration enabled")
})
},
}
var disableRegistrationCmd = &cobra.Command{
Use: "registration",
Short: "Disable user registration",
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
if err := cfg.SetBool(config.KeyRegistrationEnabled, false); err != nil {
exitWithError("Failed to disable registration: %v", err)
}
fmt.Println("User registration disabled")
})
},
}
var enableProductionCmd = &cobra.Command{
Use: "production",
Short: "Enable production mode",
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
if err := cfg.SetBool(config.KeyProductionMode, true); err != nil {
exitWithError("Failed to enable production mode: %v", err)
}
fmt.Println("Production mode enabled")
})
},
}
var disableProductionCmd = &cobra.Command{
Use: "production",
Short: "Disable production mode",
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
if err := cfg.SetBool(config.KeyProductionMode, false); err != nil {
exitWithError("Failed to disable production mode: %v", err)
}
fmt.Println("Production mode disabled")
})
},
}
// --- Add commands ---
var addCmd = &cobra.Command{
Use: "add",
Short: "Add a resource",
}
var (
addUserName string
addUserAdmin bool
)
var addUserCmd = &cobra.Command{
Use: "user <email> <password>",
Short: "Add a local user",
Long: `Add a new local user account to the database.`,
Args: cobra.ExactArgs(2),
Run: func(cmd *cobra.Command, args []string) {
email := args[0]
password := args[1]
name := addUserName
if name == "" {
// Use email prefix as name
if atIndex := strings.Index(email, "@"); atIndex > 0 {
name = email[:atIndex]
} else {
name = email
}
}
if len(password) < 8 {
exitWithError("Password must be at least 8 characters")
}
withConfig(func(cfg *config.Config, db *database.DB) {
// Check if user exists
var exists bool
err := db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ?)", email).Scan(&exists)
})
if err != nil {
exitWithError("Failed to check user: %v", err)
}
isAdmin := addUserAdmin
if exists {
if !configForce {
exitWithError("User with email %s already exists (use -f to override)", email)
}
// Confirm override
if !configYes && !confirm(fmt.Sprintf("User %s already exists. Override?", email)) {
fmt.Println("Aborted")
return
}
// Update existing user
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
exitWithError("Failed to hash password: %v", err)
}
err = db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
"UPDATE users SET name = ?, password_hash = ?, is_admin = ? WHERE email = ?",
name, string(hashedPassword), isAdmin, email,
)
return err
})
if err != nil {
exitWithError("Failed to update user: %v", err)
}
fmt.Printf("Updated user: %s (admin: %v)\n", email, isAdmin)
return
}
// Hash password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
exitWithError("Failed to hash password: %v", err)
}
// Check if first user (make admin)
var userCount int
db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
})
if userCount == 0 {
isAdmin = true
}
// Confirm creation
if !configYes && !confirm(fmt.Sprintf("Create user %s (admin: %v)?", email, isAdmin)) {
fmt.Println("Aborted")
return
}
// Create user
err = db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?)",
email, name, email, string(hashedPassword), isAdmin,
)
return err
})
if err != nil {
exitWithError("Failed to create user: %v", err)
}
fmt.Printf("Created user: %s (admin: %v)\n", email, isAdmin)
})
},
}
var addAPIKeyScope string
var addAPIKeyCmd = &cobra.Command{
Use: "apikey [name]",
Short: "Add a runner API key",
Long: `Generate a new API key for runner authentication.`,
Args: cobra.MaximumNArgs(1),
Run: func(cmd *cobra.Command, args []string) {
name := "cli-generated"
if len(args) > 0 {
name = args[0]
}
withConfig(func(cfg *config.Config, db *database.DB) {
// Check if API key with same name exists
var exists bool
err := db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM runner_api_keys WHERE name = ?)", name).Scan(&exists)
})
if err != nil {
exitWithError("Failed to check API key: %v", err)
}
if exists {
if !configForce {
exitWithError("API key with name %s already exists (use -f to create another)", name)
}
if !configYes && !confirm(fmt.Sprintf("API key named '%s' already exists. Create another?", name)) {
fmt.Println("Aborted")
return
}
}
// Confirm creation
if !configYes && !confirm(fmt.Sprintf("Generate new API key '%s' (scope: %s)?", name, addAPIKeyScope)) {
fmt.Println("Aborted")
return
}
// Generate API key
key, keyPrefix, keyHash, err := generateAPIKey()
if err != nil {
exitWithError("Failed to generate API key: %v", err)
}
// Get first user ID for created_by (or use 0 if no users)
var createdBy int64
db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT id FROM users ORDER BY id ASC LIMIT 1").Scan(&createdBy)
})
// Store in database
err = db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`INSERT INTO runner_api_keys (key_prefix, key_hash, name, scope, is_active, created_by)
VALUES (?, ?, ?, ?, true, ?)`,
keyPrefix, keyHash, name, addAPIKeyScope, createdBy,
)
return err
})
if err != nil {
exitWithError("Failed to store API key: %v", err)
}
fmt.Printf("Generated API key: %s\n", key)
fmt.Printf("Name: %s, Scope: %s\n", name, addAPIKeyScope)
fmt.Println("\nSave this key - it cannot be retrieved later!")
})
},
}
// --- Set commands ---
var setCmd = &cobra.Command{
Use: "set",
Short: "Set a configuration value",
}
var setFixedAPIKeyCmd = &cobra.Command{
Use: "fixed-apikey [key]",
Short: "Set a fixed API key for testing",
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
// Check if already set
existing := cfg.FixedAPIKey()
if existing != "" && !configForce {
exitWithError("Fixed API key already set (use -f to override)")
}
if existing != "" && !configYes && !confirm("Fixed API key already set. Override?") {
fmt.Println("Aborted")
return
}
if err := cfg.Set(config.KeyFixedAPIKey, args[0]); err != nil {
exitWithError("Failed to set fixed API key: %v", err)
}
fmt.Println("Fixed API key set")
})
},
}
var setAllowedOriginsCmd = &cobra.Command{
Use: "allowed-origins [origins]",
Short: "Set allowed CORS origins (comma-separated)",
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
if err := cfg.Set(config.KeyAllowedOrigins, args[0]); err != nil {
exitWithError("Failed to set allowed origins: %v", err)
}
fmt.Printf("Allowed origins set to: %s\n", args[0])
})
},
}
var setGoogleOAuthRedirectURL string
var setGoogleOAuthCmd = &cobra.Command{
Use: "google-oauth <client-id> <client-secret>",
Short: "Set Google OAuth credentials",
Args: cobra.ExactArgs(2),
Run: func(cmd *cobra.Command, args []string) {
clientID := args[0]
clientSecret := args[1]
withConfig(func(cfg *config.Config, db *database.DB) {
// Check if already configured
existing := cfg.GoogleClientID()
if existing != "" && !configForce {
exitWithError("Google OAuth already configured (use -f to override)")
}
if existing != "" && !configYes && !confirm("Google OAuth already configured. Override?") {
fmt.Println("Aborted")
return
}
if err := cfg.Set(config.KeyGoogleClientID, clientID); err != nil {
exitWithError("Failed to set Google client ID: %v", err)
}
if err := cfg.Set(config.KeyGoogleClientSecret, clientSecret); err != nil {
exitWithError("Failed to set Google client secret: %v", err)
}
if setGoogleOAuthRedirectURL != "" {
if err := cfg.Set(config.KeyGoogleRedirectURL, setGoogleOAuthRedirectURL); err != nil {
exitWithError("Failed to set Google redirect URL: %v", err)
}
}
fmt.Println("Google OAuth configured")
})
},
}
var setDiscordOAuthRedirectURL string
var setDiscordOAuthCmd = &cobra.Command{
Use: "discord-oauth <client-id> <client-secret>",
Short: "Set Discord OAuth credentials",
Args: cobra.ExactArgs(2),
Run: func(cmd *cobra.Command, args []string) {
clientID := args[0]
clientSecret := args[1]
withConfig(func(cfg *config.Config, db *database.DB) {
// Check if already configured
existing := cfg.DiscordClientID()
if existing != "" && !configForce {
exitWithError("Discord OAuth already configured (use -f to override)")
}
if existing != "" && !configYes && !confirm("Discord OAuth already configured. Override?") {
fmt.Println("Aborted")
return
}
if err := cfg.Set(config.KeyDiscordClientID, clientID); err != nil {
exitWithError("Failed to set Discord client ID: %v", err)
}
if err := cfg.Set(config.KeyDiscordClientSecret, clientSecret); err != nil {
exitWithError("Failed to set Discord client secret: %v", err)
}
if setDiscordOAuthRedirectURL != "" {
if err := cfg.Set(config.KeyDiscordRedirectURL, setDiscordOAuthRedirectURL); err != nil {
exitWithError("Failed to set Discord redirect URL: %v", err)
}
}
fmt.Println("Discord OAuth configured")
})
},
}
// --- Show command ---
var showCmd = &cobra.Command{
Use: "show",
Short: "Show current configuration",
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
all, err := cfg.GetAll()
if err != nil {
exitWithError("Failed to get config: %v", err)
}
if len(all) == 0 {
fmt.Println("No configuration stored")
return
}
fmt.Println("Current configuration:")
fmt.Println("----------------------")
for key, value := range all {
// Redact sensitive values
if strings.Contains(key, "secret") || strings.Contains(key, "api_key") || strings.Contains(key, "password") {
fmt.Printf(" %s: [REDACTED]\n", key)
} else {
fmt.Printf(" %s: %s\n", key, value)
}
}
})
},
}
// --- List commands ---
var listCmd = &cobra.Command{
Use: "list",
Short: "List resources",
}
var listUsersCmd = &cobra.Command{
Use: "users",
Short: "List all users",
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
var rows *sql.Rows
err := db.With(func(conn *sql.DB) error {
var err error
rows, err = conn.Query("SELECT id, email, name, oauth_provider, is_admin, created_at FROM users ORDER BY id")
return err
})
if err != nil {
exitWithError("Failed to list users: %v", err)
}
defer rows.Close()
fmt.Printf("%-6s %-30s %-20s %-10s %-6s %s\n", "ID", "Email", "Name", "Provider", "Admin", "Created")
fmt.Println(strings.Repeat("-", 100))
for rows.Next() {
var id int64
var email, name, provider string
var isAdmin bool
var createdAt string
if err := rows.Scan(&id, &email, &name, &provider, &isAdmin, &createdAt); err != nil {
continue
}
adminStr := "no"
if isAdmin {
adminStr = "yes"
}
fmt.Printf("%-6d %-30s %-20s %-10s %-6s %s\n", id, email, name, provider, adminStr, createdAt[:19])
}
})
},
}
var listAPIKeysCmd = &cobra.Command{
Use: "apikeys",
Short: "List all API keys",
Run: func(cmd *cobra.Command, args []string) {
withConfig(func(cfg *config.Config, db *database.DB) {
var rows *sql.Rows
err := db.With(func(conn *sql.DB) error {
var err error
rows, err = conn.Query("SELECT id, key_prefix, name, scope, is_active, created_at FROM runner_api_keys ORDER BY id")
return err
})
if err != nil {
exitWithError("Failed to list API keys: %v", err)
}
defer rows.Close()
fmt.Printf("%-6s %-12s %-20s %-10s %-8s %s\n", "ID", "Prefix", "Name", "Scope", "Active", "Created")
fmt.Println(strings.Repeat("-", 80))
for rows.Next() {
var id int64
var prefix, name, scope string
var isActive bool
var createdAt string
if err := rows.Scan(&id, &prefix, &name, &scope, &isActive, &createdAt); err != nil {
continue
}
activeStr := "no"
if isActive {
activeStr = "yes"
}
fmt.Printf("%-6d %-12s %-20s %-10s %-8s %s\n", id, prefix, name, scope, activeStr, createdAt[:19])
}
})
},
}
func init() {
managerCmd.AddCommand(configCmd)
// Global config flags
configCmd.PersistentFlags().StringVar(&configDBPath, "db", "jiggablend.db", "Database path")
configCmd.PersistentFlags().BoolVarP(&configYes, "yes", "y", false, "Auto-confirm prompts")
configCmd.PersistentFlags().BoolVarP(&configForce, "force", "f", false, "Force override existing")
// Enable/Disable
configCmd.AddCommand(enableCmd)
configCmd.AddCommand(disableCmd)
enableCmd.AddCommand(enableLocalAuthCmd)
enableCmd.AddCommand(enableRegistrationCmd)
enableCmd.AddCommand(enableProductionCmd)
disableCmd.AddCommand(disableLocalAuthCmd)
disableCmd.AddCommand(disableRegistrationCmd)
disableCmd.AddCommand(disableProductionCmd)
// Add
configCmd.AddCommand(addCmd)
addCmd.AddCommand(addUserCmd)
addUserCmd.Flags().StringVarP(&addUserName, "name", "n", "", "User display name")
addUserCmd.Flags().BoolVarP(&addUserAdmin, "admin", "a", false, "Make user an admin")
addCmd.AddCommand(addAPIKeyCmd)
addAPIKeyCmd.Flags().StringVarP(&addAPIKeyScope, "scope", "s", "manager", "API key scope (manager or user)")
// Set
configCmd.AddCommand(setCmd)
setCmd.AddCommand(setFixedAPIKeyCmd)
setCmd.AddCommand(setAllowedOriginsCmd)
setCmd.AddCommand(setGoogleOAuthCmd)
setCmd.AddCommand(setDiscordOAuthCmd)
setGoogleOAuthCmd.Flags().StringVarP(&setGoogleOAuthRedirectURL, "redirect-url", "r", "", "Google OAuth redirect URL")
setDiscordOAuthCmd.Flags().StringVarP(&setDiscordOAuthRedirectURL, "redirect-url", "r", "", "Discord OAuth redirect URL")
// Show
configCmd.AddCommand(showCmd)
// List
configCmd.AddCommand(listCmd)
listCmd.AddCommand(listUsersCmd)
listCmd.AddCommand(listAPIKeysCmd)
}
// withConfig opens the database and runs the callback with config access
func withConfig(fn func(cfg *config.Config, db *database.DB)) {
db, err := database.NewDB(configDBPath)
if err != nil {
exitWithError("Failed to open database: %v", err)
}
defer db.Close()
cfg := config.NewConfig(db)
fn(cfg, db)
}
// generateAPIKey generates a new API key
func generateAPIKey() (key, prefix, hash string, err error) {
randomBytes := make([]byte, 16)
if _, err := rand.Read(randomBytes); err != nil {
return "", "", "", err
}
randomStr := hex.EncodeToString(randomBytes)
prefixDigit := make([]byte, 1)
if _, err := rand.Read(prefixDigit); err != nil {
return "", "", "", err
}
prefix = fmt.Sprintf("jk_r%d", prefixDigit[0]%10)
key = fmt.Sprintf("%s_%s", prefix, randomStr)
keyHash := sha256.Sum256([]byte(key))
hash = hex.EncodeToString(keyHash[:])
return key, prefix, hash, nil
}
// confirm prompts the user for confirmation
func confirm(prompt string) bool {
fmt.Printf("%s [y/N]: ", prompt)
reader := bufio.NewReader(os.Stdin)
response, err := reader.ReadString('\n')
if err != nil {
return false
}
response = strings.TrimSpace(strings.ToLower(response))
return response == "y" || response == "yes"
}

View File

@@ -0,0 +1,34 @@
package cmd
import (
"fmt"
"os"
"github.com/spf13/cobra"
)
var rootCmd = &cobra.Command{
Use: "jiggablend",
Short: "Jiggablend - Distributed Blender Render Farm",
Long: `Jiggablend is a distributed render farm for Blender.
Run 'jiggablend manager' to start the manager server.
Run 'jiggablend runner' to start a render runner.
Run 'jiggablend manager config' to configure the manager.`,
}
// Execute runs the root command
func Execute() error {
return rootCmd.Execute()
}
func init() {
// Global flags can be added here if needed
rootCmd.CompletionOptions.DisableDefaultCmd = true
}
// exitWithError prints an error and exits
func exitWithError(msg string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "Error: "+msg+"\n", args...)
os.Exit(1)
}

View File

@@ -0,0 +1,208 @@
package cmd
import (
"crypto/rand"
"encoding/hex"
"fmt"
"os"
"os/signal"
"strings"
"syscall"
"time"
"jiggablend/internal/logger"
"jiggablend/internal/runner"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
var runnerViper = viper.New()
var runnerCmd = &cobra.Command{
Use: "runner",
Short: "Start the Jiggablend render runner",
Long: `Start the Jiggablend render runner that connects to a manager and processes render tasks.`,
Run: runRunner,
}
func init() {
rootCmd.AddCommand(runnerCmd)
runnerCmd.Flags().StringP("manager", "m", "http://localhost:8080", "Manager URL")
runnerCmd.Flags().StringP("name", "n", "", "Runner name")
runnerCmd.Flags().String("hostname", "", "Runner hostname")
runnerCmd.Flags().StringP("api-key", "k", "", "API key for authentication")
runnerCmd.Flags().StringP("log-file", "l", "", "Log file path (truncated on start, if not set logs only to stdout)")
runnerCmd.Flags().String("log-level", "info", "Log level (debug, info, warn, error)")
runnerCmd.Flags().BoolP("verbose", "v", false, "Enable verbose logging (same as --log-level=debug)")
runnerCmd.Flags().Duration("poll-interval", 5*time.Second, "Job polling interval")
// Bind flags to viper with JIGGABLEND_ prefix
runnerViper.SetEnvPrefix("JIGGABLEND")
runnerViper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
runnerViper.AutomaticEnv()
runnerViper.BindPFlag("manager", runnerCmd.Flags().Lookup("manager"))
runnerViper.BindPFlag("name", runnerCmd.Flags().Lookup("name"))
runnerViper.BindPFlag("hostname", runnerCmd.Flags().Lookup("hostname"))
runnerViper.BindPFlag("api_key", runnerCmd.Flags().Lookup("api-key"))
runnerViper.BindPFlag("log_file", runnerCmd.Flags().Lookup("log-file"))
runnerViper.BindPFlag("log_level", runnerCmd.Flags().Lookup("log-level"))
runnerViper.BindPFlag("verbose", runnerCmd.Flags().Lookup("verbose"))
runnerViper.BindPFlag("poll_interval", runnerCmd.Flags().Lookup("poll-interval"))
}
func runRunner(cmd *cobra.Command, args []string) {
// Get config values (flags take precedence over env vars)
managerURL := runnerViper.GetString("manager")
name := runnerViper.GetString("name")
hostname := runnerViper.GetString("hostname")
apiKey := runnerViper.GetString("api_key")
logFile := runnerViper.GetString("log_file")
logLevel := runnerViper.GetString("log_level")
verbose := runnerViper.GetBool("verbose")
pollInterval := runnerViper.GetDuration("poll_interval")
var r *runner.Runner
defer func() {
if rec := recover(); rec != nil {
logger.Errorf("Runner panicked: %v", rec)
if r != nil {
r.Cleanup()
}
os.Exit(1)
}
}()
if hostname == "" {
hostname, _ = os.Hostname()
}
// Generate unique runner ID suffix
runnerIDStr := generateShortID()
// Generate runner name with ID if not provided
if name == "" {
name = fmt.Sprintf("runner-%s-%s", hostname, runnerIDStr)
} else {
name = fmt.Sprintf("%s-%s", name, runnerIDStr)
}
// Initialize logger
if logFile != "" {
if err := logger.InitWithFile(logFile); err != nil {
logger.Fatalf("Failed to initialize logger: %v", err)
}
defer func() {
if l := logger.GetDefault(); l != nil {
l.Close()
}
}()
} else {
logger.InitStdout()
}
// Set log level
if verbose {
logger.SetLevel(logger.LevelDebug)
} else {
logger.SetLevel(logger.ParseLevel(logLevel))
}
logger.Info("Runner starting up...")
logger.Debugf("Generated runner ID suffix: %s", runnerIDStr)
if logFile != "" {
logger.Infof("Logging to file: %s", logFile)
}
// Create runner
r = runner.New(managerURL, name, hostname)
// Check for required tools early to fail fast
if err := r.CheckRequiredTools(); err != nil {
logger.Fatalf("Required tool check failed: %v", err)
}
// Clean up orphaned workspace directories
r.Cleanup()
// Probe capabilities and log them
logger.Debug("Probing runner capabilities...")
capabilities := r.ProbeCapabilities()
capList := []string{}
for cap, value := range capabilities {
if enabled, ok := value.(bool); ok && enabled {
capList = append(capList, cap)
}
}
if len(capList) > 0 {
logger.Infof("Detected capabilities: %s", strings.Join(capList, ", "))
} else {
logger.Warn("No capabilities detected")
}
// Register with API key
if apiKey == "" {
logger.Fatal("API key required (use --api-key or set JIGGABLEND_API_KEY env var)")
}
// Retry registration with exponential backoff
backoff := 1 * time.Second
maxBackoff := 30 * time.Second
maxRetries := 10
retryCount := 0
var runnerID int64
for {
var err error
runnerID, err = r.Register(apiKey)
if err == nil {
logger.Infof("Registered runner with ID: %d", runnerID)
break
}
errMsg := err.Error()
if strings.Contains(errMsg, "token error:") {
logger.Fatalf("Registration failed (token error): %v", err)
}
retryCount++
if retryCount >= maxRetries {
logger.Fatalf("Failed to register runner after %d attempts: %v", maxRetries, err)
}
logger.Warnf("Registration failed (attempt %d/%d): %v, retrying in %v", retryCount, maxRetries, err, backoff)
time.Sleep(backoff)
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
}
// Signal handlers
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
go func() {
sig := <-sigChan
logger.Infof("Received signal: %v, killing all processes and cleaning up...", sig)
r.KillAllProcesses()
r.Cleanup()
os.Exit(0)
}()
// Start polling for jobs
logger.Infof("Runner started, polling for jobs (interval: %v)...", pollInterval)
r.Start(pollInterval)
}
func generateShortID() string {
bytes := make([]byte, 4)
if _, err := rand.Read(bytes); err != nil {
return fmt.Sprintf("%x", os.Getpid()^int(time.Now().Unix()))
}
return hex.EncodeToString(bytes)
}

View File

@@ -0,0 +1,25 @@
package cmd
import (
"fmt"
"jiggablend/version"
"github.com/spf13/cobra"
)
var versionCmd = &cobra.Command{
Use: "version",
Short: "Print the version information",
Long: `Print the version and build date of jiggablend.`,
Run: func(cmd *cobra.Command, args []string) {
fmt.Printf("jiggablend version %s\n", version.Version)
if version.Date != "" {
fmt.Printf("Build date: %s\n", version.Date)
}
},
}
func init() {
rootCmd.AddCommand(versionCmd)
}

14
cmd/jiggablend/main.go Normal file
View File

@@ -0,0 +1,14 @@
package main
import (
"os"
"jiggablend/cmd/jiggablend/cmd"
_ "jiggablend/version"
)
func main() {
if err := cmd.Execute(); err != nil {
os.Exit(1)
}
}

View File

@@ -1,65 +0,0 @@
package main
import (
"flag"
"fmt"
"log"
"net/http"
"os"
"jiggablend/internal/api"
"jiggablend/internal/auth"
"jiggablend/internal/database"
"jiggablend/internal/storage"
)
func main() {
var (
port = flag.String("port", getEnv("PORT", "8080"), "Server port")
dbPath = flag.String("db", getEnv("DB_PATH", "jiggablend.db"), "Database path")
storagePath = flag.String("storage", getEnv("STORAGE_PATH", "./jiggablend-storage"), "Storage path")
)
flag.Parse()
// Initialize database
db, err := database.NewDB(*dbPath)
if err != nil {
log.Fatalf("Failed to initialize database: %v", err)
}
defer db.Close()
// Initialize auth
authHandler, err := auth.NewAuth(db.DB)
if err != nil {
log.Fatalf("Failed to initialize auth: %v", err)
}
// Initialize storage
storageHandler, err := storage.NewStorage(*storagePath)
if err != nil {
log.Fatalf("Failed to initialize storage: %v", err)
}
// Create API server
server, err := api.NewServer(db, authHandler, storageHandler)
if err != nil {
log.Fatalf("Failed to create server: %v", err)
}
// Start server
addr := fmt.Sprintf(":%s", *port)
log.Printf("Starting manager server on %s", addr)
log.Printf("Database: %s", *dbPath)
log.Printf("Storage: %s", *storagePath)
if err := http.ListenAndServe(addr, server); err != nil {
log.Fatalf("Server failed: %v", err)
}
}
func getEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}

View File

@@ -1,221 +0,0 @@
package main
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"flag"
"fmt"
"log"
"os"
"os/signal"
"strings"
"syscall"
"time"
"jiggablend/internal/runner"
)
type SecretsFile struct {
RunnerID int64 `json:"runner_id"`
RunnerSecret string `json:"runner_secret"`
ManagerSecret string `json:"manager_secret"`
}
func main() {
var (
managerURL = flag.String("manager", getEnv("MANAGER_URL", "http://localhost:8080"), "Manager URL")
name = flag.String("name", getEnv("RUNNER_NAME", ""), "Runner name")
hostname = flag.String("hostname", getEnv("RUNNER_HOSTNAME", ""), "Runner hostname")
ipAddress = flag.String("ip", getEnv("RUNNER_IP", ""), "Runner IP address")
token = flag.String("token", getEnv("REGISTRATION_TOKEN", ""), "Registration token")
secretsFile = flag.String("secrets-file", getEnv("SECRETS_FILE", ""), "Path to secrets file for persistent storage (default: ./runner-secrets.json, or ./runner-secrets-{id}.json if multiple runners)")
runnerIDSuffix = flag.String("runner-id", getEnv("RUNNER_ID", ""), "Unique runner ID suffix (auto-generated if not provided)")
)
flag.Parse()
if *hostname == "" {
*hostname, _ = os.Hostname()
}
if *ipAddress == "" {
*ipAddress = "127.0.0.1"
}
// Generate or use provided runner ID suffix
runnerIDStr := *runnerIDSuffix
if runnerIDStr == "" {
runnerIDStr = generateShortID()
}
// Generate runner name with ID if not provided
if *name == "" {
*name = fmt.Sprintf("runner-%s-%s", *hostname, runnerIDStr)
} else {
// Append ID to provided name to ensure uniqueness
*name = fmt.Sprintf("%s-%s", *name, runnerIDStr)
}
// Set default secrets file if not provided - always use current directory
if *secretsFile == "" {
if *runnerIDSuffix != "" || getEnv("RUNNER_ID", "") != "" {
// Multiple runners - use local file with ID
*secretsFile = fmt.Sprintf("./runner-secrets-%s.json", runnerIDStr)
} else {
// Single runner - use local file
*secretsFile = "./runner-secrets.json"
}
}
client := runner.NewClient(*managerURL, *name, *hostname, *ipAddress)
// Probe capabilities once at startup (before any registration attempts)
log.Printf("Probing runner capabilities...")
client.ProbeCapabilities()
capabilities := client.GetCapabilities()
capList := []string{}
for cap, value := range capabilities {
// Only show boolean true capabilities and numeric GPU counts
if enabled, ok := value.(bool); ok && enabled {
capList = append(capList, cap)
} else if count, ok := value.(int); ok && count > 0 {
capList = append(capList, fmt.Sprintf("%s=%d", cap, count))
} else if count, ok := value.(float64); ok && count > 0 {
capList = append(capList, fmt.Sprintf("%s=%.0f", cap, count))
}
}
if len(capList) > 0 {
log.Printf("Detected capabilities: %s", strings.Join(capList, ", "))
} else {
log.Printf("Warning: No capabilities detected")
}
// Try to load secrets from file
var runnerID int64
var runnerSecret, managerSecret string
if *secretsFile != "" {
if secrets, err := loadSecrets(*secretsFile); err == nil {
runnerID = secrets.RunnerID
runnerSecret = secrets.RunnerSecret
managerSecret = secrets.ManagerSecret
client.SetSecrets(runnerID, runnerSecret, managerSecret)
log.Printf("Loaded secrets from %s", *secretsFile)
}
}
// If no secrets loaded, register with token (with retry logic)
if runnerID == 0 {
if *token == "" {
log.Fatalf("Registration token required (use --token or set REGISTRATION_TOKEN env var)")
}
// Retry registration with exponential backoff
backoff := 1 * time.Second
maxBackoff := 30 * time.Second
maxRetries := 10
retryCount := 0
for {
var err error
runnerID, runnerSecret, managerSecret, err = client.Register(*token)
if err == nil {
log.Printf("Registered runner with ID: %d", runnerID)
// Always save secrets to file (secretsFile is now always set to a default if not provided)
secrets := SecretsFile{
RunnerID: runnerID,
RunnerSecret: runnerSecret,
ManagerSecret: managerSecret,
}
if err := saveSecrets(*secretsFile, secrets); err != nil {
log.Printf("Warning: Failed to save secrets to %s: %v", *secretsFile, err)
} else {
log.Printf("Saved secrets to %s", *secretsFile)
}
break
}
// Check if it's a token error (invalid/expired/used token) - shutdown immediately
errMsg := err.Error()
if strings.Contains(errMsg, "token error:") {
log.Fatalf("Registration failed (token error): %v", err)
}
// Only retry on connection errors or other retryable errors
retryCount++
if retryCount >= maxRetries {
log.Fatalf("Failed to register runner after %d attempts: %v", maxRetries, err)
}
log.Printf("Registration failed (attempt %d/%d): %v, retrying in %v", retryCount, maxRetries, err, backoff)
time.Sleep(backoff)
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
}
}
// Start WebSocket connection with reconnection
go client.ConnectWebSocketWithReconnect()
// Start heartbeat loop (for WebSocket ping/pong and HTTP fallback)
go client.HeartbeatLoop()
// ProcessTasks is now handled via WebSocket, but kept for HTTP fallback
// WebSocket will handle task assignment automatically
log.Printf("Runner started, connecting to manager via WebSocket...")
// Set up signal handlers to kill processes on shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
go func() {
sig := <-sigChan
log.Printf("Received signal: %v, killing all processes and shutting down...", sig)
client.KillAllProcesses()
os.Exit(0)
}()
// Block forever
select {}
}
func loadSecrets(path string) (*SecretsFile, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var secrets SecretsFile
if err := json.Unmarshal(data, &secrets); err != nil {
return nil, err
}
return &secrets, nil
}
func saveSecrets(path string, secrets SecretsFile) error {
data, err := json.MarshalIndent(secrets, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0600)
}
func getEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
// generateShortID generates a short random ID (8 hex characters)
func generateShortID() string {
bytes := make([]byte, 4)
if _, err := rand.Read(bytes); err != nil {
// Fallback to timestamp-based ID if crypto/rand fails
return fmt.Sprintf("%x", os.Getpid()^int(time.Now().Unix()))
}
return hex.EncodeToString(bytes)
}

BIN
examples/frame_0800.exr Normal file

Binary file not shown.

BIN
examples/frame_0800.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 MiB

44
go.mod
View File

@@ -5,35 +5,33 @@ go 1.25.4
require (
github.com/go-chi/chi/v5 v5.2.3
github.com/go-chi/cors v1.2.2
github.com/golang-migrate/migrate/v4 v4.19.0
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/marcboeker/go-duckdb/v2 v2.4.3
github.com/mattn/go-sqlite3 v1.14.32
github.com/spf13/cobra v1.10.1
github.com/spf13/viper v1.21.0
golang.org/x/crypto v0.45.0
golang.org/x/oauth2 v0.33.0
)
require (
cloud.google.com/go/compute/metadata v0.3.0 // indirect
github.com/apache/arrow-go/v18 v18.4.1 // indirect
github.com/duckdb/duckdb-go-bindings v0.1.21 // indirect
github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.21 // indirect
github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.21 // indirect
github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.21 // indirect
github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.21 // indirect
github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.21 // indirect
cloud.google.com/go/compute/metadata v0.5.0 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/google/flatbuffers v25.2.10+incompatible // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/marcboeker/go-duckdb/arrowmapping v0.0.21 // indirect
github.com/marcboeker/go-duckdb/mapping v0.0.21 // indirect
github.com/pierrec/lz4/v4 v4.1.22 // indirect
github.com/zeebo/xxh3 v1.0.2 // indirect
golang.org/x/crypto v0.45.0 // indirect
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 // indirect
golang.org/x/mod v0.27.0 // indirect
golang.org/x/sync v0.16.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/sagikazarmark/locafero v0.11.0 // indirect
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
github.com/spf13/afero v1.15.0 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/tools v0.36.0 // indirect
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect
golang.org/x/text v0.31.0 // indirect
)

119
go.sum
View File

@@ -1,88 +1,79 @@
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/apache/arrow-go/v18 v18.4.1 h1:q/jVkBWCJOB9reDgaIZIdruLQUb1kbkvOnOFezVH1C4=
github.com/apache/arrow-go/v18 v18.4.1/go.mod h1:tLyFubsAl17bvFdUAy24bsSvA/6ww95Iqi67fTpGu3E=
github.com/apache/thrift v0.22.0 h1:r7mTJdj51TMDe6RtcmNdQxgn9XcyfGDOzegMDRg47uc=
github.com/apache/thrift v0.22.0/go.mod h1:1e7J/O1Ae6ZQMTYdy9xa3w9k+XHWPfRvdPyJeynQ+/g=
cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY=
cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/duckdb/duckdb-go-bindings v0.1.21 h1:bOb/MXNT4PN5JBZ7wpNg6hrj9+cuDjWDa4ee9UdbVyI=
github.com/duckdb/duckdb-go-bindings v0.1.21/go.mod h1:pBnfviMzANT/9hi4bg+zW4ykRZZPCXlVuvBWEcZofkc=
github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.21 h1:Sjjhf2F/zCjPF53c2VXOSKk0PzieMriSoyr5wfvr9d8=
github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.21/go.mod h1:Ezo7IbAfB8NP7CqPIN8XEHKUg5xdRRQhcPPlCXImXYA=
github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.21 h1:IUk0FFUB6dpWLhlN9hY1mmdPX7Hkn3QpyrAmn8pmS8g=
github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.21/go.mod h1:eS7m/mLnPQgVF4za1+xTyorKRBuK0/BA44Oy6DgrGXI=
github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.21 h1:Qpc7ZE3n6Nwz30KTvaAwI6nGkXjXmMxBTdFpC8zDEYI=
github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.21/go.mod h1:1GOuk1PixiESxLaCGFhag+oFi7aP+9W8byymRAvunBk=
github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.21 h1:eX2DhobAZOgjXkh8lPnKAyrxj8gXd2nm+K71f6KV/mo=
github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.21/go.mod h1:o7crKMpT2eOIi5/FY6HPqaXcvieeLSqdXXaXbruGX7w=
github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.21 h1:hhziFnGV7mpA+v5J5G2JnYQ+UWCCP3NQ+OTvxFX10D8=
github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.21/go.mod h1:IlOhJdVKUJCAPj3QsDszUo8DVdvp1nBFp4TUJVdw99s=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE=
github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
github.com/go-chi/cors v1.2.2 h1:Jmey33TE+b+rB7fT8MUy1u0I4L+NARQlK6LhzKPSyQE=
github.com/go-chi/cors v1.2.2/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/flatbuffers v25.2.10+incompatible h1:F3vclr7C3HpB1k9mxCGRMXq6FdUalZ6H/pNX4FP1v0Q=
github.com/google/flatbuffers v25.2.10+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
github.com/golang-migrate/migrate/v4 v4.19.0 h1:RcjOnCGz3Or6HQYEJ/EEVLfWnmw9KnoigPSjzhCuaSE=
github.com/golang-migrate/migrate/v4 v4.19.0/go.mod h1:9dyEcu+hO+G9hPSw8AIg50yg622pXJsoHItQnDGZkI0=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4=
github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/marcboeker/go-duckdb/arrowmapping v0.0.21 h1:geHnVjlsAJGczSWEqYigy/7ARuD+eBtjd0kLN80SPJQ=
github.com/marcboeker/go-duckdb/arrowmapping v0.0.21/go.mod h1:flFTc9MSqQCh2Xm62RYvG3Kyj29h7OtsTb6zUx1CdK8=
github.com/marcboeker/go-duckdb/mapping v0.0.21 h1:6woNXZn8EfYdc9Vbv0qR6acnt0TM1s1eFqnrJZVrqEs=
github.com/marcboeker/go-duckdb/mapping v0.0.21/go.mod h1:q3smhpLyv2yfgkQd7gGHMd+H/Z905y+WYIUjrl29vT4=
github.com/marcboeker/go-duckdb/v2 v2.4.3 h1:bHUkphPsAp2Bh/VFEdiprGpUekxBNZiWWtK+Bv/ljRk=
github.com/marcboeker/go-duckdb/v2 v2.4.3/go.mod h1:taim9Hktg2igHdNBmg5vgTfHAlV26z3gBI0QXQOcuyI=
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs=
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY=
github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI=
github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE=
github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU=
github.com/pierrec/lz4/v4 v4.1.22/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.0 h1:ib4sjIrwZKxE5u/Japgo/7SJV3PvgjGiRNAvTVGqQl8=
github.com/stretchr/testify v1.11.0/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ=
github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0=
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s=
github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0=
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM=
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8=
golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ=
golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc=
golang.org/x/oauth2 v0.33.0 h1:4Q+qn+E5z8gPRJfmRy7C2gGG3T4jIprK6aSYgTXGRpo=
golang.org/x/oauth2 v0.33.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg=
golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s=
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY=
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

File diff suppressed because it is too large Load Diff

View File

@@ -1,158 +0,0 @@
package api
import (
"database/sql"
"encoding/json"
"fmt"
"log"
"net/http"
"jiggablend/pkg/types"
)
// handleSubmitMetadata handles metadata submission from runner
func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) {
jobID, err := parseID(r, "jobId")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
// Get runner ID from context (set by runnerAuthMiddleware)
runnerID, ok := r.Context().Value(runnerIDContextKey).(int64)
if !ok {
s.respondError(w, http.StatusUnauthorized, "runner_id not found in context")
return
}
var metadata types.BlendMetadata
if err := json.NewDecoder(r.Body).Decode(&metadata); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid metadata JSON")
return
}
// Verify job exists
var jobUserID int64
err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID)
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "Job not found")
return
}
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify job: %v", err))
return
}
// Find the metadata extraction task for this job
// First try to find task assigned to this runner, then fall back to any metadata task for this job
var taskID int64
err = s.db.QueryRow(
`SELECT id FROM tasks WHERE job_id = ? AND task_type = ? AND runner_id = ?`,
jobID, types.TaskTypeMetadata, runnerID,
).Scan(&taskID)
if err == sql.ErrNoRows {
// Fall back to any metadata task for this job (in case assignment changed)
err = s.db.QueryRow(
`SELECT id FROM tasks WHERE job_id = ? AND task_type = ? ORDER BY created_at DESC LIMIT 1`,
jobID, types.TaskTypeMetadata,
).Scan(&taskID)
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "Metadata extraction task not found")
return
}
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to find task: %v", err))
return
}
// Update the task to be assigned to this runner if it wasn't already
s.db.Exec(
`UPDATE tasks SET runner_id = ? WHERE id = ? AND runner_id IS NULL`,
runnerID, taskID,
)
} else if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to find task: %v", err))
return
}
// Convert metadata to JSON
metadataJSON, err := json.Marshal(metadata)
if err != nil {
s.respondError(w, http.StatusInternalServerError, "Failed to marshal metadata")
return
}
// Update job with metadata
_, err = s.db.Exec(
`UPDATE jobs SET blend_metadata = ? WHERE id = ?`,
string(metadataJSON), jobID,
)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update job metadata: %v", err))
return
}
// Mark task as completed
_, err = s.db.Exec(
`UPDATE tasks SET status = ?, completed_at = CURRENT_TIMESTAMP WHERE id = ?`,
types.TaskStatusCompleted, taskID,
)
if err != nil {
log.Printf("Failed to mark metadata task as completed: %v", err)
} else {
// Update job status and progress after metadata task completes
s.updateJobStatusFromTasks(jobID)
}
log.Printf("Metadata extracted for job %d: frame_start=%d, frame_end=%d", jobID, metadata.FrameStart, metadata.FrameEnd)
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Metadata submitted successfully"})
}
// handleGetJobMetadata retrieves metadata for a job
func (s *Server) handleGetJobMetadata(w http.ResponseWriter, r *http.Request) {
userID, err := getUserID(r)
if err != nil {
s.respondError(w, http.StatusUnauthorized, err.Error())
return
}
jobID, err := parseID(r, "id")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
// Verify job belongs to user
var jobUserID int64
var blendMetadataJSON sql.NullString
err = s.db.QueryRow(
`SELECT user_id, blend_metadata FROM jobs WHERE id = ?`,
jobID,
).Scan(&jobUserID, &blendMetadataJSON)
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "Job not found")
return
}
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query job: %v", err))
return
}
if jobUserID != userID {
s.respondError(w, http.StatusForbidden, "Access denied")
return
}
if !blendMetadataJSON.Valid || blendMetadataJSON.String == "" {
s.respondJSON(w, http.StatusOK, nil)
return
}
var metadata types.BlendMetadata
if err := json.Unmarshal([]byte(blendMetadataJSON.String), &metadata); err != nil {
s.respondError(w, http.StatusInternalServerError, "Failed to parse metadata")
return
}
s.respondJSON(w, http.StatusOK, metadata)
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,658 +0,0 @@
package api
import (
"database/sql"
"encoding/json"
"fmt"
"log"
"net/http"
"strconv"
"sync"
"time"
authpkg "jiggablend/internal/auth"
"jiggablend/internal/database"
"jiggablend/internal/storage"
"jiggablend/pkg/types"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/cors"
"github.com/gorilla/websocket"
)
// Server represents the API server
type Server struct {
db *database.DB
auth *authpkg.Auth
secrets *authpkg.Secrets
storage *storage.Storage
router *chi.Mux
// WebSocket connections
wsUpgrader websocket.Upgrader
runnerConns map[int64]*websocket.Conn
runnerConnsMu sync.RWMutex
frontendConns map[string]*websocket.Conn // key: "jobId:taskId"
frontendConnsMu sync.RWMutex
// Mutexes for each frontend connection to serialize writes
frontendConnsWriteMu map[string]*sync.Mutex // key: "jobId:taskId"
frontendConnsWriteMuMu sync.RWMutex
// Throttling for progress updates (per job)
progressUpdateTimes map[int64]time.Time // key: jobID
progressUpdateTimesMu sync.RWMutex
}
// NewServer creates a new API server
func NewServer(db *database.DB, auth *authpkg.Auth, storage *storage.Storage) (*Server, error) {
secrets, err := authpkg.NewSecrets(db.DB)
if err != nil {
return nil, fmt.Errorf("failed to initialize secrets: %w", err)
}
s := &Server{
db: db,
auth: auth,
secrets: secrets,
storage: storage,
router: chi.NewRouter(),
wsUpgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins for now
},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
},
runnerConns: make(map[int64]*websocket.Conn),
frontendConns: make(map[string]*websocket.Conn),
frontendConnsWriteMu: make(map[string]*sync.Mutex),
progressUpdateTimes: make(map[int64]time.Time),
}
s.setupMiddleware()
s.setupRoutes()
s.StartBackgroundTasks()
return s, nil
}
// setupMiddleware configures middleware
func (s *Server) setupMiddleware() {
s.router.Use(middleware.Logger)
s.router.Use(middleware.Recoverer)
// Note: Timeout middleware is NOT applied globally to avoid conflicts with WebSocket connections
// WebSocket connections are long-lived and should not have HTTP timeouts
s.router.Use(cors.Handler(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "Range"},
ExposedHeaders: []string{"Link", "Content-Range", "Accept-Ranges", "Content-Length"},
AllowCredentials: true,
MaxAge: 300,
}))
}
// setupRoutes configures routes
func (s *Server) setupRoutes() {
// Public routes
s.router.Route("/api/auth", func(r chi.Router) {
r.Get("/providers", s.handleGetAuthProviders)
r.Get("/google/login", s.handleGoogleLogin)
r.Get("/google/callback", s.handleGoogleCallback)
r.Get("/discord/login", s.handleDiscordLogin)
r.Get("/discord/callback", s.handleDiscordCallback)
r.Get("/local/available", s.handleLocalLoginAvailable)
r.Post("/local/register", s.handleLocalRegister)
r.Post("/local/login", s.handleLocalLogin)
r.Post("/logout", s.handleLogout)
r.Get("/me", s.handleGetMe)
r.Post("/change-password", s.handleChangePassword)
})
// Protected routes
s.router.Route("/api/jobs", func(r chi.Router) {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(s.auth.Middleware(next.ServeHTTP))
})
r.Post("/", s.handleCreateJob)
r.Get("/", s.handleListJobs)
r.Get("/{id}", s.handleGetJob)
r.Delete("/{id}", s.handleCancelJob)
r.Post("/{id}/delete", s.handleDeleteJob)
r.Post("/{id}/upload", s.handleUploadJobFile)
r.Get("/{id}/files", s.handleListJobFiles)
r.Get("/{id}/files/{fileId}/download", s.handleDownloadJobFile)
r.Get("/{id}/video", s.handleStreamVideo)
r.Get("/{id}/metadata", s.handleGetJobMetadata)
r.Get("/{id}/tasks", s.handleListJobTasks)
r.Get("/{id}/tasks/{taskId}/logs", s.handleGetTaskLogs)
// WebSocket route - no timeout middleware (long-lived connection)
r.With(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Remove timeout middleware for WebSocket
next.ServeHTTP(w, r)
})
}).Get("/{id}/tasks/{taskId}/logs/ws", s.handleStreamTaskLogsWebSocket)
r.Get("/{id}/tasks/{taskId}/steps", s.handleGetTaskSteps)
r.Post("/{id}/tasks/{taskId}/retry", s.handleRetryTask)
})
// Admin routes
s.router.Route("/api/admin", func(r chi.Router) {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(s.auth.AdminMiddleware(next.ServeHTTP))
})
r.Route("/runners", func(r chi.Router) {
r.Route("/tokens", func(r chi.Router) {
r.Post("/", s.handleGenerateRegistrationToken)
r.Get("/", s.handleListRegistrationTokens)
r.Delete("/{id}", s.handleRevokeRegistrationToken)
})
r.Get("/", s.handleListRunnersAdmin)
r.Post("/{id}/verify", s.handleVerifyRunner)
r.Delete("/{id}", s.handleDeleteRunner)
})
r.Route("/users", func(r chi.Router) {
r.Get("/", s.handleListUsers)
r.Get("/{id}/jobs", s.handleGetUserJobs)
r.Post("/{id}/admin", s.handleSetUserAdminStatus)
})
r.Route("/settings", func(r chi.Router) {
r.Get("/registration", s.handleGetRegistrationEnabled)
r.Post("/registration", s.handleSetRegistrationEnabled)
})
})
// Runner API
s.router.Route("/api/runner", func(r chi.Router) {
// Registration doesn't require auth (uses token)
r.With(middleware.Timeout(60*time.Second)).Post("/register", s.handleRegisterRunner)
// WebSocket endpoint (auth handled in handler) - no timeout middleware
r.Get("/ws", s.handleRunnerWebSocket)
// File operations still use HTTP (WebSocket not suitable for large files)
r.Group(func(r chi.Router) {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(s.runnerAuthMiddleware(next.ServeHTTP))
})
r.Post("/tasks/{id}/progress", s.handleUpdateTaskProgress)
r.Post("/tasks/{id}/steps", s.handleUpdateTaskStep)
r.Get("/files/{jobId}/*", s.handleDownloadFileForRunner)
r.Post("/files/{jobId}/upload", s.handleUploadFileFromRunner)
r.Get("/jobs/{jobId}/status", s.handleGetJobStatusForRunner)
r.Get("/jobs/{jobId}/files", s.handleGetJobFilesForRunner)
r.Post("/jobs/{jobId}/metadata", s.handleSubmitMetadata)
})
})
// Serve static files (built React app)
s.router.Handle("/*", http.FileServer(http.Dir("./web/dist")))
}
// ServeHTTP implements http.Handler
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.router.ServeHTTP(w, r)
}
// JSON response helpers
func (s *Server) respondJSON(w http.ResponseWriter, status int, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(data); err != nil {
log.Printf("Failed to encode JSON response: %v", err)
}
}
func (s *Server) respondError(w http.ResponseWriter, status int, message string) {
s.respondJSON(w, status, map[string]string{"error": message})
}
// Auth handlers
func (s *Server) handleGoogleLogin(w http.ResponseWriter, r *http.Request) {
url, err := s.auth.GoogleLoginURL()
if err != nil {
s.respondError(w, http.StatusInternalServerError, err.Error())
return
}
http.Redirect(w, r, url, http.StatusFound)
}
func (s *Server) handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
if code == "" {
s.respondError(w, http.StatusBadRequest, "Missing code parameter")
return
}
session, err := s.auth.GoogleCallback(r.Context(), code)
if err != nil {
// If registration is disabled, redirect back to login with error
if err.Error() == "registration is disabled" {
http.Redirect(w, r, "/?error=registration_disabled", http.StatusFound)
return
}
s.respondError(w, http.StatusInternalServerError, err.Error())
return
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: sessionID,
Path: "/",
MaxAge: 86400,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
http.Redirect(w, r, "/", http.StatusFound)
}
func (s *Server) handleDiscordLogin(w http.ResponseWriter, r *http.Request) {
url, err := s.auth.DiscordLoginURL()
if err != nil {
s.respondError(w, http.StatusInternalServerError, err.Error())
return
}
http.Redirect(w, r, url, http.StatusFound)
}
func (s *Server) handleDiscordCallback(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
if code == "" {
s.respondError(w, http.StatusBadRequest, "Missing code parameter")
return
}
session, err := s.auth.DiscordCallback(r.Context(), code)
if err != nil {
// If registration is disabled, redirect back to login with error
if err.Error() == "registration is disabled" {
http.Redirect(w, r, "/?error=registration_disabled", http.StatusFound)
return
}
s.respondError(w, http.StatusInternalServerError, err.Error())
return
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: sessionID,
Path: "/",
MaxAge: 86400,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
http.Redirect(w, r, "/", http.StatusFound)
}
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("session_id")
if err == nil {
s.auth.DeleteSession(cookie.Value)
}
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
})
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Logged out"})
}
func (s *Server) handleGetMe(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("session_id")
if err != nil {
s.respondError(w, http.StatusUnauthorized, "Not authenticated")
return
}
session, ok := s.auth.GetSession(cookie.Value)
if !ok {
s.respondError(w, http.StatusUnauthorized, "Invalid session")
return
}
s.respondJSON(w, http.StatusOK, map[string]interface{}{
"id": session.UserID,
"email": session.Email,
"name": session.Name,
"is_admin": session.IsAdmin,
})
}
func (s *Server) handleGetAuthProviders(w http.ResponseWriter, r *http.Request) {
s.respondJSON(w, http.StatusOK, map[string]bool{
"google": s.auth.IsGoogleOAuthConfigured(),
"discord": s.auth.IsDiscordOAuthConfigured(),
"local": s.auth.IsLocalLoginEnabled(),
})
}
func (s *Server) handleLocalLoginAvailable(w http.ResponseWriter, r *http.Request) {
s.respondJSON(w, http.StatusOK, map[string]bool{
"available": s.auth.IsLocalLoginEnabled(),
})
}
func (s *Server) handleLocalRegister(w http.ResponseWriter, r *http.Request) {
var req struct {
Email string `json:"email"`
Name string `json:"name"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
if req.Email == "" || req.Name == "" || req.Password == "" {
s.respondError(w, http.StatusBadRequest, "Email, name, and password are required")
return
}
if len(req.Password) < 8 {
s.respondError(w, http.StatusBadRequest, "Password must be at least 8 characters long")
return
}
session, err := s.auth.RegisterLocalUser(req.Email, req.Name, req.Password)
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: sessionID,
Path: "/",
MaxAge: 86400,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
"message": "Registration successful",
"user": map[string]interface{}{
"id": session.UserID,
"email": session.Email,
"name": session.Name,
"is_admin": session.IsAdmin,
},
})
}
func (s *Server) handleLocalLogin(w http.ResponseWriter, r *http.Request) {
var req struct {
Username string `json:"username"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
if req.Username == "" || req.Password == "" {
s.respondError(w, http.StatusBadRequest, "Username and password are required")
return
}
session, err := s.auth.LocalLogin(req.Username, req.Password)
if err != nil {
s.respondError(w, http.StatusUnauthorized, "Invalid credentials")
return
}
sessionID := s.auth.CreateSession(session)
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: sessionID,
Path: "/",
MaxAge: 86400,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
s.respondJSON(w, http.StatusOK, map[string]interface{}{
"message": "Login successful",
"user": map[string]interface{}{
"id": session.UserID,
"email": session.Email,
"name": session.Name,
"is_admin": session.IsAdmin,
},
})
}
func (s *Server) handleChangePassword(w http.ResponseWriter, r *http.Request) {
userID, err := getUserID(r)
if err != nil {
s.respondError(w, http.StatusUnauthorized, err.Error())
return
}
var req struct {
OldPassword string `json:"old_password"`
NewPassword string `json:"new_password"`
TargetUserID *int64 `json:"target_user_id,omitempty"` // For admin to change other users' passwords
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
if req.NewPassword == "" {
s.respondError(w, http.StatusBadRequest, "New password is required")
return
}
if len(req.NewPassword) < 8 {
s.respondError(w, http.StatusBadRequest, "Password must be at least 8 characters long")
return
}
isAdmin := authpkg.IsAdmin(r.Context())
// If target_user_id is provided and user is admin, allow changing other user's password
if req.TargetUserID != nil && isAdmin {
if err := s.auth.AdminChangePassword(*req.TargetUserID, req.NewPassword); err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Password changed successfully"})
return
}
// Otherwise, user is changing their own password (requires old password)
if req.OldPassword == "" {
s.respondError(w, http.StatusBadRequest, "Old password is required")
return
}
if err := s.auth.ChangePassword(userID, req.OldPassword, req.NewPassword); err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Password changed successfully"})
}
// Helper to get user ID from context
func getUserID(r *http.Request) (int64, error) {
userID, ok := authpkg.GetUserID(r.Context())
if !ok {
return 0, fmt.Errorf("user ID not found in context")
}
return userID, nil
}
// Helper to parse ID from URL
func parseID(r *http.Request, param string) (int64, error) {
idStr := chi.URLParam(r, param)
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
return 0, fmt.Errorf("invalid ID: %s", idStr)
}
return id, nil
}
// StartBackgroundTasks starts background goroutines for error recovery
func (s *Server) StartBackgroundTasks() {
go s.recoverStuckTasks()
go s.cleanupOldMetadataJobs()
}
// recoverStuckTasks periodically checks for dead runners and stuck tasks
func (s *Server) recoverStuckTasks() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
// Also distribute tasks every 10 seconds (reduced frequency since we have event-driven distribution)
distributeTicker := time.NewTicker(10 * time.Second)
defer distributeTicker.Stop()
go func() {
for range distributeTicker.C {
s.distributeTasksToRunners()
}
}()
for range ticker.C {
func() {
defer func() {
if r := recover(); r != nil {
log.Printf("Panic in recoverStuckTasks: %v", r)
}
}()
// Find dead runners (no heartbeat for 90 seconds)
// But only mark as dead if they're not actually connected via WebSocket
rows, err := s.db.Query(
`SELECT id FROM runners
WHERE last_heartbeat < CURRENT_TIMESTAMP - INTERVAL '90 seconds'
AND status = ?`,
types.RunnerStatusOnline,
)
if err != nil {
log.Printf("Failed to query dead runners: %v", err)
return
}
defer rows.Close()
var deadRunnerIDs []int64
s.runnerConnsMu.RLock()
for rows.Next() {
var runnerID int64
if err := rows.Scan(&runnerID); err == nil {
// Only mark as dead if not actually connected via WebSocket
// The WebSocket connection is the source of truth
if _, stillConnected := s.runnerConns[runnerID]; !stillConnected {
deadRunnerIDs = append(deadRunnerIDs, runnerID)
}
// If still connected, heartbeat should be updated by pong handler or heartbeat message
// No need to manually update here - if it's stale, the pong handler isn't working
}
}
s.runnerConnsMu.RUnlock()
rows.Close()
if len(deadRunnerIDs) == 0 {
// Check for task timeouts
s.recoverTaskTimeouts()
return
}
// Reset tasks assigned to dead runners
for _, runnerID := range deadRunnerIDs {
s.redistributeRunnerTasks(runnerID)
// Mark runner as offline
_, _ = s.db.Exec(
`UPDATE runners SET status = ? WHERE id = ?`,
types.RunnerStatusOffline, runnerID,
)
}
// Check for task timeouts
s.recoverTaskTimeouts()
// Distribute newly recovered tasks
s.distributeTasksToRunners()
}()
}
}
// recoverTaskTimeouts handles tasks that have exceeded their timeout
func (s *Server) recoverTaskTimeouts() {
// Find tasks running longer than their timeout
rows, err := s.db.Query(
`SELECT t.id, t.runner_id, t.retry_count, t.max_retries, t.timeout_seconds, t.started_at
FROM tasks t
WHERE t.status = ?
AND t.started_at IS NOT NULL
AND (t.timeout_seconds IS NULL OR
t.started_at + INTERVAL (t.timeout_seconds || ' seconds') < CURRENT_TIMESTAMP)`,
types.TaskStatusRunning,
)
if err != nil {
log.Printf("Failed to query timed out tasks: %v", err)
return
}
defer rows.Close()
for rows.Next() {
var taskID int64
var runnerID sql.NullInt64
var retryCount, maxRetries int
var timeoutSeconds sql.NullInt64
var startedAt time.Time
err := rows.Scan(&taskID, &runnerID, &retryCount, &maxRetries, &timeoutSeconds, &startedAt)
if err != nil {
continue
}
// Use default timeout if not set (5 minutes for frame tasks, 24 hours for FFmpeg)
timeout := 300 // 5 minutes default
if timeoutSeconds.Valid {
timeout = int(timeoutSeconds.Int64)
}
// Check if actually timed out
if time.Since(startedAt).Seconds() < float64(timeout) {
continue
}
if retryCount >= maxRetries {
// Mark as failed
_, err = s.db.Exec(
`UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL
WHERE id = ?`,
types.TaskStatusFailed, "Task timeout exceeded, max retries reached", taskID,
)
if err != nil {
log.Printf("Failed to mark task %d as failed: %v", taskID, err)
}
} else {
// Reset to pending
_, err = s.db.Exec(
`UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL,
retry_count = retry_count + 1 WHERE id = ?`,
types.TaskStatusPending, taskID,
)
if err == nil {
// Add log entry using the helper function
s.logTaskEvent(taskID, nil, types.LogLevelWarn, fmt.Sprintf("Task timeout exceeded, resetting (retry %d/%d)", retryCount+1, maxRetries), "")
}
}
}
}

View File

@@ -5,10 +5,13 @@ import (
"database/sql"
"encoding/json"
"fmt"
"jiggablend/internal/config"
"jiggablend/internal/database"
"log"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/google/uuid"
@@ -17,12 +20,31 @@ import (
"golang.org/x/oauth2/google"
)
// Context key types to avoid collisions (typed keys are safer than string keys)
type contextKey int
const (
contextKeyUserID contextKey = iota
contextKeyUserEmail
contextKeyUserName
contextKeyIsAdmin
)
// Configuration constants
const (
SessionDuration = 24 * time.Hour
SessionCleanupInterval = 1 * time.Hour
)
// Auth handles authentication
type Auth struct {
db *sql.DB
db *database.DB
cfg *config.Config
googleConfig *oauth2.Config
discordConfig *oauth2.Config
sessionStore map[string]*Session
sessionCache map[string]*Session // In-memory cache for performance
cacheMu sync.RWMutex
stopCleanup chan struct{}
}
// Session represents a user session
@@ -35,41 +57,53 @@ type Session struct {
}
// NewAuth creates a new auth instance
func NewAuth(db *sql.DB) (*Auth, error) {
func NewAuth(db *database.DB, cfg *config.Config) (*Auth, error) {
auth := &Auth{
db: db,
sessionStore: make(map[string]*Session),
cfg: cfg,
sessionCache: make(map[string]*Session),
stopCleanup: make(chan struct{}),
}
// Initialize Google OAuth
googleClientID := os.Getenv("GOOGLE_CLIENT_ID")
googleClientSecret := os.Getenv("GOOGLE_CLIENT_SECRET")
// Initialize Google OAuth from database config
googleClientID := cfg.GoogleClientID()
googleClientSecret := cfg.GoogleClientSecret()
if googleClientID != "" && googleClientSecret != "" {
auth.googleConfig = &oauth2.Config{
ClientID: googleClientID,
ClientSecret: googleClientSecret,
RedirectURL: os.Getenv("GOOGLE_REDIRECT_URL"),
RedirectURL: cfg.GoogleRedirectURL(),
Scopes: []string{"openid", "profile", "email"},
Endpoint: google.Endpoint,
}
log.Printf("Google OAuth configured")
}
// Initialize Discord OAuth
discordClientID := os.Getenv("DISCORD_CLIENT_ID")
discordClientSecret := os.Getenv("DISCORD_CLIENT_SECRET")
// Initialize Discord OAuth from database config
discordClientID := cfg.DiscordClientID()
discordClientSecret := cfg.DiscordClientSecret()
if discordClientID != "" && discordClientSecret != "" {
auth.discordConfig = &oauth2.Config{
ClientID: discordClientID,
ClientSecret: discordClientSecret,
RedirectURL: os.Getenv("DISCORD_REDIRECT_URL"),
RedirectURL: cfg.DiscordRedirectURL(),
Scopes: []string{"identify", "email"},
Endpoint: oauth2.Endpoint{
AuthURL: "https://discord.com/api/oauth2/authorize",
TokenURL: "https://discord.com/api/oauth2/token",
},
}
log.Printf("Discord OAuth configured")
}
// Load existing sessions from database into cache
if err := auth.loadSessionsFromDB(); err != nil {
log.Printf("Warning: Failed to load sessions from database: %v", err)
}
// Start background cleanup goroutine
go auth.cleanupExpiredSessions()
// Initialize admin settings on startup to ensure they persist between boots
if err := auth.initializeSettings(); err != nil {
log.Printf("Warning: Failed to initialize admin settings: %v", err)
@@ -85,19 +119,119 @@ func NewAuth(db *sql.DB) (*Auth, error) {
return auth, nil
}
// Close stops background goroutines
func (a *Auth) Close() {
close(a.stopCleanup)
}
// loadSessionsFromDB loads all valid sessions from database into cache
func (a *Auth) loadSessionsFromDB() error {
var sessions []struct {
sessionID string
session Session
}
err := a.db.With(func(conn *sql.DB) error {
rows, err := conn.Query(
`SELECT session_id, user_id, email, name, is_admin, expires_at
FROM sessions WHERE expires_at > CURRENT_TIMESTAMP`,
)
if err != nil {
return fmt.Errorf("failed to query sessions: %w", err)
}
defer rows.Close()
for rows.Next() {
var s struct {
sessionID string
session Session
}
err := rows.Scan(&s.sessionID, &s.session.UserID, &s.session.Email, &s.session.Name, &s.session.IsAdmin, &s.session.ExpiresAt)
if err != nil {
log.Printf("Warning: Failed to scan session row: %v", err)
continue
}
sessions = append(sessions, s)
}
return nil
})
if err != nil {
return err
}
a.cacheMu.Lock()
defer a.cacheMu.Unlock()
for _, s := range sessions {
a.sessionCache[s.sessionID] = &s.session
}
if len(sessions) > 0 {
log.Printf("Loaded %d active sessions from database", len(sessions))
}
return nil
}
// cleanupExpiredSessions periodically removes expired sessions from database and cache
func (a *Auth) cleanupExpiredSessions() {
ticker := time.NewTicker(SessionCleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
// Delete expired sessions from database
var deleted int64
err := a.db.With(func(conn *sql.DB) error {
result, err := conn.Exec(`DELETE FROM sessions WHERE expires_at < CURRENT_TIMESTAMP`)
if err != nil {
return err
}
deleted, _ = result.RowsAffected()
return nil
})
if err != nil {
log.Printf("Warning: Failed to cleanup expired sessions: %v", err)
continue
}
// Clean up cache
a.cacheMu.Lock()
now := time.Now()
for sessionID, session := range a.sessionCache {
if now.After(session.ExpiresAt) {
delete(a.sessionCache, sessionID)
}
}
a.cacheMu.Unlock()
if deleted > 0 {
log.Printf("Cleaned up %d expired sessions", deleted)
}
case <-a.stopCleanup:
return
}
}
}
// initializeSettings ensures all admin settings are initialized with defaults if they don't exist
func (a *Auth) initializeSettings() error {
// Initialize registration_enabled setting (default: true) if it doesn't exist
var settingCount int
err := a.db.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
})
if err != nil {
return fmt.Errorf("failed to check registration_enabled setting: %w", err)
}
if settingCount == 0 {
_, err = a.db.Exec(
err = a.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)`,
"registration_enabled", "true",
)
return err
})
if err != nil {
return fmt.Errorf("failed to initialize registration_enabled setting: %w", err)
}
@@ -118,7 +252,9 @@ func (a *Auth) initializeTestUser() error {
// Check if user already exists
var exists bool
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND oauth_provider = 'local')", testEmail).Scan(&exists)
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND oauth_provider = 'local')", testEmail).Scan(&exists)
})
if err != nil {
return fmt.Errorf("failed to check if test user exists: %w", err)
}
@@ -137,7 +273,12 @@ func (a *Auth) initializeTestUser() error {
// Check if this is the first user (make them admin)
var userCount int
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
err = a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
})
if err != nil {
return fmt.Errorf("failed to check user count: %w", err)
}
isAdmin := userCount == 0
// Create test user (use email as name if no name is provided)
@@ -147,10 +288,13 @@ func (a *Auth) initializeTestUser() error {
}
// Create test user
_, err = a.db.Exec(
err = a.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?)",
testEmail, testName, testEmail, string(hashedPassword), isAdmin,
)
return err
})
if err != nil {
return fmt.Errorf("failed to create test user: %w", err)
}
@@ -242,7 +386,9 @@ func (a *Auth) DiscordCallback(ctx context.Context, code string) (*Session, erro
// IsRegistrationEnabled checks if new user registration is enabled
func (a *Auth) IsRegistrationEnabled() (bool, error) {
var value string
err := a.db.QueryRow("SELECT value FROM settings WHERE key = ?", "registration_enabled").Scan(&value)
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT value FROM settings WHERE key = ?", "registration_enabled").Scan(&value)
})
if err == sql.ErrNoRows {
// Default to enabled if setting doesn't exist
return true, nil
@@ -262,29 +408,31 @@ func (a *Auth) SetRegistrationEnabled(enabled bool) error {
// Check if setting exists
var exists bool
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", "registration_enabled").Scan(&exists)
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", "registration_enabled").Scan(&exists)
})
if err != nil {
return fmt.Errorf("failed to check if setting exists: %w", err)
}
err = a.db.With(func(conn *sql.DB) error {
if exists {
// Update existing setting
_, err = a.db.Exec(
_, err = conn.Exec(
"UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?",
value, "registration_enabled",
)
if err != nil {
return fmt.Errorf("failed to update setting: %w", err)
}
} else {
// Insert new setting
_, err = a.db.Exec(
_, err = conn.Exec(
"INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)",
"registration_enabled", value,
)
if err != nil {
return fmt.Errorf("failed to insert setting: %w", err)
}
return err
})
if err != nil {
return fmt.Errorf("failed to set registration_enabled: %w", err)
}
return nil
@@ -299,17 +447,21 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
var dbProvider, dbOAuthID string
// First, try to find by provider + oauth_id
err := a.db.QueryRow(
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
"SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE oauth_provider = ? AND oauth_id = ?",
provider, oauthID,
).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID)
})
if err == sql.ErrNoRows {
// Not found by provider+oauth_id, check by email for account linking
err = a.db.QueryRow(
err = a.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
"SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE email = ?",
email,
).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID)
})
if err == sql.ErrNoRows {
// User doesn't exist, check if registration is enabled
@@ -323,14 +475,26 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
// Check if this is the first user
var userCount int
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
err = a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
})
if err != nil {
return nil, fmt.Errorf("failed to check user count: %w", err)
}
isAdmin = userCount == 0
// Create new user
err = a.db.QueryRow(
"INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?) RETURNING id",
err = a.db.With(func(conn *sql.DB) error {
result, err := conn.Exec(
"INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?)",
email, name, provider, oauthID, isAdmin,
).Scan(&userID)
)
if err != nil {
return err
}
userID, err = result.LastInsertId()
return err
})
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
@@ -339,10 +503,13 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
} else {
// User exists with same email but different provider - link accounts by updating provider info
// This allows the user to log in with any provider that has the same email
_, err = a.db.Exec(
err = a.db.With(func(conn *sql.DB) error {
_, err = conn.Exec(
"UPDATE users SET oauth_provider = ?, oauth_id = ?, name = ? WHERE id = ?",
provider, oauthID, name, userID,
)
return err
})
if err != nil {
return nil, fmt.Errorf("failed to link account: %w", err)
}
@@ -353,10 +520,13 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
} else {
// User found by provider+oauth_id, update info if changed
if dbEmail != email || dbName != name {
_, err = a.db.Exec(
err = a.db.With(func(conn *sql.DB) error {
_, err = conn.Exec(
"UPDATE users SET email = ?, name = ? WHERE id = ?",
email, name, userID,
)
return err
})
if err != nil {
return nil, fmt.Errorf("failed to update user: %w", err)
}
@@ -368,41 +538,134 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
Email: email,
Name: name,
IsAdmin: isAdmin,
ExpiresAt: time.Now().Add(24 * time.Hour),
ExpiresAt: time.Now().Add(SessionDuration),
}
return session, nil
}
// CreateSession creates a new session and returns a session ID
// Sessions are persisted to database and cached in memory
func (a *Auth) CreateSession(session *Session) string {
sessionID := uuid.New().String()
a.sessionStore[sessionID] = session
// Store in database first
err := a.db.With(func(conn *sql.DB) error {
_, err := conn.Exec(
`INSERT INTO sessions (session_id, user_id, email, name, is_admin, expires_at)
VALUES (?, ?, ?, ?, ?, ?)`,
sessionID, session.UserID, session.Email, session.Name, session.IsAdmin, session.ExpiresAt,
)
return err
})
if err != nil {
log.Printf("Warning: Failed to persist session to database: %v", err)
// Continue anyway - session will work from cache but won't survive restart
}
// Store in cache
a.cacheMu.Lock()
a.sessionCache[sessionID] = session
a.cacheMu.Unlock()
return sessionID
}
// GetSession retrieves a session by ID
// First checks cache, then database if not found
func (a *Auth) GetSession(sessionID string) (*Session, bool) {
session, ok := a.sessionStore[sessionID]
if !ok {
// Check cache first
a.cacheMu.RLock()
session, ok := a.sessionCache[sessionID]
a.cacheMu.RUnlock()
if ok {
if time.Now().After(session.ExpiresAt) {
a.DeleteSession(sessionID)
return nil, false
}
// Refresh admin status from database
var isAdmin bool
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin)
})
if err == nil {
session.IsAdmin = isAdmin
}
return session, true
}
// Not in cache, check database
session = &Session{}
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT user_id, email, name, is_admin, expires_at
FROM sessions WHERE session_id = ?`,
sessionID,
).Scan(&session.UserID, &session.Email, &session.Name, &session.IsAdmin, &session.ExpiresAt)
})
if err == sql.ErrNoRows {
return nil, false
}
if err != nil {
log.Printf("Warning: Failed to query session from database: %v", err)
return nil, false
}
if time.Now().After(session.ExpiresAt) {
delete(a.sessionStore, sessionID)
a.DeleteSession(sessionID)
return nil, false
}
// Refresh admin status from database
var isAdmin bool
err := a.db.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin)
err = a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin)
})
if err == nil {
session.IsAdmin = isAdmin
}
// Add to cache
a.cacheMu.Lock()
a.sessionCache[sessionID] = session
a.cacheMu.Unlock()
return session, true
}
// DeleteSession deletes a session
// DeleteSession deletes a session from both cache and database
func (a *Auth) DeleteSession(sessionID string) {
delete(a.sessionStore, sessionID)
// Delete from cache
a.cacheMu.Lock()
delete(a.sessionCache, sessionID)
a.cacheMu.Unlock()
// Delete from database
err := a.db.With(func(conn *sql.DB) error {
_, err := conn.Exec("DELETE FROM sessions WHERE session_id = ?", sessionID)
return err
})
if err != nil {
log.Printf("Warning: Failed to delete session from database: %v", err)
}
}
// IsProductionMode returns true if running in production mode
// This is a package-level function that checks the environment variable
// For config-based checks, use Config.IsProductionMode()
func IsProductionMode() bool {
// Check environment variable first for backwards compatibility
if os.Getenv("PRODUCTION") == "true" {
return true
}
return false
}
// IsProductionModeFromConfig returns true if production mode is enabled in config
func (a *Auth) IsProductionModeFromConfig() bool {
return a.cfg.IsProductionMode()
}
// Middleware creates an authentication middleware
@@ -410,6 +673,7 @@ func (a *Auth) Middleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("session_id")
if err != nil {
log.Printf("Authentication failed: missing session cookie for %s %s", r.Method, r.URL.Path)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
@@ -418,30 +682,31 @@ func (a *Auth) Middleware(next http.HandlerFunc) http.HandlerFunc {
session, ok := a.GetSession(cookie.Value)
if !ok {
log.Printf("Authentication failed: invalid session cookie for %s %s", r.Method, r.URL.Path)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
// Add user info to request context
ctx := context.WithValue(r.Context(), "user_id", session.UserID)
ctx = context.WithValue(ctx, "user_email", session.Email)
ctx = context.WithValue(ctx, "user_name", session.Name)
ctx = context.WithValue(ctx, "is_admin", session.IsAdmin)
// Add user info to request context using typed keys
ctx := context.WithValue(r.Context(), contextKeyUserID, session.UserID)
ctx = context.WithValue(ctx, contextKeyUserEmail, session.Email)
ctx = context.WithValue(ctx, contextKeyUserName, session.Name)
ctx = context.WithValue(ctx, contextKeyIsAdmin, session.IsAdmin)
next(w, r.WithContext(ctx))
}
}
// GetUserID gets the user ID from context
func GetUserID(ctx context.Context) (int64, bool) {
userID, ok := ctx.Value("user_id").(int64)
userID, ok := ctx.Value(contextKeyUserID).(int64)
return userID, ok
}
// IsAdmin checks if the user in context is an admin
func IsAdmin(ctx context.Context) bool {
isAdmin, ok := ctx.Value("is_admin").(bool)
isAdmin, ok := ctx.Value(contextKeyIsAdmin).(bool)
return ok && isAdmin
}
@@ -451,6 +716,7 @@ func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
// First check authentication
cookie, err := r.Cookie("session_id")
if err != nil {
log.Printf("Admin authentication failed: missing session cookie for %s %s", r.Method, r.URL.Path)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
@@ -459,6 +725,7 @@ func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
session, ok := a.GetSession(cookie.Value)
if !ok {
log.Printf("Admin authentication failed: invalid session cookie for %s %s", r.Method, r.URL.Path)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
@@ -467,25 +734,26 @@ func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
// Then check admin status
if !session.IsAdmin {
log.Printf("Admin access denied: user %d (email: %s) attempted to access admin endpoint %s %s", session.UserID, session.Email, r.Method, r.URL.Path)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
json.NewEncoder(w).Encode(map[string]string{"error": "Forbidden: Admin access required"})
return
}
// Add user info to request context
ctx := context.WithValue(r.Context(), "user_id", session.UserID)
ctx = context.WithValue(ctx, "user_email", session.Email)
ctx = context.WithValue(ctx, "user_name", session.Name)
ctx = context.WithValue(ctx, "is_admin", session.IsAdmin)
// Add user info to request context using typed keys
ctx := context.WithValue(r.Context(), contextKeyUserID, session.UserID)
ctx = context.WithValue(ctx, contextKeyUserEmail, session.Email)
ctx = context.WithValue(ctx, contextKeyUserName, session.Name)
ctx = context.WithValue(ctx, contextKeyIsAdmin, session.IsAdmin)
next(w, r.WithContext(ctx))
}
}
// IsLocalLoginEnabled returns whether local login is enabled
// Local login is enabled when ENABLE_LOCAL_AUTH environment variable is set to "true"
// Checks database config first, falls back to environment variable
func (a *Auth) IsLocalLoginEnabled() bool {
return os.Getenv("ENABLE_LOCAL_AUTH") == "true"
return a.cfg.IsLocalAuthEnabled()
}
// IsGoogleOAuthConfigured returns whether Google OAuth is configured
@@ -506,10 +774,12 @@ func (a *Auth) LocalLogin(username, password string) (*Session, error) {
var dbEmail, dbName, passwordHash string
var isAdmin bool
err := a.db.QueryRow(
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
"SELECT id, email, name, password_hash, is_admin FROM users WHERE email = ? AND oauth_provider = 'local'",
email,
).Scan(&userID, &dbEmail, &dbName, &passwordHash, &isAdmin)
})
if err == sql.ErrNoRows {
return nil, fmt.Errorf("invalid credentials")
@@ -534,7 +804,7 @@ func (a *Auth) LocalLogin(username, password string) (*Session, error) {
Email: dbEmail,
Name: dbName,
IsAdmin: isAdmin,
ExpiresAt: time.Now().Add(24 * time.Hour),
ExpiresAt: time.Now().Add(SessionDuration),
}
return session, nil
@@ -553,7 +823,9 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error)
// Check if user already exists
var exists bool
err = a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ?)", email).Scan(&exists)
err = a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ?)", email).Scan(&exists)
})
if err != nil {
return nil, fmt.Errorf("failed to check if user exists: %w", err)
}
@@ -569,15 +841,27 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error)
// Check if this is the first user (make them admin)
var userCount int
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
err = a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
})
if err != nil {
return nil, fmt.Errorf("failed to check user count: %w", err)
}
isAdmin := userCount == 0
// Create user
var userID int64
err = a.db.QueryRow(
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?) RETURNING id",
err = a.db.With(func(conn *sql.DB) error {
result, err := conn.Exec(
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?)",
email, name, email, string(hashedPassword), isAdmin,
).Scan(&userID)
)
if err != nil {
return err
}
userID, err = result.LastInsertId()
return err
})
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
@@ -588,7 +872,7 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error)
Email: email,
Name: name,
IsAdmin: isAdmin,
ExpiresAt: time.Now().Add(24 * time.Hour),
ExpiresAt: time.Now().Add(SessionDuration),
}
return session, nil
@@ -598,7 +882,9 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error)
func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) error {
// Get current password hash
var passwordHash string
err := a.db.QueryRow("SELECT password_hash FROM users WHERE id = ? AND oauth_provider = 'local'", userID).Scan(&passwordHash)
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT password_hash FROM users WHERE id = ? AND oauth_provider = 'local'", userID).Scan(&passwordHash)
})
if err == sql.ErrNoRows {
return fmt.Errorf("user not found or not a local user")
}
@@ -623,7 +909,10 @@ func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) err
}
// Update password
_, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), userID)
err = a.db.With(func(conn *sql.DB) error {
_, err := conn.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), userID)
return err
})
if err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
@@ -635,7 +924,9 @@ func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) err
func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error {
// Verify user exists and is a local user
var exists bool
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ? AND oauth_provider = 'local')", targetUserID).Scan(&exists)
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ? AND oauth_provider = 'local')", targetUserID).Scan(&exists)
})
if err != nil {
return fmt.Errorf("failed to check if user exists: %w", err)
}
@@ -650,7 +941,10 @@ func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error
}
// Update password
_, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), targetUserID)
err = a.db.With(func(conn *sql.DB) error {
_, err := conn.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), targetUserID)
return err
})
if err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
@@ -661,7 +955,9 @@ func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error
// GetFirstUserID returns the ID of the first user (user with the lowest ID)
func (a *Auth) GetFirstUserID() (int64, error) {
var firstUserID int64
err := a.db.QueryRow("SELECT id FROM users ORDER BY id ASC LIMIT 1").Scan(&firstUserID)
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT id FROM users ORDER BY id ASC LIMIT 1").Scan(&firstUserID)
})
if err == sql.ErrNoRows {
return 0, fmt.Errorf("no users found")
}
@@ -675,7 +971,9 @@ func (a *Auth) GetFirstUserID() (int64, error) {
func (a *Auth) SetUserAdminStatus(targetUserID int64, isAdmin bool) error {
// Verify user exists
var exists bool
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", targetUserID).Scan(&exists)
err := a.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", targetUserID).Scan(&exists)
})
if err != nil {
return fmt.Errorf("failed to check if user exists: %w", err)
}
@@ -693,7 +991,10 @@ func (a *Auth) SetUserAdminStatus(targetUserID int64, isAdmin bool) error {
}
// Update admin status
_, err = a.db.Exec("UPDATE users SET is_admin = ? WHERE id = ?", isAdmin, targetUserID)
err = a.db.With(func(conn *sql.DB) error {
_, err := conn.Exec("UPDATE users SET is_admin = ? WHERE id = ?", isAdmin, targetUserID)
return err
})
if err != nil {
return fmt.Errorf("failed to update admin status: %w", err)
}

115
internal/auth/jobtoken.go Normal file
View File

@@ -0,0 +1,115 @@
package auth
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"time"
)
// JobTokenDuration is the validity period for job tokens
const JobTokenDuration = 1 * time.Hour
// JobTokenClaims represents the claims in a job token
type JobTokenClaims struct {
JobID int64 `json:"job_id"`
RunnerID int64 `json:"runner_id"`
TaskID int64 `json:"task_id"`
Exp int64 `json:"exp"` // Unix timestamp
}
// jobTokenSecret is the secret used to sign job tokens
// Generated once at startup and kept in memory
var jobTokenSecret []byte
func init() {
// Generate a random secret for signing job tokens
// This means tokens are invalidated on server restart, which is acceptable
// for short-lived job tokens
jobTokenSecret = make([]byte, 32)
if _, err := rand.Read(jobTokenSecret); err != nil {
panic(fmt.Sprintf("failed to generate job token secret: %v", err))
}
}
// GenerateJobToken creates a new job token for a specific job/runner/task combination
func GenerateJobToken(jobID, runnerID, taskID int64) (string, error) {
claims := JobTokenClaims{
JobID: jobID,
RunnerID: runnerID,
TaskID: taskID,
Exp: time.Now().Add(JobTokenDuration).Unix(),
}
// Encode claims to JSON
claimsJSON, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("failed to marshal claims: %w", err)
}
// Create HMAC signature
h := hmac.New(sha256.New, jobTokenSecret)
h.Write(claimsJSON)
signature := h.Sum(nil)
// Combine claims and signature: base64(claims).base64(signature)
token := base64.RawURLEncoding.EncodeToString(claimsJSON) + "." +
base64.RawURLEncoding.EncodeToString(signature)
return token, nil
}
// ValidateJobToken validates a job token and returns the claims if valid
func ValidateJobToken(token string) (*JobTokenClaims, error) {
// Split token into claims and signature
var claimsB64, sigB64 string
dotIdx := -1
for i := len(token) - 1; i >= 0; i-- {
if token[i] == '.' {
dotIdx = i
break
}
}
if dotIdx == -1 {
return nil, fmt.Errorf("invalid token format")
}
claimsB64 = token[:dotIdx]
sigB64 = token[dotIdx+1:]
// Decode claims
claimsJSON, err := base64.RawURLEncoding.DecodeString(claimsB64)
if err != nil {
return nil, fmt.Errorf("invalid token encoding: %w", err)
}
// Decode signature
signature, err := base64.RawURLEncoding.DecodeString(sigB64)
if err != nil {
return nil, fmt.Errorf("invalid signature encoding: %w", err)
}
// Verify signature
h := hmac.New(sha256.New, jobTokenSecret)
h.Write(claimsJSON)
expectedSig := h.Sum(nil)
if !hmac.Equal(signature, expectedSig) {
return nil, fmt.Errorf("invalid signature")
}
// Parse claims
var claims JobTokenClaims
if err := json.Unmarshal(claimsJSON, &claims); err != nil {
return nil, fmt.Errorf("invalid claims: %w", err)
}
// Check expiration
if time.Now().Unix() > claims.Exp {
return nil, fmt.Errorf("token expired")
}
return &claims, nil
}

View File

@@ -1,276 +1,236 @@
package auth
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/hex"
"fmt"
"io"
"log"
"net/http"
"os"
"jiggablend/internal/config"
"jiggablend/internal/database"
"strings"
"sync"
"time"
)
// Secrets handles secret and token management
// Secrets handles API key management
type Secrets struct {
db *sql.DB
fixedRegistrationToken string // Fixed token from environment variable (reusable, never expires)
db *database.DB
cfg *config.Config
RegistrationMu sync.Mutex // Protects concurrent runner registrations
}
// NewSecrets creates a new secrets manager
func NewSecrets(db *sql.DB) (*Secrets, error) {
s := &Secrets{db: db}
// Check for fixed registration token from environment
fixedToken := os.Getenv("FIXED_REGISTRATION_TOKEN")
if fixedToken != "" {
s.fixedRegistrationToken = fixedToken
log.Printf("Fixed registration token enabled (from FIXED_REGISTRATION_TOKEN env var)")
log.Printf("WARNING: Fixed registration token is reusable and never expires - use only for testing/development!")
}
// Ensure manager secret exists
if err := s.ensureManagerSecret(); err != nil {
return nil, fmt.Errorf("failed to ensure manager secret: %w", err)
}
return s, nil
func NewSecrets(db *database.DB, cfg *config.Config) (*Secrets, error) {
return &Secrets{db: db, cfg: cfg}, nil
}
// ensureManagerSecret ensures a manager secret exists in the database
func (s *Secrets) ensureManagerSecret() error {
var count int
err := s.db.QueryRow("SELECT COUNT(*) FROM manager_secrets").Scan(&count)
if err != nil {
return fmt.Errorf("failed to check manager secrets: %w", err)
}
if count == 0 {
// Generate new manager secret
secret, err := generateSecret(32)
if err != nil {
return fmt.Errorf("failed to generate manager secret: %w", err)
}
_, err = s.db.Exec("INSERT INTO manager_secrets (secret) VALUES (?)", secret)
if err != nil {
return fmt.Errorf("failed to store manager secret: %w", err)
}
}
return nil
// APIKeyInfo represents information about an API key
type APIKeyInfo struct {
ID int64 `json:"id"`
Key string `json:"key"`
Name string `json:"name"`
Description *string `json:"description,omitempty"`
Scope string `json:"scope"` // 'manager' or 'user'
IsActive bool `json:"is_active"`
CreatedAt time.Time `json:"created_at"`
CreatedBy int64 `json:"created_by"`
}
// GetManagerSecret retrieves the current manager secret
func (s *Secrets) GetManagerSecret() (string, error) {
var secret string
err := s.db.QueryRow("SELECT secret FROM manager_secrets ORDER BY created_at DESC LIMIT 1").Scan(&secret)
// GenerateRunnerAPIKey generates a new API key for runners
func (s *Secrets) GenerateRunnerAPIKey(createdBy int64, name, description string, scope string) (*APIKeyInfo, error) {
// Generate API key in format: jk_r1_abc123def456...
key, err := s.generateAPIKey()
if err != nil {
return "", fmt.Errorf("failed to get manager secret: %w", err)
}
return secret, nil
}
// GenerateRegistrationToken generates a new registration token
func (s *Secrets) GenerateRegistrationToken(createdBy int64, expiresIn time.Duration) (string, error) {
token, err := generateSecret(32)
if err != nil {
return "", fmt.Errorf("failed to generate token: %w", err)
return nil, fmt.Errorf("failed to generate API key: %w", err)
}
expiresAt := time.Now().Add(expiresIn)
// Extract prefix (first 5 chars after "jk_") and hash the full key
parts := strings.Split(key, "_")
if len(parts) < 3 {
return nil, fmt.Errorf("invalid API key format generated")
}
keyPrefix := fmt.Sprintf("%s_%s", parts[0], parts[1])
_, err = s.db.Exec(
"INSERT INTO registration_tokens (token, expires_at, created_by) VALUES (?, ?, ?)",
token, expiresAt, createdBy,
keyHash := sha256.Sum256([]byte(key))
keyHashStr := hex.EncodeToString(keyHash[:])
var keyInfo APIKeyInfo
err = s.db.With(func(conn *sql.DB) error {
result, err := conn.Exec(
`INSERT INTO runner_api_keys (key_prefix, key_hash, name, description, scope, is_active, created_by)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
keyPrefix, keyHashStr, name, description, scope, true, createdBy,
)
if err != nil {
return "", fmt.Errorf("failed to store registration token: %w", err)
return fmt.Errorf("failed to store API key: %w", err)
}
keyID, err := result.LastInsertId()
if err != nil {
return fmt.Errorf("failed to get inserted key ID: %w", err)
}
return token, nil
}
// Get the inserted key info
err = conn.QueryRow(
`SELECT id, name, description, scope, is_active, created_at, created_by
FROM runner_api_keys WHERE id = ?`,
keyID,
).Scan(&keyInfo.ID, &keyInfo.Name, &keyInfo.Description, &keyInfo.Scope, &keyInfo.IsActive, &keyInfo.CreatedAt, &keyInfo.CreatedBy)
// TokenValidationResult represents the result of token validation
type TokenValidationResult struct {
Valid bool
Reason string // "valid", "not_found", "already_used", "expired"
Error error
}
return err
})
// ValidateRegistrationToken validates a registration token
func (s *Secrets) ValidateRegistrationToken(token string) (bool, error) {
result, err := s.ValidateRegistrationTokenDetailed(token)
if err != nil {
return false, err
return nil, fmt.Errorf("failed to create API key: %w", err)
}
// For backward compatibility, return just the valid boolean
return result.Valid, nil
keyInfo.Key = key
return &keyInfo, nil
}
// ValidateRegistrationTokenDetailed validates a registration token and returns detailed result
func (s *Secrets) ValidateRegistrationTokenDetailed(token string) (*TokenValidationResult, error) {
// Check fixed token first (if set) - it's reusable and never expires
if s.fixedRegistrationToken != "" && token == s.fixedRegistrationToken {
log.Printf("Fixed registration token used (from FIXED_REGISTRATION_TOKEN env var)")
return &TokenValidationResult{Valid: true, Reason: "valid"}, nil
// generateAPIKey generates a new API key in format jk_r1_abc123def456...
func (s *Secrets) generateAPIKey() (string, error) {
// Generate random suffix
randomBytes := make([]byte, 16)
if _, err := rand.Read(randomBytes); err != nil {
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
randomStr := hex.EncodeToString(randomBytes)
// Generate a unique prefix (jk_r followed by 1 random digit)
prefixDigit := make([]byte, 1)
if _, err := rand.Read(prefixDigit); err != nil {
return "", fmt.Errorf("failed to generate prefix digit: %w", err)
}
// Check database tokens
var used bool
var expiresAt time.Time
var id int64
prefix := fmt.Sprintf("jk_r%d", prefixDigit[0]%10)
key := fmt.Sprintf("%s_%s", prefix, randomStr)
err := s.db.QueryRow(
"SELECT id, expires_at, used FROM registration_tokens WHERE token = ?",
token,
).Scan(&id, &expiresAt, &used)
// Validate generated key format
if !strings.HasPrefix(key, "jk_r") {
return "", fmt.Errorf("generated invalid API key format: %s", key)
}
return key, nil
}
// ValidateRunnerAPIKey validates an API key and returns the key ID and scope if valid
func (s *Secrets) ValidateRunnerAPIKey(apiKey string) (int64, string, error) {
if apiKey == "" {
return 0, "", fmt.Errorf("API key is required")
}
// Check fixed API key first (from database config)
fixedKey := s.cfg.FixedAPIKey()
if fixedKey != "" && apiKey == fixedKey {
// Return a special ID for fixed API key (doesn't exist in database)
return -1, "manager", nil
}
// Parse API key format: jk_rX_...
if !strings.HasPrefix(apiKey, "jk_r") {
return 0, "", fmt.Errorf("invalid API key format: expected format 'jk_rX_...' where X is a number (e.g., 'jk_r1_abc123...')")
}
parts := strings.Split(apiKey, "_")
if len(parts) < 3 {
return 0, "", fmt.Errorf("invalid API key format: expected format 'jk_rX_...' with at least 3 parts separated by underscores")
}
keyPrefix := fmt.Sprintf("%s_%s", parts[0], parts[1])
// Hash the full key for comparison
keyHash := sha256.Sum256([]byte(apiKey))
keyHashStr := hex.EncodeToString(keyHash[:])
var keyID int64
var scope string
var isActive bool
err := s.db.With(func(conn *sql.DB) error {
err := conn.QueryRow(
`SELECT id, scope, is_active FROM runner_api_keys
WHERE key_prefix = ? AND key_hash = ?`,
keyPrefix, keyHashStr,
).Scan(&keyID, &scope, &isActive)
if err == sql.ErrNoRows {
return &TokenValidationResult{Valid: false, Reason: "not_found"}, nil
return fmt.Errorf("API key not found or invalid - please check that the key is correct and active")
}
if err != nil {
return nil, fmt.Errorf("failed to query token: %w", err)
return fmt.Errorf("failed to validate API key: %w", err)
}
if used {
return &TokenValidationResult{Valid: false, Reason: "already_used"}, nil
if !isActive {
return fmt.Errorf("API key is inactive")
}
if time.Now().After(expiresAt) {
return &TokenValidationResult{Valid: false, Reason: "expired"}, nil
}
// Update last_used_at (don't fail if this update fails)
conn.Exec(`UPDATE runner_api_keys SET last_used_at = ? WHERE id = ?`, time.Now(), keyID)
return nil
})
// Mark token as used
_, err = s.db.Exec("UPDATE registration_tokens SET used = 1 WHERE id = ?", id)
if err != nil {
return nil, fmt.Errorf("failed to mark token as used: %w", err)
return 0, "", err
}
return &TokenValidationResult{Valid: true, Reason: "valid"}, nil
return keyID, scope, nil
}
// ListRegistrationTokens lists all registration tokens
func (s *Secrets) ListRegistrationTokens() ([]map[string]interface{}, error) {
rows, err := s.db.Query(
`SELECT id, token, expires_at, used, created_at, created_by
FROM registration_tokens
// ListRunnerAPIKeys lists all runner API keys
func (s *Secrets) ListRunnerAPIKeys() ([]APIKeyInfo, error) {
var keys []APIKeyInfo
err := s.db.With(func(conn *sql.DB) error {
rows, err := conn.Query(
`SELECT id, key_prefix, name, description, scope, is_active, created_at, created_by
FROM runner_api_keys
ORDER BY created_at DESC`,
)
if err != nil {
return nil, fmt.Errorf("failed to query tokens: %w", err)
return fmt.Errorf("failed to query API keys: %w", err)
}
defer rows.Close()
var tokens []map[string]interface{}
for rows.Next() {
var id, createdBy sql.NullInt64
var token string
var expiresAt, createdAt time.Time
var used bool
var key APIKeyInfo
var description sql.NullString
err := rows.Scan(&id, &token, &expiresAt, &used, &createdAt, &createdBy)
err := rows.Scan(&key.ID, &key.Key, &key.Name, &description, &key.Scope, &key.IsActive, &key.CreatedAt, &key.CreatedBy)
if err != nil {
continue
}
tokens = append(tokens, map[string]interface{}{
"id": id.Int64,
"token": token,
"expires_at": expiresAt,
"used": used,
"created_at": createdAt,
"created_by": createdBy.Int64,
})
}
if description.Valid {
key.Description = &description.String
}
return tokens, nil
keys = append(keys, key)
}
return nil
})
if err != nil {
return nil, err
}
return keys, nil
}
// RevokeRegistrationToken revokes a registration token
func (s *Secrets) RevokeRegistrationToken(tokenID int64) error {
_, err := s.db.Exec("UPDATE registration_tokens SET used = 1 WHERE id = ?", tokenID)
// RevokeRunnerAPIKey revokes (deactivates) a runner API key
func (s *Secrets) RevokeRunnerAPIKey(keyID int64) error {
return s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec("UPDATE runner_api_keys SET is_active = false WHERE id = ?", keyID)
return err
})
}
// GenerateRunnerSecret generates a unique secret for a runner
func (s *Secrets) GenerateRunnerSecret() (string, error) {
return generateSecret(32)
// DeleteRunnerAPIKey deletes a runner API key
func (s *Secrets) DeleteRunnerAPIKey(keyID int64) error {
return s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec("DELETE FROM runner_api_keys WHERE id = ?", keyID)
return err
})
}
// SignRequest signs a request with the given secret
func SignRequest(method, path, body, secret string, timestamp time.Time) string {
message := fmt.Sprintf("%s\n%s\n%s\n%d", method, path, body, timestamp.Unix())
h := hmac.New(sha256.New, []byte(secret))
h.Write([]byte(message))
return hex.EncodeToString(h.Sum(nil))
}
// VerifyRequest verifies a signed request
func VerifyRequest(r *http.Request, secret string, maxAge time.Duration) (bool, error) {
signature := r.Header.Get("X-Runner-Signature")
if signature == "" {
return false, fmt.Errorf("missing signature")
}
timestampStr := r.Header.Get("X-Runner-Timestamp")
if timestampStr == "" {
return false, fmt.Errorf("missing timestamp")
}
var timestampUnix int64
_, err := fmt.Sscanf(timestampStr, "%d", &timestampUnix)
if err != nil {
return false, fmt.Errorf("invalid timestamp: %w", err)
}
timestamp := time.Unix(timestampUnix, 0)
// Check timestamp is not too old
if time.Since(timestamp) > maxAge {
return false, fmt.Errorf("request too old")
}
// Check timestamp is not in the future (allow 1 minute clock skew)
if timestamp.After(time.Now().Add(1 * time.Minute)) {
return false, fmt.Errorf("timestamp in future")
}
// Read body
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
return false, fmt.Errorf("failed to read body: %w", err)
}
// Restore body for handler
r.Body = io.NopCloser(strings.NewReader(string(bodyBytes)))
// Verify signature - use path without query parameters (query params are not part of signature)
// The runner signs with the path including query params, but we verify with just the path
// This is intentional - query params are for identification, not part of the signature
path := r.URL.Path
expectedSig := SignRequest(r.Method, path, string(bodyBytes), secret, timestamp)
return hmac.Equal([]byte(signature), []byte(expectedSig)), nil
}
// GetRunnerSecret retrieves the runner secret for a runner ID
func (s *Secrets) GetRunnerSecret(runnerID int64) (string, error) {
var secret string
err := s.db.QueryRow("SELECT runner_secret FROM runners WHERE id = ?", runnerID).Scan(&secret)
if err == sql.ErrNoRows {
return "", fmt.Errorf("runner not found")
}
if err != nil {
return "", fmt.Errorf("failed to get runner secret: %w", err)
}
if secret == "" {
return "", fmt.Errorf("runner not verified")
}
return secret, nil
}
// generateSecret generates a random secret of the given length
func generateSecret(length int) (string, error) {

303
internal/config/config.go Normal file
View File

@@ -0,0 +1,303 @@
package config
import (
"database/sql"
"fmt"
"jiggablend/internal/database"
"log"
"os"
"strconv"
)
// Config keys stored in database
const (
KeyGoogleClientID = "google_client_id"
KeyGoogleClientSecret = "google_client_secret"
KeyGoogleRedirectURL = "google_redirect_url"
KeyDiscordClientID = "discord_client_id"
KeyDiscordClientSecret = "discord_client_secret"
KeyDiscordRedirectURL = "discord_redirect_url"
KeyEnableLocalAuth = "enable_local_auth"
KeyFixedAPIKey = "fixed_api_key"
KeyRegistrationEnabled = "registration_enabled"
KeyProductionMode = "production_mode"
KeyAllowedOrigins = "allowed_origins"
)
// Config manages application configuration stored in the database
type Config struct {
db *database.DB
}
// NewConfig creates a new config manager
func NewConfig(db *database.DB) *Config {
return &Config{db: db}
}
// InitializeFromEnv loads configuration from environment variables on first run
// Environment variables take precedence only if the config key doesn't exist in the database
// This allows first-run setup via env vars, then subsequent runs use database values
func (c *Config) InitializeFromEnv() error {
envMappings := []struct {
envKey string
configKey string
sensitive bool
}{
{"GOOGLE_CLIENT_ID", KeyGoogleClientID, false},
{"GOOGLE_CLIENT_SECRET", KeyGoogleClientSecret, true},
{"GOOGLE_REDIRECT_URL", KeyGoogleRedirectURL, false},
{"DISCORD_CLIENT_ID", KeyDiscordClientID, false},
{"DISCORD_CLIENT_SECRET", KeyDiscordClientSecret, true},
{"DISCORD_REDIRECT_URL", KeyDiscordRedirectURL, false},
{"ENABLE_LOCAL_AUTH", KeyEnableLocalAuth, false},
{"FIXED_API_KEY", KeyFixedAPIKey, true},
{"PRODUCTION", KeyProductionMode, false},
{"ALLOWED_ORIGINS", KeyAllowedOrigins, false},
}
for _, mapping := range envMappings {
envValue := os.Getenv(mapping.envKey)
if envValue == "" {
continue
}
// Check if config already exists in database
exists, err := c.Exists(mapping.configKey)
if err != nil {
return fmt.Errorf("failed to check config %s: %w", mapping.configKey, err)
}
if !exists {
// Store env value in database
if err := c.Set(mapping.configKey, envValue); err != nil {
return fmt.Errorf("failed to store config %s: %w", mapping.configKey, err)
}
if mapping.sensitive {
log.Printf("Stored config from env: %s = [REDACTED]", mapping.configKey)
} else {
log.Printf("Stored config from env: %s = %s", mapping.configKey, envValue)
}
}
}
return nil
}
// Get retrieves a config value from the database
func (c *Config) Get(key string) (string, error) {
var value string
err := c.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT value FROM settings WHERE key = ?", key).Scan(&value)
})
if err == sql.ErrNoRows {
return "", nil
}
if err != nil {
return "", fmt.Errorf("failed to get config %s: %w", key, err)
}
return value, nil
}
// GetWithDefault retrieves a config value or returns a default if not set
func (c *Config) GetWithDefault(key, defaultValue string) string {
value, err := c.Get(key)
if err != nil || value == "" {
return defaultValue
}
return value
}
// GetBool retrieves a boolean config value
func (c *Config) GetBool(key string) (bool, error) {
value, err := c.Get(key)
if err != nil {
return false, err
}
return value == "true" || value == "1", nil
}
// GetBoolWithDefault retrieves a boolean config value or returns a default
func (c *Config) GetBoolWithDefault(key string, defaultValue bool) bool {
value, err := c.GetBool(key)
if err != nil {
return defaultValue
}
// If the key doesn't exist, Get returns empty string which becomes false
// Check if key exists to distinguish between "false" and "not set"
exists, _ := c.Exists(key)
if !exists {
return defaultValue
}
return value
}
// GetInt retrieves an integer config value
func (c *Config) GetInt(key string) (int, error) {
value, err := c.Get(key)
if err != nil {
return 0, err
}
if value == "" {
return 0, nil
}
return strconv.Atoi(value)
}
// GetIntWithDefault retrieves an integer config value or returns a default
func (c *Config) GetIntWithDefault(key string, defaultValue int) int {
value, err := c.GetInt(key)
if err != nil {
return defaultValue
}
exists, _ := c.Exists(key)
if !exists {
return defaultValue
}
return value
}
// Set stores a config value in the database
func (c *Config) Set(key, value string) error {
// Use upsert pattern
exists, err := c.Exists(key)
if err != nil {
return err
}
err = c.db.With(func(conn *sql.DB) error {
if exists {
_, err = conn.Exec(
"UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?",
value, key,
)
} else {
_, err = conn.Exec(
"INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)",
key, value,
)
}
return err
})
if err != nil {
return fmt.Errorf("failed to set config %s: %w", key, err)
}
return nil
}
// SetBool stores a boolean config value
func (c *Config) SetBool(key string, value bool) error {
strValue := "false"
if value {
strValue = "true"
}
return c.Set(key, strValue)
}
// SetInt stores an integer config value
func (c *Config) SetInt(key string, value int) error {
return c.Set(key, strconv.Itoa(value))
}
// Delete removes a config value from the database
func (c *Config) Delete(key string) error {
err := c.db.With(func(conn *sql.DB) error {
_, err := conn.Exec("DELETE FROM settings WHERE key = ?", key)
return err
})
if err != nil {
return fmt.Errorf("failed to delete config %s: %w", key, err)
}
return nil
}
// Exists checks if a config key exists in the database
func (c *Config) Exists(key string) (bool, error) {
var exists bool
err := c.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", key).Scan(&exists)
})
if err != nil {
return false, fmt.Errorf("failed to check config existence %s: %w", key, err)
}
return exists, nil
}
// GetAll returns all config values (for debugging/admin purposes)
func (c *Config) GetAll() (map[string]string, error) {
var result map[string]string
err := c.db.With(func(conn *sql.DB) error {
rows, err := conn.Query("SELECT key, value FROM settings")
if err != nil {
return fmt.Errorf("failed to get all config: %w", err)
}
defer rows.Close()
result = make(map[string]string)
for rows.Next() {
var key, value string
if err := rows.Scan(&key, &value); err != nil {
return fmt.Errorf("failed to scan config row: %w", err)
}
result[key] = value
}
return nil
})
if err != nil {
return nil, err
}
return result, nil
}
// --- Convenience methods for specific config values ---
// GoogleClientID returns the Google OAuth client ID
func (c *Config) GoogleClientID() string {
return c.GetWithDefault(KeyGoogleClientID, "")
}
// GoogleClientSecret returns the Google OAuth client secret
func (c *Config) GoogleClientSecret() string {
return c.GetWithDefault(KeyGoogleClientSecret, "")
}
// GoogleRedirectURL returns the Google OAuth redirect URL
func (c *Config) GoogleRedirectURL() string {
return c.GetWithDefault(KeyGoogleRedirectURL, "")
}
// DiscordClientID returns the Discord OAuth client ID
func (c *Config) DiscordClientID() string {
return c.GetWithDefault(KeyDiscordClientID, "")
}
// DiscordClientSecret returns the Discord OAuth client secret
func (c *Config) DiscordClientSecret() string {
return c.GetWithDefault(KeyDiscordClientSecret, "")
}
// DiscordRedirectURL returns the Discord OAuth redirect URL
func (c *Config) DiscordRedirectURL() string {
return c.GetWithDefault(KeyDiscordRedirectURL, "")
}
// IsLocalAuthEnabled returns whether local authentication is enabled
func (c *Config) IsLocalAuthEnabled() bool {
return c.GetBoolWithDefault(KeyEnableLocalAuth, false)
}
// FixedAPIKey returns the fixed API key for testing
func (c *Config) FixedAPIKey() string {
return c.GetWithDefault(KeyFixedAPIKey, "")
}
// IsProductionMode returns whether production mode is enabled
func (c *Config) IsProductionMode() bool {
return c.GetBoolWithDefault(KeyProductionMode, false)
}
// AllowedOrigins returns the allowed CORS origins
func (c *Config) AllowedOrigins() string {
return c.GetWithDefault(KeyAllowedOrigins, "")
}

View File

@@ -0,0 +1,36 @@
-- Drop indexes
DROP INDEX IF EXISTS idx_sessions_expires_at;
DROP INDEX IF EXISTS idx_sessions_user_id;
DROP INDEX IF EXISTS idx_sessions_session_id;
DROP INDEX IF EXISTS idx_runners_last_heartbeat;
DROP INDEX IF EXISTS idx_task_steps_task_id;
DROP INDEX IF EXISTS idx_task_logs_runner_id;
DROP INDEX IF EXISTS idx_task_logs_task_id_id;
DROP INDEX IF EXISTS idx_task_logs_task_id_created_at;
DROP INDEX IF EXISTS idx_runners_api_key_id;
DROP INDEX IF EXISTS idx_runner_api_keys_created_by;
DROP INDEX IF EXISTS idx_runner_api_keys_active;
DROP INDEX IF EXISTS idx_runner_api_keys_prefix;
DROP INDEX IF EXISTS idx_job_files_job_id;
DROP INDEX IF EXISTS idx_tasks_started_at;
DROP INDEX IF EXISTS idx_tasks_job_status;
DROP INDEX IF EXISTS idx_tasks_status;
DROP INDEX IF EXISTS idx_tasks_runner_id;
DROP INDEX IF EXISTS idx_tasks_job_id;
DROP INDEX IF EXISTS idx_jobs_user_status_created;
DROP INDEX IF EXISTS idx_jobs_status;
DROP INDEX IF EXISTS idx_jobs_user_id;
-- Drop tables (order matters due to foreign keys)
DROP TABLE IF EXISTS sessions;
DROP TABLE IF EXISTS settings;
DROP TABLE IF EXISTS task_steps;
DROP TABLE IF EXISTS task_logs;
DROP TABLE IF EXISTS manager_secrets;
DROP TABLE IF EXISTS job_files;
DROP TABLE IF EXISTS tasks;
DROP TABLE IF EXISTS runners;
DROP TABLE IF EXISTS jobs;
DROP TABLE IF EXISTS runner_api_keys;
DROP TABLE IF EXISTS users;

View File

@@ -0,0 +1,184 @@
-- Enable foreign keys for SQLite
PRAGMA foreign_keys = ON;
-- Users table
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
email TEXT UNIQUE NOT NULL,
name TEXT NOT NULL,
oauth_provider TEXT NOT NULL,
oauth_id TEXT NOT NULL,
password_hash TEXT,
is_admin INTEGER NOT NULL DEFAULT 0,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
UNIQUE(oauth_provider, oauth_id)
);
-- Runner API keys table
CREATE TABLE runner_api_keys (
id INTEGER PRIMARY KEY AUTOINCREMENT,
key_prefix TEXT NOT NULL,
key_hash TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT,
scope TEXT NOT NULL DEFAULT 'user',
is_active INTEGER NOT NULL DEFAULT 1,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
created_by INTEGER,
FOREIGN KEY (created_by) REFERENCES users(id),
UNIQUE(key_prefix)
);
-- Jobs table
CREATE TABLE jobs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
job_type TEXT NOT NULL DEFAULT 'render',
name TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
progress REAL NOT NULL DEFAULT 0.0,
frame_start INTEGER,
frame_end INTEGER,
output_format TEXT,
blend_metadata TEXT,
retry_count INTEGER NOT NULL DEFAULT 0,
max_retries INTEGER NOT NULL DEFAULT 3,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
error_message TEXT,
assigned_runner_id INTEGER,
FOREIGN KEY (user_id) REFERENCES users(id)
);
-- Runners table
CREATE TABLE runners (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
hostname TEXT NOT NULL,
ip_address TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'offline',
last_heartbeat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
capabilities TEXT,
api_key_id INTEGER,
api_key_scope TEXT NOT NULL DEFAULT 'user',
priority INTEGER NOT NULL DEFAULT 100,
fingerprint TEXT,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (api_key_id) REFERENCES runner_api_keys(id)
);
-- Tasks table
CREATE TABLE tasks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
job_id INTEGER NOT NULL,
runner_id INTEGER,
frame INTEGER NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
output_path TEXT,
task_type TEXT NOT NULL DEFAULT 'render',
current_step TEXT,
retry_count INTEGER NOT NULL DEFAULT 0,
max_retries INTEGER NOT NULL DEFAULT 3,
runner_failure_count INTEGER NOT NULL DEFAULT 0,
timeout_seconds INTEGER,
condition TEXT,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
error_message TEXT,
FOREIGN KEY (job_id) REFERENCES jobs(id),
FOREIGN KEY (runner_id) REFERENCES runners(id)
);
-- Job files table
CREATE TABLE job_files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
job_id INTEGER NOT NULL,
file_type TEXT NOT NULL,
file_path TEXT NOT NULL,
file_name TEXT NOT NULL,
file_size INTEGER NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (job_id) REFERENCES jobs(id)
);
-- Manager secrets table
CREATE TABLE manager_secrets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
secret TEXT UNIQUE NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
-- Task logs table
CREATE TABLE task_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_id INTEGER NOT NULL,
runner_id INTEGER,
log_level TEXT NOT NULL,
message TEXT NOT NULL,
step_name TEXT,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (task_id) REFERENCES tasks(id),
FOREIGN KEY (runner_id) REFERENCES runners(id)
);
-- Task steps table
CREATE TABLE task_steps (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_id INTEGER NOT NULL,
step_name TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
started_at TIMESTAMP,
completed_at TIMESTAMP,
duration_ms INTEGER,
error_message TEXT,
FOREIGN KEY (task_id) REFERENCES tasks(id)
);
-- Settings table
CREATE TABLE settings (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
-- Sessions table
CREATE TABLE sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT UNIQUE NOT NULL,
user_id INTEGER NOT NULL,
email TEXT NOT NULL,
name TEXT NOT NULL,
is_admin INTEGER NOT NULL DEFAULT 0,
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users(id)
);
-- Indexes
CREATE INDEX idx_jobs_user_id ON jobs(user_id);
CREATE INDEX idx_jobs_status ON jobs(status);
CREATE INDEX idx_jobs_user_status_created ON jobs(user_id, status, created_at DESC);
CREATE INDEX idx_tasks_job_id ON tasks(job_id);
CREATE INDEX idx_tasks_runner_id ON tasks(runner_id);
CREATE INDEX idx_tasks_status ON tasks(status);
CREATE INDEX idx_tasks_job_status ON tasks(job_id, status);
CREATE INDEX idx_tasks_started_at ON tasks(started_at);
CREATE INDEX idx_job_files_job_id ON job_files(job_id);
CREATE INDEX idx_runner_api_keys_prefix ON runner_api_keys(key_prefix);
CREATE INDEX idx_runner_api_keys_active ON runner_api_keys(is_active);
CREATE INDEX idx_runner_api_keys_created_by ON runner_api_keys(created_by);
CREATE INDEX idx_runners_api_key_id ON runners(api_key_id);
CREATE INDEX idx_task_logs_task_id_created_at ON task_logs(task_id, created_at);
CREATE INDEX idx_task_logs_task_id_id ON task_logs(task_id, id DESC);
CREATE INDEX idx_task_logs_runner_id ON task_logs(runner_id);
CREATE INDEX idx_task_steps_task_id ON task_steps(task_id);
CREATE INDEX idx_runners_last_heartbeat ON runners(last_heartbeat);
CREATE INDEX idx_sessions_session_id ON sessions(session_id);
CREATE INDEX idx_sessions_user_id ON sessions(user_id);
CREATE INDEX idx_sessions_expires_at ON sessions(expires_at);
-- Initialize registration_enabled setting
INSERT INTO settings (key, value, updated_at) VALUES ('registration_enabled', 'true', CURRENT_TIMESTAMP);

View File

@@ -2,252 +2,158 @@ package database
import (
"database/sql"
"embed"
"fmt"
"io/fs"
"log"
_ "github.com/marcboeker/go-duckdb/v2"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/golang-migrate/migrate/v4/source/iofs"
_ "github.com/mattn/go-sqlite3"
)
//go:embed migrations/*.sql
var migrationsFS embed.FS
// DB wraps the database connection
// Note: No mutex needed - we only have one connection per process and SQLite with WAL mode
// handles concurrent access safely
type DB struct {
*sql.DB
db *sql.DB
}
// NewDB creates a new database connection
func NewDB(dbPath string) (*DB, error) {
db, err := sql.Open("duckdb", dbPath)
// Use WAL mode for better concurrency (allows readers and writers simultaneously)
// Add timeout and busy handler for better concurrent access
db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_busy_timeout=5000")
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
// Configure connection pool for better concurrency
// SQLite with WAL mode supports multiple concurrent readers and one writer
// Increasing pool size allows multiple HTTP requests to query the database simultaneously
// This prevents blocking when multiple requests come in (e.g., on page refresh)
db.SetMaxOpenConns(10) // Allow up to 10 concurrent connections
db.SetMaxIdleConns(5) // Keep 5 idle connections ready
db.SetConnMaxLifetime(0) // Connections don't expire
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
database := &DB{DB: db}
// Enable foreign keys for SQLite
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
}
// Enable WAL mode explicitly (in case the connection string didn't work)
if _, err := db.Exec("PRAGMA journal_mode = WAL"); err != nil {
log.Printf("Warning: Failed to enable WAL mode: %v", err)
}
database := &DB{db: db}
if err := database.migrate(); err != nil {
return nil, fmt.Errorf("failed to migrate database: %w", err)
}
// Verify connection is still open after migration
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("database connection closed after migration: %w", err)
}
return database, nil
}
// migrate runs database migrations
// With executes a function with access to the database
// The function receives the underlying *sql.DB connection
// No mutex needed - single connection + WAL mode handles concurrency
func (db *DB) With(fn func(*sql.DB) error) error {
return fn(db.db)
}
// WithTx executes a function within a transaction
// The function receives a *sql.Tx transaction
// If the function returns an error, the transaction is rolled back
// If the function returns nil, the transaction is committed
// No mutex needed - single connection + WAL mode handles concurrency
func (db *DB) WithTx(fn func(*sql.Tx) error) error {
tx, err := db.db.Begin()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
if err := fn(tx); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("transaction error: %w, rollback error: %v", err, rbErr)
}
return err
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
// migrate runs database migrations using golang-migrate
func (db *DB) migrate() error {
// Create sequences for auto-incrementing primary keys
sequences := []string{
`CREATE SEQUENCE IF NOT EXISTS seq_users_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_jobs_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_runners_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_tasks_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_job_files_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_manager_secrets_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_registration_tokens_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_task_logs_id START 1`,
`CREATE SEQUENCE IF NOT EXISTS seq_task_steps_id START 1`,
// Create SQLite driver instance
// Note: We use db.db directly since we're in the same package and this is called during initialization
driver, err := sqlite3.WithInstance(db.db, &sqlite3.Config{})
if err != nil {
return fmt.Errorf("failed to create sqlite3 driver: %w", err)
}
for _, seq := range sequences {
if _, err := db.Exec(seq); err != nil {
return fmt.Errorf("failed to create sequence: %w", err)
// Create embedded filesystem source
migrationFS, err := fs.Sub(migrationsFS, "migrations")
if err != nil {
return fmt.Errorf("failed to create migration filesystem: %w", err)
}
sourceDriver, err := iofs.New(migrationFS, ".")
if err != nil {
return fmt.Errorf("failed to create iofs source driver: %w", err)
}
// Create migrate instance
m, err := migrate.NewWithInstance("iofs", sourceDriver, "sqlite3", driver)
if err != nil {
return fmt.Errorf("failed to create migrate instance: %w", err)
}
// Run migrations
if err := m.Up(); err != nil {
// If the error is "no change", that's fine - database is already up to date
if err == migrate.ErrNoChange {
log.Printf("Database is already up to date")
// Don't close migrate instance - it may close the database connection
// The migrate instance will be garbage collected
return nil
}
// Don't close migrate instance on error either - it may close the DB
return fmt.Errorf("failed to run migrations: %w", err)
}
schema := `
CREATE TABLE IF NOT EXISTS users (
id BIGINT PRIMARY KEY DEFAULT nextval('seq_users_id'),
email TEXT UNIQUE NOT NULL,
name TEXT NOT NULL,
oauth_provider TEXT NOT NULL,
oauth_id TEXT NOT NULL,
password_hash TEXT,
is_admin BOOLEAN NOT NULL DEFAULT false,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
UNIQUE(oauth_provider, oauth_id)
);
CREATE TABLE IF NOT EXISTS jobs (
id BIGINT PRIMARY KEY DEFAULT nextval('seq_jobs_id'),
user_id BIGINT NOT NULL,
job_type TEXT NOT NULL DEFAULT 'render',
name TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
progress REAL NOT NULL DEFAULT 0.0,
frame_start INTEGER,
frame_end INTEGER,
output_format TEXT,
allow_parallel_runners BOOLEAN,
timeout_seconds INTEGER DEFAULT 86400,
blend_metadata TEXT,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
error_message TEXT,
FOREIGN KEY (user_id) REFERENCES users(id)
);
CREATE TABLE IF NOT EXISTS runners (
id BIGINT PRIMARY KEY DEFAULT nextval('seq_runners_id'),
name TEXT NOT NULL,
hostname TEXT NOT NULL,
ip_address TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'offline',
last_heartbeat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
capabilities TEXT,
registration_token TEXT,
runner_secret TEXT,
manager_secret TEXT,
verified BOOLEAN NOT NULL DEFAULT false,
priority INTEGER NOT NULL DEFAULT 100,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS tasks (
id BIGINT PRIMARY KEY DEFAULT nextval('seq_tasks_id'),
job_id BIGINT NOT NULL,
runner_id BIGINT,
frame_start INTEGER NOT NULL,
frame_end INTEGER NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
output_path TEXT,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
error_message TEXT
);
CREATE TABLE IF NOT EXISTS job_files (
id BIGINT PRIMARY KEY DEFAULT nextval('seq_job_files_id'),
job_id BIGINT NOT NULL,
file_type TEXT NOT NULL,
file_path TEXT NOT NULL,
file_name TEXT NOT NULL,
file_size INTEGER NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS manager_secrets (
id BIGINT PRIMARY KEY DEFAULT nextval('seq_manager_secrets_id'),
secret TEXT UNIQUE NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS registration_tokens (
id BIGINT PRIMARY KEY DEFAULT nextval('seq_registration_tokens_id'),
token TEXT UNIQUE NOT NULL,
expires_at TIMESTAMP NOT NULL,
used BOOLEAN NOT NULL DEFAULT false,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
created_by BIGINT,
FOREIGN KEY (created_by) REFERENCES users(id)
);
CREATE TABLE IF NOT EXISTS task_logs (
id BIGINT PRIMARY KEY DEFAULT nextval('seq_task_logs_id'),
task_id BIGINT NOT NULL,
runner_id BIGINT,
log_level TEXT NOT NULL,
message TEXT NOT NULL,
step_name TEXT,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS task_steps (
id BIGINT PRIMARY KEY DEFAULT nextval('seq_task_steps_id'),
task_id BIGINT NOT NULL,
step_name TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
started_at TIMESTAMP,
completed_at TIMESTAMP,
duration_ms INTEGER,
error_message TEXT
);
CREATE INDEX IF NOT EXISTS idx_jobs_user_id ON jobs(user_id);
CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status);
CREATE INDEX IF NOT EXISTS idx_tasks_job_id ON tasks(job_id);
CREATE INDEX IF NOT EXISTS idx_tasks_runner_id ON tasks(runner_id);
CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status);
CREATE INDEX IF NOT EXISTS idx_tasks_started_at ON tasks(started_at);
CREATE INDEX IF NOT EXISTS idx_job_files_job_id ON job_files(job_id);
CREATE INDEX IF NOT EXISTS idx_registration_tokens_token ON registration_tokens(token);
CREATE INDEX IF NOT EXISTS idx_registration_tokens_expires_at ON registration_tokens(expires_at);
CREATE INDEX IF NOT EXISTS idx_task_logs_task_id_created_at ON task_logs(task_id, created_at);
CREATE INDEX IF NOT EXISTS idx_task_logs_runner_id ON task_logs(runner_id);
CREATE INDEX IF NOT EXISTS idx_task_steps_task_id ON task_steps(task_id);
CREATE INDEX IF NOT EXISTS idx_runners_last_heartbeat ON runners(last_heartbeat);
CREATE TABLE IF NOT EXISTS settings (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
`
if _, err := db.Exec(schema); err != nil {
return fmt.Errorf("failed to create schema: %w", err)
}
// Migrate existing tables to add new columns
migrations := []string{
// Add is_admin to users if it doesn't exist
`ALTER TABLE users ADD COLUMN IF NOT EXISTS is_admin BOOLEAN NOT NULL DEFAULT false`,
// Add new columns to runners if they don't exist
`ALTER TABLE runners ADD COLUMN IF NOT EXISTS registration_token TEXT`,
`ALTER TABLE runners ADD COLUMN IF NOT EXISTS runner_secret TEXT`,
`ALTER TABLE runners ADD COLUMN IF NOT EXISTS manager_secret TEXT`,
`ALTER TABLE runners ADD COLUMN IF NOT EXISTS verified BOOLEAN NOT NULL DEFAULT false`,
`ALTER TABLE runners ADD COLUMN IF NOT EXISTS priority INTEGER NOT NULL DEFAULT 100`,
// Add allow_parallel_runners to jobs if it doesn't exist
`ALTER TABLE jobs ADD COLUMN IF NOT EXISTS allow_parallel_runners BOOLEAN NOT NULL DEFAULT true`,
// Add timeout_seconds to jobs if it doesn't exist
`ALTER TABLE jobs ADD COLUMN IF NOT EXISTS timeout_seconds INTEGER DEFAULT 86400`,
// Add blend_metadata to jobs if it doesn't exist
`ALTER TABLE jobs ADD COLUMN IF NOT EXISTS blend_metadata TEXT`,
// Add job_type to jobs if it doesn't exist
`ALTER TABLE jobs ADD COLUMN IF NOT EXISTS job_type TEXT DEFAULT 'render'`,
// Add task_type to tasks if it doesn't exist
`ALTER TABLE tasks ADD COLUMN IF NOT EXISTS task_type TEXT DEFAULT 'render'`,
// Add new columns to tasks if they don't exist
`ALTER TABLE tasks ADD COLUMN IF NOT EXISTS current_step TEXT`,
`ALTER TABLE tasks ADD COLUMN IF NOT EXISTS retry_count INTEGER DEFAULT 0`,
`ALTER TABLE tasks ADD COLUMN IF NOT EXISTS max_retries INTEGER DEFAULT 3`,
`ALTER TABLE tasks ADD COLUMN IF NOT EXISTS timeout_seconds INTEGER`,
}
for _, migration := range migrations {
// DuckDB supports IF NOT EXISTS for ALTER TABLE, so we can safely execute
if _, err := db.Exec(migration); err != nil {
// Log but don't fail - column might already exist or table might not exist yet
// This is fine for migrations that run after schema creation
}
}
// Initialize registration_enabled setting (default: true) if it doesn't exist
var settingCount int
err := db.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
if err == nil && settingCount == 0 {
_, err = db.Exec("INSERT INTO settings (key, value) VALUES (?, ?)", "registration_enabled", "true")
if err != nil {
// Log but don't fail - setting might have been created by another process
log.Printf("Note: Could not initialize registration_enabled setting: %v", err)
}
}
// Don't close the migrate instance - with sqlite3.WithInstance, closing it
// may close the underlying database connection. The migrate instance will
// be garbage collected when it goes out of scope.
// If we need to close it later, we can store it in the DB struct and close
// it when DB.Close() is called, but for now we'll let it be GC'd.
log.Printf("Database migrations completed successfully")
return nil
}
for _, migration := range migrations {
// DuckDB supports IF NOT EXISTS for ALTER TABLE, so we can safely execute
if _, err := db.Exec(migration); err != nil {
// Log but don't fail - column might already exist or table might not exist yet
// This is fine for migrations that run after schema creation
}
}
return nil
// Ping checks the database connection
func (db *DB) Ping() error {
return db.db.Ping()
}
// Close closes the database connection
func (db *DB) Close() error {
return db.DB.Close()
return db.db.Close()
}

223
internal/logger/logger.go Normal file
View File

@@ -0,0 +1,223 @@
package logger
import (
"fmt"
"io"
"log"
"os"
"path/filepath"
"sync"
)
// Level represents log severity
type Level int
const (
LevelDebug Level = iota
LevelInfo
LevelWarn
LevelError
)
var levelNames = map[Level]string{
LevelDebug: "DEBUG",
LevelInfo: "INFO",
LevelWarn: "WARN",
LevelError: "ERROR",
}
// ParseLevel parses a level string into a Level
func ParseLevel(s string) Level {
switch s {
case "debug", "DEBUG":
return LevelDebug
case "info", "INFO":
return LevelInfo
case "warn", "WARN", "warning", "WARNING":
return LevelWarn
case "error", "ERROR":
return LevelError
default:
return LevelInfo
}
}
var (
defaultLogger *Logger
once sync.Once
currentLevel Level = LevelInfo
)
// Logger wraps the standard log.Logger with optional file output and levels
type Logger struct {
*log.Logger
fileWriter io.WriteCloser
}
// SetLevel sets the global log level
func SetLevel(level Level) {
currentLevel = level
}
// GetLevel returns the current log level
func GetLevel() Level {
return currentLevel
}
// InitStdout initializes the logger to only write to stdout
func InitStdout() {
once.Do(func() {
log.SetOutput(os.Stdout)
log.SetFlags(log.LstdFlags | log.Lshortfile)
defaultLogger = &Logger{
Logger: log.Default(),
fileWriter: nil,
}
})
}
// InitWithFile initializes the logger with both file and stdout output
// The file is truncated on each start
func InitWithFile(logPath string) error {
var err error
once.Do(func() {
defaultLogger, err = NewWithFile(logPath)
if err != nil {
return
}
// Replace standard log output with the multi-writer
multiWriter := io.MultiWriter(os.Stdout, defaultLogger.fileWriter)
log.SetOutput(multiWriter)
log.SetFlags(log.LstdFlags | log.Lshortfile)
})
return err
}
// NewWithFile creates a new logger that writes to both stdout and a log file
// The file is truncated on each start
func NewWithFile(logPath string) (*Logger, error) {
// Ensure log directory exists
logDir := filepath.Dir(logPath)
if err := os.MkdirAll(logDir, 0755); err != nil {
return nil, err
}
// Create/truncate the log file
fileWriter, err := os.Create(logPath)
if err != nil {
return nil, err
}
// Create multi-writer that writes to both stdout and file
multiWriter := io.MultiWriter(os.Stdout, fileWriter)
// Create logger with standard flags
logger := log.New(multiWriter, "", log.LstdFlags|log.Lshortfile)
return &Logger{
Logger: logger,
fileWriter: fileWriter,
}, nil
}
// Close closes the file writer
func (l *Logger) Close() error {
if l.fileWriter != nil {
return l.fileWriter.Close()
}
return nil
}
// GetDefault returns the default logger instance
func GetDefault() *Logger {
return defaultLogger
}
// logf logs a formatted message at the given level
func logf(level Level, format string, v ...interface{}) {
if level < currentLevel {
return
}
prefix := fmt.Sprintf("[%s] ", levelNames[level])
msg := fmt.Sprintf(format, v...)
log.Print(prefix + msg)
}
// logln logs a message at the given level
func logln(level Level, v ...interface{}) {
if level < currentLevel {
return
}
prefix := fmt.Sprintf("[%s] ", levelNames[level])
msg := fmt.Sprint(v...)
log.Print(prefix + msg)
}
// Debug logs a debug message
func Debug(v ...interface{}) {
logln(LevelDebug, v...)
}
// Debugf logs a formatted debug message
func Debugf(format string, v ...interface{}) {
logf(LevelDebug, format, v...)
}
// Info logs an info message
func Info(v ...interface{}) {
logln(LevelInfo, v...)
}
// Infof logs a formatted info message
func Infof(format string, v ...interface{}) {
logf(LevelInfo, format, v...)
}
// Warn logs a warning message
func Warn(v ...interface{}) {
logln(LevelWarn, v...)
}
// Warnf logs a formatted warning message
func Warnf(format string, v ...interface{}) {
logf(LevelWarn, format, v...)
}
// Error logs an error message
func Error(v ...interface{}) {
logln(LevelError, v...)
}
// Errorf logs a formatted error message
func Errorf(format string, v ...interface{}) {
logf(LevelError, format, v...)
}
// Fatal logs an error message and exits
func Fatal(v ...interface{}) {
logln(LevelError, v...)
os.Exit(1)
}
// Fatalf logs a formatted error message and exits
func Fatalf(format string, v ...interface{}) {
logf(LevelError, format, v...)
os.Exit(1)
}
// --- Backwards compatibility (maps to Info level) ---
// Printf logs a formatted message at Info level
func Printf(format string, v ...interface{}) {
logf(LevelInfo, format, v...)
}
// Print logs a message at Info level
func Print(v ...interface{}) {
logln(LevelInfo, v...)
}
// Println logs a message at Info level
func Println(v ...interface{}) {
logln(LevelInfo, v...)
}

View File

@@ -10,68 +10,119 @@ import (
"jiggablend/pkg/types"
)
// handleGenerateRegistrationToken generates a new registration token
func (s *Server) handleGenerateRegistrationToken(w http.ResponseWriter, r *http.Request) {
// handleGenerateRunnerAPIKey generates a new runner API key
func (s *Manager) handleGenerateRunnerAPIKey(w http.ResponseWriter, r *http.Request) {
userID, err := getUserID(r)
if err != nil {
s.respondError(w, http.StatusUnauthorized, err.Error())
return
}
// Default expiration: 24 hours
expiresIn := 24 * time.Hour
var req struct {
ExpiresInHours int `json:"expires_in_hours,omitempty"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
Scope string `json:"scope,omitempty"` // 'manager' or 'user'
}
if r.Body != nil && r.ContentLength > 0 {
if err := json.NewDecoder(r.Body).Decode(&req); err == nil && req.ExpiresInHours > 0 {
expiresIn = time.Duration(req.ExpiresInHours) * time.Hour
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
return
}
if req.Name == "" {
s.respondError(w, http.StatusBadRequest, "API key name is required")
return
}
// Default scope to 'user' if not specified
scope := req.Scope
if scope == "" {
scope = "user"
}
if scope != "manager" && scope != "user" {
s.respondError(w, http.StatusBadRequest, "Scope must be 'manager' or 'user'")
return
}
keyInfo, err := s.secrets.GenerateRunnerAPIKey(userID, req.Name, req.Description, scope)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to generate API key: %v", err))
return
}
response := map[string]interface{}{
"id": keyInfo.ID,
"key": keyInfo.Key,
"name": keyInfo.Name,
"description": keyInfo.Description,
"is_active": keyInfo.IsActive,
"created_at": keyInfo.CreatedAt,
}
s.respondJSON(w, http.StatusCreated, response)
}
// handleListRunnerAPIKeys lists all runner API keys
func (s *Manager) handleListRunnerAPIKeys(w http.ResponseWriter, r *http.Request) {
keys, err := s.secrets.ListRunnerAPIKeys()
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to list API keys: %v", err))
return
}
// Convert to response format (hide sensitive hash data)
var response []map[string]interface{}
for _, key := range keys {
item := map[string]interface{}{
"id": key.ID,
"key_prefix": key.Key, // Only show prefix, not full key
"name": key.Name,
"is_active": key.IsActive,
"created_at": key.CreatedAt,
"created_by": key.CreatedBy,
}
if key.Description != nil {
item["description"] = *key.Description
}
response = append(response, item)
}
token, err := s.secrets.GenerateRegistrationToken(userID, expiresIn)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to generate token: %v", err))
return
}
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
"token": token,
"expires_in": expiresIn.String(),
"expires_at": time.Now().Add(expiresIn),
})
s.respondJSON(w, http.StatusOK, response)
}
// handleListRegistrationTokens lists all registration tokens
func (s *Server) handleListRegistrationTokens(w http.ResponseWriter, r *http.Request) {
tokens, err := s.secrets.ListRegistrationTokens()
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to list tokens: %v", err))
return
}
s.respondJSON(w, http.StatusOK, tokens)
}
// handleRevokeRegistrationToken revokes a registration token
func (s *Server) handleRevokeRegistrationToken(w http.ResponseWriter, r *http.Request) {
tokenID, err := parseID(r, "id")
// handleRevokeRunnerAPIKey revokes a runner API key
func (s *Manager) handleRevokeRunnerAPIKey(w http.ResponseWriter, r *http.Request) {
keyID, err := parseID(r, "id")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
if err := s.secrets.RevokeRegistrationToken(tokenID); err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to revoke token: %v", err))
if err := s.secrets.RevokeRunnerAPIKey(keyID); err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to revoke API key: %v", err))
return
}
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Token revoked"})
s.respondJSON(w, http.StatusOK, map[string]string{"message": "API key revoked"})
}
// handleDeleteRunnerAPIKey deletes a runner API key
func (s *Manager) handleDeleteRunnerAPIKey(w http.ResponseWriter, r *http.Request) {
keyID, err := parseID(r, "id")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
if err := s.secrets.DeleteRunnerAPIKey(keyID); err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to delete API key: %v", err))
return
}
s.respondJSON(w, http.StatusOK, map[string]string{"message": "API key deleted"})
}
// handleVerifyRunner manually verifies a runner
func (s *Server) handleVerifyRunner(w http.ResponseWriter, r *http.Request) {
func (s *Manager) handleVerifyRunner(w http.ResponseWriter, r *http.Request) {
runnerID, err := parseID(r, "id")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
@@ -80,14 +131,19 @@ func (s *Server) handleVerifyRunner(w http.ResponseWriter, r *http.Request) {
// Check if runner exists
var exists bool
err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists)
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists)
})
if err != nil || !exists {
s.respondError(w, http.StatusNotFound, "Runner not found")
return
}
// Mark runner as verified
_, err = s.db.Exec("UPDATE runners SET verified = 1 WHERE id = ?", runnerID)
err = s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec("UPDATE runners SET verified = 1 WHERE id = ?", runnerID)
return err
})
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify runner: %v", err))
return
@@ -97,7 +153,7 @@ func (s *Server) handleVerifyRunner(w http.ResponseWriter, r *http.Request) {
}
// handleDeleteRunner removes a runner
func (s *Server) handleDeleteRunner(w http.ResponseWriter, r *http.Request) {
func (s *Manager) handleDeleteRunner(w http.ResponseWriter, r *http.Request) {
runnerID, err := parseID(r, "id")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
@@ -106,14 +162,19 @@ func (s *Server) handleDeleteRunner(w http.ResponseWriter, r *http.Request) {
// Check if runner exists
var exists bool
err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists)
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists)
})
if err != nil || !exists {
s.respondError(w, http.StatusNotFound, "Runner not found")
return
}
// Delete runner
_, err = s.db.Exec("DELETE FROM runners WHERE id = ?", runnerID)
err = s.db.With(func(conn *sql.DB) error {
_, err := conn.Exec("DELETE FROM runners WHERE id = ?", runnerID)
return err
})
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to delete runner: %v", err))
return
@@ -123,12 +184,17 @@ func (s *Server) handleDeleteRunner(w http.ResponseWriter, r *http.Request) {
}
// handleListRunnersAdmin lists all runners with admin details
func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request) {
rows, err := s.db.Query(
`SELECT id, name, hostname, ip_address, status, last_heartbeat, capabilities,
registration_token, verified, priority, created_at
func (s *Manager) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request) {
var rows *sql.Rows
err := s.db.With(func(conn *sql.DB) error {
var err error
rows, err = conn.Query(
`SELECT id, name, hostname, status, last_heartbeat, capabilities,
api_key_id, api_key_scope, priority, created_at
FROM runners ORDER BY created_at DESC`,
)
)
return err
})
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query runners: %v", err))
return
@@ -138,31 +204,32 @@ func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request)
runners := []map[string]interface{}{}
for rows.Next() {
var runner types.Runner
var registrationToken sql.NullString
var verified bool
var apiKeyID sql.NullInt64
var apiKeyScope string
err := rows.Scan(
&runner.ID, &runner.Name, &runner.Hostname, &runner.IPAddress,
&runner.ID, &runner.Name, &runner.Hostname,
&runner.Status, &runner.LastHeartbeat, &runner.Capabilities,
&registrationToken, &verified, &runner.Priority, &runner.CreatedAt,
&apiKeyID, &apiKeyScope, &runner.Priority, &runner.CreatedAt,
)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan runner: %v", err))
return
}
// In polling model, database status is the source of truth
// Runners update their status when they poll for jobs
runners = append(runners, map[string]interface{}{
"id": runner.ID,
"name": runner.Name,
"hostname": runner.Hostname,
"ip_address": runner.IPAddress,
"status": runner.Status,
"last_heartbeat": runner.LastHeartbeat,
"capabilities": runner.Capabilities,
"registration_token": registrationToken.String,
"verified": verified,
"priority": runner.Priority,
"created_at": runner.CreatedAt,
"id": runner.ID,
"name": runner.Name,
"hostname": runner.Hostname,
"status": runner.Status,
"last_heartbeat": runner.LastHeartbeat,
"capabilities": runner.Capabilities,
"api_key_id": apiKeyID.Int64,
"api_key_scope": apiKeyScope,
"priority": runner.Priority,
"created_at": runner.CreatedAt,
})
}
@@ -170,7 +237,7 @@ func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request)
}
// handleListUsers lists all users
func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) {
func (s *Manager) handleListUsers(w http.ResponseWriter, r *http.Request) {
// Get first user ID to mark it in the response
firstUserID, err := s.auth.GetFirstUserID()
if err != nil {
@@ -178,10 +245,15 @@ func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) {
firstUserID = 0
}
rows, err := s.db.Query(
`SELECT id, email, name, oauth_provider, is_admin, created_at
var rows *sql.Rows
err = s.db.With(func(conn *sql.DB) error {
var err error
rows, err = conn.Query(
`SELECT id, email, name, oauth_provider, is_admin, created_at
FROM users ORDER BY created_at DESC`,
)
)
return err
})
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query users: %v", err))
return
@@ -203,7 +275,9 @@ func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) {
// Get job count for this user
var jobCount int
err = s.db.QueryRow("SELECT COUNT(*) FROM jobs WHERE user_id = ?", userID).Scan(&jobCount)
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT COUNT(*) FROM jobs WHERE user_id = ?", userID).Scan(&jobCount)
})
if err != nil {
jobCount = 0 // Default to 0 if query fails
}
@@ -224,7 +298,7 @@ func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) {
}
// handleGetUserJobs gets all jobs for a specific user
func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) {
func (s *Manager) handleGetUserJobs(w http.ResponseWriter, r *http.Request) {
userID, err := parseID(r, "id")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
@@ -233,18 +307,25 @@ func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) {
// Verify user exists
var exists bool
err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", userID).Scan(&exists)
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", userID).Scan(&exists)
})
if err != nil || !exists {
s.respondError(w, http.StatusNotFound, "User not found")
return
}
rows, err := s.db.Query(
`SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format,
allow_parallel_runners, timeout_seconds, blend_metadata, created_at, started_at, completed_at, error_message
var rows *sql.Rows
err = s.db.With(func(conn *sql.DB) error {
var err error
rows, err = conn.Query(
`SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format,
blend_metadata, created_at, started_at, completed_at, error_message
FROM jobs WHERE user_id = ? ORDER BY created_at DESC`,
userID,
)
userID,
)
return err
})
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query jobs: %v", err))
return
@@ -260,11 +341,9 @@ func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) {
var errorMessage sql.NullString
var frameStart, frameEnd sql.NullInt64
var outputFormat sql.NullString
var allowParallelRunners sql.NullBool
err := rows.Scan(
&job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress,
&frameStart, &frameEnd, &outputFormat, &allowParallelRunners, &job.TimeoutSeconds,
&frameStart, &frameEnd, &outputFormat,
&blendMetadataJSON, &job.CreatedAt, &startedAt, &completedAt, &errorMessage,
)
if err != nil {
@@ -284,9 +363,6 @@ func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) {
if outputFormat.Valid {
job.OutputFormat = &outputFormat.String
}
if allowParallelRunners.Valid {
job.AllowParallelRunners = &allowParallelRunners.Bool
}
if startedAt.Valid {
job.StartedAt = &startedAt.Time
}
@@ -310,7 +386,7 @@ func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) {
}
// handleGetRegistrationEnabled gets the registration enabled setting
func (s *Server) handleGetRegistrationEnabled(w http.ResponseWriter, r *http.Request) {
func (s *Manager) handleGetRegistrationEnabled(w http.ResponseWriter, r *http.Request) {
enabled, err := s.auth.IsRegistrationEnabled()
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to get registration setting: %v", err))
@@ -320,12 +396,12 @@ func (s *Server) handleGetRegistrationEnabled(w http.ResponseWriter, r *http.Req
}
// handleSetRegistrationEnabled sets the registration enabled setting
func (s *Server) handleSetRegistrationEnabled(w http.ResponseWriter, r *http.Request) {
func (s *Manager) handleSetRegistrationEnabled(w http.ResponseWriter, r *http.Request) {
var req struct {
Enabled bool `json:"enabled"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
return
}
@@ -338,7 +414,7 @@ func (s *Server) handleSetRegistrationEnabled(w http.ResponseWriter, r *http.Req
}
// handleSetUserAdminStatus sets a user's admin status (admin only)
func (s *Server) handleSetUserAdminStatus(w http.ResponseWriter, r *http.Request) {
func (s *Manager) handleSetUserAdminStatus(w http.ResponseWriter, r *http.Request) {
targetUserID, err := parseID(r, "id")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
@@ -349,7 +425,7 @@ func (s *Server) handleSetUserAdminStatus(w http.ResponseWriter, r *http.Request
IsAdmin bool `json:"is_admin"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
return
}

831
internal/manager/blender.go Normal file
View File

@@ -0,0 +1,831 @@
package api
import (
"archive/tar"
"compress/bzip2"
"compress/gzip"
"fmt"
"io"
"log"
"net/http"
"os"
"os/exec"
"path/filepath"
"regexp"
"sort"
"strings"
"sync"
"time"
)
const (
BlenderDownloadBaseURL = "https://download.blender.org/release/"
BlenderVersionCacheTTL = 1 * time.Hour
)
// BlenderVersion represents a parsed Blender version
type BlenderVersion struct {
Major int `json:"major"`
Minor int `json:"minor"`
Patch int `json:"patch"`
Full string `json:"full"` // e.g., "4.2.3"
DirName string `json:"dir_name"` // e.g., "Blender4.2"
Filename string `json:"filename"` // e.g., "blender-4.2.3-linux-x64.tar.xz"
URL string `json:"url"` // Full download URL
}
// BlenderVersionCache caches available Blender versions
type BlenderVersionCache struct {
versions []BlenderVersion
fetchedAt time.Time
mu sync.RWMutex
}
var blenderVersionCache = &BlenderVersionCache{}
// FetchBlenderVersions fetches available Blender versions from download.blender.org
// Returns versions sorted by version number (newest first)
func (s *Manager) FetchBlenderVersions() ([]BlenderVersion, error) {
// Check cache first
blenderVersionCache.mu.RLock()
if time.Since(blenderVersionCache.fetchedAt) < BlenderVersionCacheTTL && len(blenderVersionCache.versions) > 0 {
versions := make([]BlenderVersion, len(blenderVersionCache.versions))
copy(versions, blenderVersionCache.versions)
blenderVersionCache.mu.RUnlock()
return versions, nil
}
blenderVersionCache.mu.RUnlock()
// Fetch from website with timeout
client := &http.Client{
Timeout: WSWriteDeadline,
}
resp, err := client.Get(BlenderDownloadBaseURL)
if err != nil {
return nil, fmt.Errorf("failed to fetch blender releases: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to fetch blender releases: status %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
// Parse directory listing for Blender version folders
// Looking for patterns like href="Blender4.2/" or href="Blender3.6/"
dirPattern := regexp.MustCompile(`href="Blender(\d+)\.(\d+)/"`)
log.Printf("Fetching Blender versions from %s", BlenderDownloadBaseURL)
matches := dirPattern.FindAllStringSubmatch(string(body), -1)
// Fetch sub-versions concurrently to speed up the process
type versionResult struct {
versions []BlenderVersion
err error
}
results := make(chan versionResult, len(matches))
var wg sync.WaitGroup
for _, match := range matches {
if len(match) < 3 {
continue
}
major := 0
minor := 0
fmt.Sscanf(match[1], "%d", &major)
fmt.Sscanf(match[2], "%d", &minor)
// Skip very old versions (pre-2.80)
if major < 2 || (major == 2 && minor < 80) {
continue
}
dirName := fmt.Sprintf("Blender%d.%d", major, minor)
// Fetch the specific version directory concurrently
wg.Add(1)
go func(dn string, maj, min int) {
defer wg.Done()
subVersions, err := fetchSubVersions(dn, maj, min)
results <- versionResult{versions: subVersions, err: err}
}(dirName, major, minor)
}
// Close results channel when all goroutines complete
go func() {
wg.Wait()
close(results)
}()
var versions []BlenderVersion
for result := range results {
if result.err != nil {
log.Printf("Warning: failed to fetch sub-versions: %v", result.err)
continue
}
versions = append(versions, result.versions...)
}
// Sort by version (newest first)
sort.Slice(versions, func(i, j int) bool {
if versions[i].Major != versions[j].Major {
return versions[i].Major > versions[j].Major
}
if versions[i].Minor != versions[j].Minor {
return versions[i].Minor > versions[j].Minor
}
return versions[i].Patch > versions[j].Patch
})
// Update cache
blenderVersionCache.mu.Lock()
blenderVersionCache.versions = versions
blenderVersionCache.fetchedAt = time.Now()
blenderVersionCache.mu.Unlock()
return versions, nil
}
// fetchSubVersions fetches specific version files from a Blender release directory
func fetchSubVersions(dirName string, major, minor int) ([]BlenderVersion, error) {
url := BlenderDownloadBaseURL + dirName + "/"
client := &http.Client{
Timeout: WSWriteDeadline,
}
resp, err := client.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("status %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// Look for linux 64-bit tar.xz/bz2 files
// Various naming conventions across versions:
// - Modern (2.93+): blender-4.2.3-linux-x64.tar.xz
// - 2.83 early: blender-2.83.0-linux64.tar.xz
// - 2.80-2.82: blender-2.80-linux-glibc217-x86_64.tar.bz2
// Skip: rc versions, alpha/beta, i686 (32-bit)
filePatterns := []*regexp.Regexp{
// Modern format: blender-X.Y.Z-linux-x64.tar.xz
regexp.MustCompile(`blender-(\d+)\.(\d+)\.(\d+)-linux-x64\.tar\.(xz|bz2)`),
// Older format: blender-X.Y.Z-linux64.tar.xz
regexp.MustCompile(`blender-(\d+)\.(\d+)\.(\d+)-linux64\.tar\.(xz|bz2)`),
// glibc format: blender-X.Y.Z-linux-glibc217-x86_64.tar.bz2 (prefer glibc217 for compatibility)
regexp.MustCompile(`blender-(\d+)\.(\d+)\.(\d+)-linux-glibc217-x86_64\.tar\.(xz|bz2)`),
}
var versions []BlenderVersion
seen := make(map[string]bool)
for _, filePattern := range filePatterns {
matches := filePattern.FindAllStringSubmatch(string(body), -1)
for _, match := range matches {
if len(match) < 5 {
continue
}
patch := 0
fmt.Sscanf(match[3], "%d", &patch)
full := fmt.Sprintf("%d.%d.%d", major, minor, patch)
if seen[full] {
continue
}
seen[full] = true
filename := match[0]
versions = append(versions, BlenderVersion{
Major: major,
Minor: minor,
Patch: patch,
Full: full,
DirName: dirName,
Filename: filename,
URL: url + filename,
})
}
}
return versions, nil
}
// GetLatestBlenderForMajorMinor returns the latest patch version for a given major.minor
// If exact match not found, uses fuzzy matching to find the closest available version
func (s *Manager) GetLatestBlenderForMajorMinor(major, minor int) (*BlenderVersion, error) {
versions, err := s.FetchBlenderVersions()
if err != nil {
return nil, err
}
if len(versions) == 0 {
return nil, fmt.Errorf("no blender versions available")
}
// Try exact match first - find the highest patch for this major.minor
var exactMatch *BlenderVersion
for i := range versions {
v := &versions[i]
if v.Major == major && v.Minor == minor {
if exactMatch == nil || v.Patch > exactMatch.Patch {
exactMatch = v
}
}
}
if exactMatch != nil {
log.Printf("Found Blender %d.%d.%d for requested %d.%d", exactMatch.Major, exactMatch.Minor, exactMatch.Patch, major, minor)
return exactMatch, nil
}
// Fuzzy matching: find closest version
// Priority: same major with closest minor > closest major
log.Printf("No exact match for Blender %d.%d, using fuzzy matching", major, minor)
var bestMatch *BlenderVersion
bestScore := -1000000 // Large negative number
for i := range versions {
v := &versions[i]
score := 0
if v.Major == major {
// Same major version - prefer this
score = 10000
// Prefer lower minor versions (more stable/compatible)
// but not too far back
minorDiff := minor - v.Minor
if minorDiff >= 0 {
// v.Minor <= minor (older or same) - prefer closer
score += 1000 - minorDiff*10
} else {
// v.Minor > minor (newer) - less preferred but acceptable
score += 500 + minorDiff*10
}
// Higher patch is better
score += v.Patch
} else {
// Different major - less preferred
majorDiff := major - v.Major
if majorDiff > 0 {
// v.Major < major (older major) - acceptable fallback
score = 5000 - majorDiff*1000 + v.Minor*10 + v.Patch
} else {
// v.Major > major (newer major) - avoid if possible
score = -majorDiff * 1000
}
}
if score > bestScore {
bestScore = score
bestMatch = v
}
}
if bestMatch != nil {
log.Printf("Fuzzy match: requested %d.%d, using %d.%d.%d", major, minor, bestMatch.Major, bestMatch.Minor, bestMatch.Patch)
return bestMatch, nil
}
return nil, fmt.Errorf("no blender version found for %d.%d", major, minor)
}
// GetBlenderArchivePath returns the path to the cached blender archive for a specific version
// Downloads from blender.org and decompresses to .tar if not already cached
// The manager caches as uncompressed .tar to save decompression time on runners
func (s *Manager) GetBlenderArchivePath(version *BlenderVersion) (string, error) {
// Base directory for blender archives
blenderDir := filepath.Join(s.storage.BasePath(), "blender-versions")
if err := os.MkdirAll(blenderDir, 0755); err != nil {
return "", fmt.Errorf("failed to create blender directory: %w", err)
}
// Cache as uncompressed .tar for faster runner downloads
// Convert filename like "blender-4.2.3-linux-x64.tar.xz" to "blender-4.2.3-linux-x64.tar"
tarFilename := version.Filename
tarFilename = strings.TrimSuffix(tarFilename, ".xz")
tarFilename = strings.TrimSuffix(tarFilename, ".bz2")
archivePath := filepath.Join(blenderDir, tarFilename)
// Check if already cached as .tar
if _, err := os.Stat(archivePath); err == nil {
log.Printf("Using cached Blender %s at %s", version.Full, archivePath)
// Clean up any extracted folders that might exist
s.cleanupExtractedBlenderFolders(blenderDir, version)
return archivePath, nil
}
// Need to download and decompress
log.Printf("Downloading Blender %s from %s", version.Full, version.URL)
client := &http.Client{
Timeout: 0, // No timeout for large downloads
}
resp, err := client.Get(version.URL)
if err != nil {
return "", fmt.Errorf("failed to download blender: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("failed to download blender: status %d", resp.StatusCode)
}
// Download to temp file first
compressedPath := filepath.Join(blenderDir, "download-"+version.Filename)
compressedFile, err := os.Create(compressedPath)
if err != nil {
return "", fmt.Errorf("failed to create temp file: %w", err)
}
if _, err := io.Copy(compressedFile, resp.Body); err != nil {
compressedFile.Close()
os.Remove(compressedPath)
return "", fmt.Errorf("failed to download blender: %w", err)
}
compressedFile.Close()
log.Printf("Downloaded Blender %s, decompressing to .tar...", version.Full)
// Decompress to .tar
if err := decompressToTar(compressedPath, archivePath); err != nil {
os.Remove(compressedPath)
os.Remove(archivePath)
return "", fmt.Errorf("failed to decompress blender archive: %w", err)
}
// Remove compressed file
os.Remove(compressedPath)
// Clean up any extracted folders for this version (if they exist)
s.cleanupExtractedBlenderFolders(blenderDir, version)
log.Printf("Blender %s cached at %s", version.Full, archivePath)
return archivePath, nil
}
// decompressToTar decompresses a .tar.xz or .tar.bz2 file to a plain .tar file
func decompressToTar(compressedPath, tarPath string) error {
if strings.HasSuffix(compressedPath, ".tar.xz") {
// Use xz command for decompression
cmd := exec.Command("xz", "-d", "-k", "-c", compressedPath)
outFile, err := os.Create(tarPath)
if err != nil {
return err
}
defer outFile.Close()
cmd.Stdout = outFile
if err := cmd.Run(); err != nil {
return fmt.Errorf("xz decompression failed: %w", err)
}
return nil
} else if strings.HasSuffix(compressedPath, ".tar.bz2") {
// Use bzip2 for decompression
inFile, err := os.Open(compressedPath)
if err != nil {
return err
}
defer inFile.Close()
bzReader := bzip2.NewReader(inFile)
outFile, err := os.Create(tarPath)
if err != nil {
return err
}
defer outFile.Close()
if _, err := io.Copy(outFile, bzReader); err != nil {
return fmt.Errorf("bzip2 decompression failed: %w", err)
}
return nil
}
return fmt.Errorf("unsupported compression format: %s", compressedPath)
}
// cleanupExtractedBlenderFolders removes any extracted Blender folders for the given version
// This ensures we only keep the .tar file and not extracted folders
func (s *Manager) cleanupExtractedBlenderFolders(blenderDir string, version *BlenderVersion) {
// Look for folders matching the version (e.g., "4.2.3", "2.83.20")
versionDirs := []string{
filepath.Join(blenderDir, version.Full), // e.g., "4.2.3"
filepath.Join(blenderDir, fmt.Sprintf("%d.%d", version.Major, version.Minor)), // e.g., "4.2"
}
for _, dir := range versionDirs {
if info, err := os.Stat(dir); err == nil && info.IsDir() {
log.Printf("Removing extracted Blender folder: %s", dir)
if err := os.RemoveAll(dir); err != nil {
log.Printf("Warning: failed to remove extracted folder %s: %v", dir, err)
} else {
log.Printf("Removed extracted Blender folder: %s", dir)
}
}
}
}
// ParseBlenderVersionFromFile parses the Blender version that a .blend file was saved with
// This reads the file header to determine the version
func ParseBlenderVersionFromFile(blendPath string) (major, minor int, err error) {
file, err := os.Open(blendPath)
if err != nil {
return 0, 0, fmt.Errorf("failed to open blend file: %w", err)
}
defer file.Close()
return ParseBlenderVersionFromReader(file)
}
// ParseBlenderVersionFromReader parses the Blender version from a reader
// Useful for reading from uploaded files without saving to disk first
func ParseBlenderVersionFromReader(r io.ReadSeeker) (major, minor int, err error) {
// Read the first 12 bytes of the blend file header
// Format: BLENDER-v<major><minor><patch> or BLENDER_v<major><minor><patch>
// The header is: "BLENDER" (7 bytes) + pointer size (1 byte: '-' for 64-bit, '_' for 32-bit)
// + endianness (1 byte: 'v' for little-endian, 'V' for big-endian)
// + version (3 bytes: e.g., "402" for 4.02)
header := make([]byte, 12)
n, err := r.Read(header)
if err != nil || n < 12 {
return 0, 0, fmt.Errorf("failed to read blend file header: %w", err)
}
// Check for BLENDER magic
if string(header[:7]) != "BLENDER" {
// Might be compressed - try to decompress
r.Seek(0, 0)
return parseCompressedBlendVersion(r)
}
// Parse version from bytes 9-11 (3 digits)
versionStr := string(header[9:12])
var vMajor, vMinor int
// Version format changed in Blender 3.0
// Pre-3.0: "279" = 2.79, "280" = 2.80
// 3.0+: "300" = 3.0, "402" = 4.02, "410" = 4.10
if len(versionStr) == 3 {
// First digit is major version
fmt.Sscanf(string(versionStr[0]), "%d", &vMajor)
// Next two digits are minor version
fmt.Sscanf(versionStr[1:3], "%d", &vMinor)
}
return vMajor, vMinor, nil
}
// parseCompressedBlendVersion handles gzip and zstd compressed blend files
func parseCompressedBlendVersion(r io.ReadSeeker) (major, minor int, err error) {
// Check for compression magic bytes
magic := make([]byte, 4)
if _, err := r.Read(magic); err != nil {
return 0, 0, err
}
r.Seek(0, 0)
if magic[0] == 0x1f && magic[1] == 0x8b {
// gzip compressed
gzReader, err := gzip.NewReader(r)
if err != nil {
return 0, 0, fmt.Errorf("failed to create gzip reader: %w", err)
}
defer gzReader.Close()
header := make([]byte, 12)
n, err := gzReader.Read(header)
if err != nil || n < 12 {
return 0, 0, fmt.Errorf("failed to read compressed blend header: %w", err)
}
if string(header[:7]) != "BLENDER" {
return 0, 0, fmt.Errorf("invalid blend file format")
}
versionStr := string(header[9:12])
var vMajor, vMinor int
if len(versionStr) == 3 {
fmt.Sscanf(string(versionStr[0]), "%d", &vMajor)
fmt.Sscanf(versionStr[1:3], "%d", &vMinor)
}
return vMajor, vMinor, nil
}
// Check for zstd magic (Blender 3.0+): 0x28 0xB5 0x2F 0xFD
if magic[0] == 0x28 && magic[1] == 0xb5 && magic[2] == 0x2f && magic[3] == 0xfd {
return parseZstdBlendVersion(r)
}
return 0, 0, fmt.Errorf("unknown blend file format")
}
// parseZstdBlendVersion handles zstd-compressed blend files (Blender 3.0+)
// Uses zstd command line tool since Go doesn't have native zstd support
func parseZstdBlendVersion(r io.ReadSeeker) (major, minor int, err error) {
r.Seek(0, 0)
// We need to decompress just enough to read the header
// Use zstd command to decompress from stdin
cmd := exec.Command("zstd", "-d", "-c")
cmd.Stdin = r
stdout, err := cmd.StdoutPipe()
if err != nil {
return 0, 0, fmt.Errorf("failed to create zstd stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return 0, 0, fmt.Errorf("failed to start zstd decompression: %w", err)
}
// Read just the header (12 bytes)
header := make([]byte, 12)
n, readErr := io.ReadFull(stdout, header)
// Kill the process early - we only need the header
cmd.Process.Kill()
cmd.Wait()
if readErr != nil || n < 12 {
return 0, 0, fmt.Errorf("failed to read zstd compressed blend header: %v", readErr)
}
if string(header[:7]) != "BLENDER" {
return 0, 0, fmt.Errorf("invalid blend file format in zstd archive")
}
versionStr := string(header[9:12])
var vMajor, vMinor int
if len(versionStr) == 3 {
fmt.Sscanf(string(versionStr[0]), "%d", &vMajor)
fmt.Sscanf(versionStr[1:3], "%d", &vMinor)
}
return vMajor, vMinor, nil
}
// handleGetBlenderVersions returns available Blender versions
func (s *Manager) handleGetBlenderVersions(w http.ResponseWriter, r *http.Request) {
versions, err := s.FetchBlenderVersions()
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to fetch blender versions: %v", err))
return
}
// Group by major.minor for easier frontend display
type VersionGroup struct {
MajorMinor string `json:"major_minor"`
Latest BlenderVersion `json:"latest"`
All []BlenderVersion `json:"all"`
}
groups := make(map[string]*VersionGroup)
for _, v := range versions {
key := fmt.Sprintf("%d.%d", v.Major, v.Minor)
if groups[key] == nil {
groups[key] = &VersionGroup{
MajorMinor: key,
Latest: v, // First one is latest due to sorting
All: []BlenderVersion{v},
}
} else {
groups[key].All = append(groups[key].All, v)
}
}
// Convert to slice and sort by version
var groupedResult []VersionGroup
for _, g := range groups {
groupedResult = append(groupedResult, *g)
}
sort.Slice(groupedResult, func(i, j int) bool {
// Parse major.minor for comparison
var iMaj, iMin, jMaj, jMin int
fmt.Sscanf(groupedResult[i].MajorMinor, "%d.%d", &iMaj, &iMin)
fmt.Sscanf(groupedResult[j].MajorMinor, "%d.%d", &jMaj, &jMin)
if iMaj != jMaj {
return iMaj > jMaj
}
return iMin > jMin
})
// Return both flat list and grouped for flexibility
response := map[string]interface{}{
"versions": versions, // Flat list of all versions (newest first)
"grouped": groupedResult, // Grouped by major.minor
}
s.respondJSON(w, http.StatusOK, response)
}
// handleDownloadBlender serves a cached Blender archive to runners
func (s *Manager) handleDownloadBlender(w http.ResponseWriter, r *http.Request) {
version := r.URL.Query().Get("version")
if version == "" {
s.respondError(w, http.StatusBadRequest, "version parameter required")
return
}
// Parse version string (e.g., "4.2.3" or "4.2")
var major, minor, patch int
parts := strings.Split(version, ".")
if len(parts) < 2 {
s.respondError(w, http.StatusBadRequest, "invalid version format, expected major.minor or major.minor.patch")
return
}
fmt.Sscanf(parts[0], "%d", &major)
fmt.Sscanf(parts[1], "%d", &minor)
if len(parts) >= 3 {
fmt.Sscanf(parts[2], "%d", &patch)
}
// Find the version
var blenderVersion *BlenderVersion
if len(parts) >= 3 {
// Exact patch version requested - find it
versions, err := s.FetchBlenderVersions()
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to fetch versions: %v", err))
return
}
for _, v := range versions {
if v.Major == major && v.Minor == minor && v.Patch == patch {
blenderVersion = &v
break
}
}
if blenderVersion == nil {
s.respondError(w, http.StatusNotFound, fmt.Sprintf("blender version %s not found", version))
return
}
} else {
// Major.minor only - use helper to get latest patch version
var err error
blenderVersion, err = s.GetLatestBlenderForMajorMinor(major, minor)
if err != nil {
s.respondError(w, http.StatusNotFound, fmt.Sprintf("blender version %s not found: %v", version, err))
return
}
}
// Get or download the archive
archivePath, err := s.GetBlenderArchivePath(blenderVersion)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to get blender archive: %v", err))
return
}
// Serve the file
file, err := os.Open(archivePath)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to open archive: %v", err))
return
}
defer file.Close()
stat, err := file.Stat()
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to stat archive: %v", err))
return
}
// Filename is now .tar (decompressed)
tarFilename := blenderVersion.Filename
tarFilename = strings.TrimSuffix(tarFilename, ".xz")
tarFilename = strings.TrimSuffix(tarFilename, ".bz2")
w.Header().Set("Content-Type", "application/x-tar")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", tarFilename))
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
w.Header().Set("X-Blender-Version", blenderVersion.Full)
io.Copy(w, file)
}
// Unused functions from extraction - keeping for reference but not needed on manager
var _ = extractBlenderArchive
var _ = extractTarXz
var _ = extractTar
// extractBlenderArchive extracts a blender archive (already decompressed to .tar by GetBlenderArchivePath)
func extractBlenderArchive(archivePath string, version *BlenderVersion, destDir string) error {
file, err := os.Open(archivePath)
if err != nil {
return err
}
defer file.Close()
// The archive is already decompressed to .tar by GetBlenderArchivePath
// Just extract it directly
if strings.HasSuffix(archivePath, ".tar") {
tarReader := tar.NewReader(file)
return extractTar(tarReader, version, destDir)
}
// Fallback for any other format (shouldn't happen with current flow)
if strings.HasSuffix(archivePath, ".tar.xz") {
return extractTarXz(archivePath, version, destDir)
} else if strings.HasSuffix(archivePath, ".tar.bz2") {
bzReader := bzip2.NewReader(file)
tarReader := tar.NewReader(bzReader)
return extractTar(tarReader, version, destDir)
}
return fmt.Errorf("unsupported archive format: %s", archivePath)
}
// extractTarXz extracts a tar.xz archive using the xz command
func extractTarXz(archivePath string, version *BlenderVersion, destDir string) error {
versionDir := filepath.Join(destDir, version.Full)
if err := os.MkdirAll(versionDir, 0755); err != nil {
return err
}
cmd := exec.Command("tar", "-xJf", archivePath, "-C", versionDir, "--strip-components=1")
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("tar extraction failed: %v, output: %s", err, string(output))
}
return nil
}
// extractTar extracts files from a tar reader
func extractTar(tarReader *tar.Reader, version *BlenderVersion, destDir string) error {
versionDir := filepath.Join(destDir, version.Full)
if err := os.MkdirAll(versionDir, 0755); err != nil {
return err
}
stripPrefix := ""
for {
header, err := tarReader.Next()
if err == io.EOF {
break
}
if err != nil {
return err
}
if stripPrefix == "" {
parts := strings.SplitN(header.Name, "/", 2)
if len(parts) > 0 {
stripPrefix = parts[0] + "/"
}
}
name := strings.TrimPrefix(header.Name, stripPrefix)
if name == "" {
continue
}
targetPath := filepath.Join(versionDir, name)
switch header.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil {
return err
}
case tar.TypeReg:
if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
return err
}
outFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
if err != nil {
return err
}
if _, err := io.Copy(outFile, tarReader); err != nil {
outFile.Close()
return err
}
outFile.Close()
case tar.TypeSymlink:
if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
return err
}
if err := os.Symlink(header.Linkname, targetPath); err != nil {
return err
}
}
}
return nil
}

5216
internal/manager/jobs.go Normal file

File diff suppressed because it is too large Load Diff

1302
internal/manager/manager.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,265 @@
package api
import (
"archive/tar"
"database/sql"
_ "embed"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"jiggablend/pkg/executils"
"jiggablend/pkg/scripts"
"jiggablend/pkg/types"
)
// handleGetJobMetadata retrieves metadata for a job
func (s *Manager) handleGetJobMetadata(w http.ResponseWriter, r *http.Request) {
userID, err := getUserID(r)
if err != nil {
s.respondError(w, http.StatusUnauthorized, err.Error())
return
}
jobID, err := parseID(r, "id")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
// Verify job belongs to user
var jobUserID int64
var blendMetadataJSON sql.NullString
err = s.db.With(func(conn *sql.DB) error {
return conn.QueryRow(
`SELECT user_id, blend_metadata FROM jobs WHERE id = ?`,
jobID,
).Scan(&jobUserID, &blendMetadataJSON)
})
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "Job not found")
return
}
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query job: %v", err))
return
}
if jobUserID != userID {
s.respondError(w, http.StatusForbidden, "Access denied")
return
}
if !blendMetadataJSON.Valid || blendMetadataJSON.String == "" {
s.respondJSON(w, http.StatusOK, nil)
return
}
var metadata types.BlendMetadata
if err := json.Unmarshal([]byte(blendMetadataJSON.String), &metadata); err != nil {
s.respondError(w, http.StatusInternalServerError, "Failed to parse metadata")
return
}
s.respondJSON(w, http.StatusOK, metadata)
}
// extractMetadataFromContext extracts metadata from the blend file in a context archive
// Returns the extracted metadata or an error
func (s *Manager) extractMetadataFromContext(jobID int64) (*types.BlendMetadata, error) {
contextPath := filepath.Join(s.storage.JobPath(jobID), "context.tar")
// Check if context exists
if _, err := os.Stat(contextPath); err != nil {
return nil, fmt.Errorf("context archive not found: %w", err)
}
// Create temporary directory for extraction under storage base path
tmpDir, err := s.storage.TempDir(fmt.Sprintf("jiggablend-metadata-%d-*", jobID))
if err != nil {
return nil, fmt.Errorf("failed to create temporary directory: %w", err)
}
defer func() {
if err := os.RemoveAll(tmpDir); err != nil {
log.Printf("Warning: Failed to clean up temp directory %s: %v", tmpDir, err)
}
}()
// Extract context archive
if err := s.extractTar(contextPath, tmpDir); err != nil {
return nil, fmt.Errorf("failed to extract context: %w", err)
}
// Find .blend file in extracted contents
blendFile := ""
err = filepath.Walk(tmpDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".blend") {
// Check it's not a Blender save file (.blend1, .blend2, etc.)
lower := strings.ToLower(info.Name())
idx := strings.LastIndex(lower, ".blend")
if idx != -1 {
suffix := lower[idx+len(".blend"):]
// If there are digits after .blend, it's a save file
isSaveFile := false
if len(suffix) > 0 {
isSaveFile = true
for _, r := range suffix {
if r < '0' || r > '9' {
isSaveFile = false
break
}
}
}
if !isSaveFile {
blendFile = path
return filepath.SkipAll // Stop walking once we find a blend file
}
}
}
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to find blend file: %w", err)
}
if blendFile == "" {
return nil, fmt.Errorf("no .blend file found in context - the uploaded context archive must contain at least one .blend file for metadata extraction")
}
// Use embedded Python script
scriptPath := filepath.Join(tmpDir, "extract_metadata.py")
if err := os.WriteFile(scriptPath, []byte(scripts.ExtractMetadata), 0644); err != nil {
return nil, fmt.Errorf("failed to create extraction script: %w", err)
}
// Make blend file path relative to tmpDir to avoid path resolution issues
blendFileRel, err := filepath.Rel(tmpDir, blendFile)
if err != nil {
return nil, fmt.Errorf("failed to get relative path for blend file: %w", err)
}
// Execute Blender with Python script using executils
result, err := executils.RunCommand(
"blender",
[]string{"-b", blendFileRel, "--python", "extract_metadata.py"},
tmpDir,
nil, // inherit environment
jobID,
nil, // no process tracker needed for metadata extraction
)
if err != nil {
stderrOutput := ""
stdoutOutput := ""
if result != nil {
stderrOutput = strings.TrimSpace(result.Stderr)
stdoutOutput = strings.TrimSpace(result.Stdout)
}
log.Printf("Blender metadata extraction failed for job %d:", jobID)
if stderrOutput != "" {
log.Printf("Blender stderr: %s", stderrOutput)
}
if stdoutOutput != "" {
log.Printf("Blender stdout (last 500 chars): %s", truncateString(stdoutOutput, 500))
}
if stderrOutput != "" {
return nil, fmt.Errorf("blender metadata extraction failed: %w (stderr: %s)", err, truncateString(stderrOutput, 200))
}
return nil, fmt.Errorf("blender metadata extraction failed: %w", err)
}
// Parse output (metadata is printed to stdout)
metadataJSON := strings.TrimSpace(result.Stdout)
// Extract JSON from output (Blender may print other stuff)
jsonStart := strings.Index(metadataJSON, "{")
jsonEnd := strings.LastIndex(metadataJSON, "}")
if jsonStart == -1 || jsonEnd == -1 || jsonEnd <= jsonStart {
return nil, errors.New("failed to extract JSON from Blender output")
}
metadataJSON = metadataJSON[jsonStart : jsonEnd+1]
var metadata types.BlendMetadata
if err := json.Unmarshal([]byte(metadataJSON), &metadata); err != nil {
return nil, fmt.Errorf("failed to parse metadata JSON: %w", err)
}
log.Printf("Metadata extracted for job %d: frame_start=%d, frame_end=%d", jobID, metadata.FrameStart, metadata.FrameEnd)
return &metadata, nil
}
// extractTar extracts a tar archive to a destination directory
func (s *Manager) extractTar(tarPath, destDir string) error {
log.Printf("Extracting tar archive: %s -> %s", tarPath, destDir)
// Ensure destination directory exists
if err := os.MkdirAll(destDir, 0755); err != nil {
return fmt.Errorf("failed to create destination directory: %w", err)
}
file, err := os.Open(tarPath)
if err != nil {
return fmt.Errorf("failed to open archive: %w", err)
}
defer file.Close()
tr := tar.NewReader(file)
fileCount := 0
dirCount := 0
for {
header, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("failed to read tar header: %w", err)
}
// Sanitize path to prevent directory traversal
target := filepath.Join(destDir, header.Name)
// Ensure target is within destDir
cleanTarget := filepath.Clean(target)
cleanDestDir := filepath.Clean(destDir)
if !strings.HasPrefix(cleanTarget, cleanDestDir+string(os.PathSeparator)) && cleanTarget != cleanDestDir {
log.Printf("ERROR: Invalid file path in TAR - target: %s, destDir: %s", cleanTarget, cleanDestDir)
return fmt.Errorf("invalid file path in archive: %s (target: %s, destDir: %s)", header.Name, cleanTarget, cleanDestDir)
}
// Create parent directories
if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
// Write file
switch header.Typeflag {
case tar.TypeReg:
outFile, err := os.Create(target)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
_, err = io.Copy(outFile, tr)
if err != nil {
outFile.Close()
return fmt.Errorf("failed to write file: %w", err)
}
outFile.Close()
fileCount++
case tar.TypeDir:
dirCount++
}
}
log.Printf("Extraction complete: %d files, %d directories extracted to %s", fileCount, dirCount, destDir)
return nil
}

2501
internal/manager/runners.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,333 @@
package api
import (
"fmt"
"log"
"strings"
"sync"
"time"
"jiggablend/pkg/types"
"github.com/gorilla/websocket"
)
// JobConnection wraps a WebSocket connection for job communication.
type JobConnection struct {
conn *websocket.Conn
writeMu sync.Mutex
stopPing chan struct{}
stopHeartbeat chan struct{}
isConnected bool
connMu sync.RWMutex
}
// NewJobConnection creates a new job connection wrapper.
func NewJobConnection() *JobConnection {
return &JobConnection{}
}
// Connect establishes a WebSocket connection for a job (no runnerID needed).
func (j *JobConnection) Connect(managerURL, jobPath, jobToken string) error {
wsPath := jobPath + "/ws"
wsURL := strings.Replace(managerURL, "http://", "ws://", 1)
wsURL = strings.Replace(wsURL, "https://", "wss://", 1)
wsURL += wsPath
log.Printf("Connecting to job WebSocket: %s", wsPath)
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
conn, _, err := dialer.Dial(wsURL, nil)
if err != nil {
return fmt.Errorf("failed to connect job WebSocket: %w", err)
}
j.conn = conn
// Send auth message
authMsg := map[string]interface{}{
"type": "auth",
"job_token": jobToken,
}
if err := conn.WriteJSON(authMsg); err != nil {
conn.Close()
return fmt.Errorf("failed to send auth: %w", err)
}
// Wait for auth_ok
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
var authResp map[string]string
if err := conn.ReadJSON(&authResp); err != nil {
conn.Close()
return fmt.Errorf("failed to read auth response: %w", err)
}
if authResp["type"] == "error" {
conn.Close()
return fmt.Errorf("auth failed: %s", authResp["message"])
}
if authResp["type"] != "auth_ok" {
conn.Close()
return fmt.Errorf("unexpected auth response: %s", authResp["type"])
}
// Clear read deadline after auth
conn.SetReadDeadline(time.Time{})
// Set up ping/pong handler for keepalive
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(90 * time.Second))
return nil
})
// Start ping goroutine
j.stopPing = make(chan struct{})
j.connMu.Lock()
j.isConnected = true
j.connMu.Unlock()
go j.pingLoop()
// Start WebSocket heartbeat goroutine
j.stopHeartbeat = make(chan struct{})
go j.heartbeatLoop()
return nil
}
// pingLoop sends periodic pings to keep the WebSocket connection alive.
func (j *JobConnection) pingLoop() {
defer func() {
if rec := recover(); rec != nil {
log.Printf("Ping loop panicked: %v", rec)
}
}()
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-j.stopPing:
return
case <-ticker.C:
j.writeMu.Lock()
if j.conn != nil {
deadline := time.Now().Add(10 * time.Second)
if err := j.conn.WriteControl(websocket.PingMessage, []byte{}, deadline); err != nil {
log.Printf("Failed to send ping, closing connection: %v", err)
j.connMu.Lock()
j.isConnected = false
if j.conn != nil {
j.conn.Close()
j.conn = nil
}
j.connMu.Unlock()
}
}
j.writeMu.Unlock()
}
}
}
// Heartbeat sends a heartbeat message over WebSocket to keep runner online.
func (j *JobConnection) Heartbeat() {
if j.conn == nil {
return
}
j.writeMu.Lock()
defer j.writeMu.Unlock()
msg := map[string]interface{}{
"type": "runner_heartbeat",
"timestamp": time.Now().Unix(),
}
if err := j.conn.WriteJSON(msg); err != nil {
log.Printf("Failed to send WebSocket heartbeat: %v", err)
// Handle connection failure
j.connMu.Lock()
j.isConnected = false
if j.conn != nil {
j.conn.Close()
j.conn = nil
}
j.connMu.Unlock()
}
}
// heartbeatLoop sends periodic heartbeat messages over WebSocket.
func (j *JobConnection) heartbeatLoop() {
defer func() {
if rec := recover(); rec != nil {
log.Printf("WebSocket heartbeat loop panicked: %v", rec)
}
}()
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-j.stopHeartbeat:
return
case <-ticker.C:
j.Heartbeat()
}
}
}
// Close closes the WebSocket connection.
func (j *JobConnection) Close() {
j.connMu.Lock()
j.isConnected = false
j.connMu.Unlock()
// Stop heartbeat goroutine
if j.stopHeartbeat != nil {
close(j.stopHeartbeat)
j.stopHeartbeat = nil
}
// Stop ping goroutine
if j.stopPing != nil {
close(j.stopPing)
j.stopPing = nil
}
if j.conn != nil {
j.conn.Close()
j.conn = nil
}
}
// IsConnected returns true if the connection is established.
func (j *JobConnection) IsConnected() bool {
j.connMu.RLock()
defer j.connMu.RUnlock()
return j.isConnected && j.conn != nil
}
// Log sends a log entry to the manager.
func (j *JobConnection) Log(taskID int64, level types.LogLevel, message string) {
if j.conn == nil {
return
}
j.writeMu.Lock()
defer j.writeMu.Unlock()
msg := map[string]interface{}{
"type": "log_entry",
"data": map[string]interface{}{
"task_id": taskID,
"log_level": string(level),
"message": message,
},
"timestamp": time.Now().Unix(),
}
if err := j.conn.WriteJSON(msg); err != nil {
log.Printf("Failed to send job log, connection may be broken: %v", err)
// Close the connection on write error
j.connMu.Lock()
j.isConnected = false
if j.conn != nil {
j.conn.Close()
j.conn = nil
}
j.connMu.Unlock()
}
}
// Progress sends a progress update to the manager.
func (j *JobConnection) Progress(taskID int64, progress float64) {
if j.conn == nil {
return
}
j.writeMu.Lock()
defer j.writeMu.Unlock()
msg := map[string]interface{}{
"type": "progress",
"data": map[string]interface{}{
"task_id": taskID,
"progress": progress,
},
"timestamp": time.Now().Unix(),
}
if err := j.conn.WriteJSON(msg); err != nil {
log.Printf("Failed to send job progress, connection may be broken: %v", err)
// Close the connection on write error
j.connMu.Lock()
j.isConnected = false
if j.conn != nil {
j.conn.Close()
j.conn = nil
}
j.connMu.Unlock()
}
}
// OutputUploaded notifies that an output file was uploaded.
func (j *JobConnection) OutputUploaded(taskID int64, fileName string) {
if j.conn == nil {
return
}
j.writeMu.Lock()
defer j.writeMu.Unlock()
msg := map[string]interface{}{
"type": "output_uploaded",
"data": map[string]interface{}{
"task_id": taskID,
"file_name": fileName,
},
"timestamp": time.Now().Unix(),
}
if err := j.conn.WriteJSON(msg); err != nil {
log.Printf("Failed to send output uploaded, connection may be broken: %v", err)
// Close the connection on write error
j.connMu.Lock()
j.isConnected = false
if j.conn != nil {
j.conn.Close()
j.conn = nil
}
j.connMu.Unlock()
}
}
// Complete sends task completion to the manager.
func (j *JobConnection) Complete(taskID int64, success bool, errorMsg error) {
if j.conn == nil {
log.Printf("Cannot send task complete: WebSocket connection is nil")
return
}
j.writeMu.Lock()
defer j.writeMu.Unlock()
msg := map[string]interface{}{
"type": "task_complete",
"data": map[string]interface{}{
"task_id": taskID,
"success": success,
"error": errorMsg,
},
"timestamp": time.Now().Unix(),
}
if err := j.conn.WriteJSON(msg); err != nil {
log.Printf("Failed to send task complete, connection may be broken: %v", err)
// Close the connection on write error
j.connMu.Lock()
j.isConnected = false
if j.conn != nil {
j.conn.Close()
j.conn = nil
}
j.connMu.Unlock()
}
}

View File

@@ -0,0 +1,421 @@
// Package api provides HTTP and WebSocket communication with the manager server.
package api
import (
"bytes"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"jiggablend/pkg/types"
)
// ManagerClient handles all HTTP communication with the manager server.
type ManagerClient struct {
baseURL string
apiKey string
runnerID int64
httpClient *http.Client // Standard timeout for quick requests
longClient *http.Client // No timeout for large file transfers
}
// NewManagerClient creates a new manager client.
func NewManagerClient(baseURL string) *ManagerClient {
return &ManagerClient{
baseURL: strings.TrimSuffix(baseURL, "/"),
httpClient: &http.Client{Timeout: 30 * time.Second},
longClient: &http.Client{Timeout: 0}, // No timeout for large transfers
}
}
// SetCredentials sets the API key and runner ID after registration.
func (m *ManagerClient) SetCredentials(runnerID int64, apiKey string) {
m.runnerID = runnerID
m.apiKey = apiKey
}
// GetRunnerID returns the registered runner ID.
func (m *ManagerClient) GetRunnerID() int64 {
return m.runnerID
}
// GetAPIKey returns the API key.
func (m *ManagerClient) GetAPIKey() string {
return m.apiKey
}
// GetBaseURL returns the base URL.
func (m *ManagerClient) GetBaseURL() string {
return m.baseURL
}
// Request performs an authenticated HTTP request with standard timeout.
func (m *ManagerClient) Request(method, path string, body []byte) (*http.Response, error) {
return m.doRequest(method, path, body, m.httpClient)
}
// RequestLong performs an authenticated HTTP request with no timeout.
// Use for large file uploads/downloads.
func (m *ManagerClient) RequestLong(method, path string, body []byte) (*http.Response, error) {
return m.doRequest(method, path, body, m.longClient)
}
func (m *ManagerClient) doRequest(method, path string, body []byte, client *http.Client) (*http.Response, error) {
if m.apiKey == "" {
return nil, fmt.Errorf("not authenticated")
}
fullURL := m.baseURL + path
req, err := http.NewRequest(method, fullURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+m.apiKey)
if len(body) > 0 {
req.Header.Set("Content-Type", "application/json")
}
return client.Do(req)
}
// RequestWithToken performs an authenticated HTTP request using a specific token.
func (m *ManagerClient) RequestWithToken(method, path, token string, body []byte) (*http.Response, error) {
return m.doRequestWithToken(method, path, token, body, m.httpClient)
}
// RequestLongWithToken performs a long-running request with a specific token.
func (m *ManagerClient) RequestLongWithToken(method, path, token string, body []byte) (*http.Response, error) {
return m.doRequestWithToken(method, path, token, body, m.longClient)
}
func (m *ManagerClient) doRequestWithToken(method, path, token string, body []byte, client *http.Client) (*http.Response, error) {
fullURL := m.baseURL + path
req, err := http.NewRequest(method, fullURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token)
if len(body) > 0 {
req.Header.Set("Content-Type", "application/json")
}
return client.Do(req)
}
// RegisterRequest is the request body for runner registration.
type RegisterRequest struct {
Name string `json:"name"`
Hostname string `json:"hostname"`
Capabilities string `json:"capabilities"`
APIKey string `json:"api_key"`
Fingerprint string `json:"fingerprint,omitempty"`
}
// RegisterResponse is the response from runner registration.
type RegisterResponse struct {
ID int64 `json:"id"`
}
// Register registers the runner with the manager.
func (m *ManagerClient) Register(name, hostname string, capabilities map[string]interface{}, registrationToken, fingerprint string) (int64, error) {
capsJSON, err := json.Marshal(capabilities)
if err != nil {
return 0, fmt.Errorf("failed to marshal capabilities: %w", err)
}
reqBody := RegisterRequest{
Name: name,
Hostname: hostname,
Capabilities: string(capsJSON),
APIKey: registrationToken,
}
// Only send fingerprint for non-fixed API keys
if !strings.HasPrefix(registrationToken, "jk_r0_") {
reqBody.Fingerprint = fingerprint
}
body, _ := json.Marshal(reqBody)
resp, err := m.httpClient.Post(
m.baseURL+"/api/runner/register",
"application/json",
bytes.NewReader(body),
)
if err != nil {
return 0, fmt.Errorf("connection error: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
bodyBytes, _ := io.ReadAll(resp.Body)
errorBody := string(bodyBytes)
// Check for token-related errors (should not retry)
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusBadRequest {
errorLower := strings.ToLower(errorBody)
if strings.Contains(errorLower, "invalid") ||
strings.Contains(errorLower, "expired") ||
strings.Contains(errorLower, "already used") ||
strings.Contains(errorLower, "token") {
return 0, fmt.Errorf("token error: %s", errorBody)
}
}
return 0, fmt.Errorf("registration failed (status %d): %s", resp.StatusCode, errorBody)
}
var result RegisterResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return 0, fmt.Errorf("failed to decode response: %w", err)
}
m.runnerID = result.ID
m.apiKey = registrationToken
return result.ID, nil
}
// NextJobResponse represents the response from the next-job endpoint.
type NextJobResponse struct {
JobToken string `json:"job_token"`
JobPath string `json:"job_path"`
Task NextJobTaskInfo `json:"task"`
}
// NextJobTaskInfo contains task information from the next-job response.
type NextJobTaskInfo struct {
TaskID int64 `json:"task_id"`
JobID int64 `json:"job_id"`
JobName string `json:"job_name"`
Frame int `json:"frame"`
TaskType string `json:"task_type"`
Metadata *types.BlendMetadata `json:"metadata,omitempty"`
}
// PollNextJob polls the manager for the next available job.
// Returns nil, nil if no job is available.
func (m *ManagerClient) PollNextJob() (*NextJobResponse, error) {
if m.runnerID == 0 || m.apiKey == "" {
return nil, fmt.Errorf("runner not authenticated")
}
path := fmt.Sprintf("/api/runner/workers/%d/next-job", m.runnerID)
resp, err := m.Request("GET", path, nil)
if err != nil {
return nil, fmt.Errorf("failed to poll for job: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNoContent {
return nil, nil // No job available
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("unexpected status %d: %s", resp.StatusCode, string(body))
}
var job NextJobResponse
if err := json.NewDecoder(resp.Body).Decode(&job); err != nil {
return nil, fmt.Errorf("failed to decode job response: %w", err)
}
return &job, nil
}
// DownloadContext downloads the job context tar file.
func (m *ManagerClient) DownloadContext(contextPath, jobToken string) (io.ReadCloser, error) {
resp, err := m.RequestLongWithToken("GET", contextPath, jobToken, nil)
if err != nil {
return nil, fmt.Errorf("failed to download context: %w", err)
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, fmt.Errorf("context download failed with status %d: %s", resp.StatusCode, string(body))
}
return resp.Body, nil
}
// UploadFile uploads a file to the manager.
func (m *ManagerClient) UploadFile(uploadPath, jobToken, filePath string) error {
file, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("failed to open file: %w", err)
}
defer file.Close()
// Create multipart form
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
part, err := writer.CreateFormFile("file", filepath.Base(filePath))
if err != nil {
return fmt.Errorf("failed to create form file: %w", err)
}
if _, err := io.Copy(part, file); err != nil {
return fmt.Errorf("failed to copy file to form: %w", err)
}
writer.Close()
fullURL := m.baseURL + uploadPath
req, err := http.NewRequest("POST", fullURL, body)
if err != nil {
return err
}
req.Header.Set("Authorization", "Bearer "+jobToken)
req.Header.Set("Content-Type", writer.FormDataContentType())
resp, err := m.longClient.Do(req)
if err != nil {
return fmt.Errorf("failed to upload file: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("upload failed with status %d: %s", resp.StatusCode, string(respBody))
}
return nil
}
// GetJobMetadata retrieves job metadata from the manager.
func (m *ManagerClient) GetJobMetadata(jobID int64) (*types.BlendMetadata, error) {
path := fmt.Sprintf("/api/runner/jobs/%d/metadata?runner_id=%d", jobID, m.runnerID)
resp, err := m.Request("GET", path, nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return nil, nil // No metadata found
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("failed to get job metadata: %s", string(body))
}
var metadata types.BlendMetadata
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
return nil, err
}
return &metadata, nil
}
// JobFile represents a file associated with a job.
type JobFile struct {
ID int64 `json:"id"`
JobID int64 `json:"job_id"`
FileType string `json:"file_type"`
FilePath string `json:"file_path"`
FileName string `json:"file_name"`
FileSize int64 `json:"file_size"`
}
// GetJobFiles retrieves the list of files for a job.
func (m *ManagerClient) GetJobFiles(jobID int64) ([]JobFile, error) {
path := fmt.Sprintf("/api/runner/jobs/%d/files?runner_id=%d", jobID, m.runnerID)
resp, err := m.Request("GET", path, nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("failed to get job files: %s", string(body))
}
var files []JobFile
if err := json.NewDecoder(resp.Body).Decode(&files); err != nil {
return nil, err
}
return files, nil
}
// DownloadFrame downloads a frame file from the manager.
func (m *ManagerClient) DownloadFrame(jobID int64, fileName, destPath string) error {
encodedFileName := url.PathEscape(fileName)
path := fmt.Sprintf("/api/runner/files/%d/%s?runner_id=%d", jobID, encodedFileName, m.runnerID)
resp, err := m.RequestLong("GET", path, nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("download failed: %s", string(body))
}
file, err := os.Create(destPath)
if err != nil {
return err
}
defer file.Close()
_, err = io.Copy(file, resp.Body)
return err
}
// SubmitMetadata submits extracted metadata to the manager.
func (m *ManagerClient) SubmitMetadata(jobID int64, metadata types.BlendMetadata) error {
metadataJSON, err := json.Marshal(metadata)
if err != nil {
return fmt.Errorf("failed to marshal metadata: %w", err)
}
path := fmt.Sprintf("/api/runner/jobs/%d/metadata?runner_id=%d", jobID, m.runnerID)
fullURL := m.baseURL + path
req, err := http.NewRequest("POST", fullURL, bytes.NewReader(metadataJSON))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+m.apiKey)
resp, err := m.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to submit metadata: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("metadata submission failed: %s", string(body))
}
return nil
}
// DownloadBlender downloads a Blender version from the manager.
func (m *ManagerClient) DownloadBlender(version string) (io.ReadCloser, error) {
path := fmt.Sprintf("/api/runner/blender/download?version=%s&runner_id=%d", version, m.runnerID)
resp, err := m.RequestLong("GET", path, nil)
if err != nil {
return nil, fmt.Errorf("failed to download blender from manager: %w", err)
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, fmt.Errorf("failed to download blender: status %d, body: %s", resp.StatusCode, string(body))
}
return resp.Body, nil
}

View File

@@ -0,0 +1,87 @@
// Package blender handles Blender binary management and execution.
package blender
import (
"fmt"
"log"
"os"
"path/filepath"
"jiggablend/internal/runner/api"
"jiggablend/internal/runner/workspace"
)
// Manager handles Blender binary downloads and management.
type Manager struct {
manager *api.ManagerClient
workspaceDir string
}
// NewManager creates a new Blender manager.
func NewManager(managerClient *api.ManagerClient, workspaceDir string) *Manager {
return &Manager{
manager: managerClient,
workspaceDir: workspaceDir,
}
}
// GetBinaryPath returns the path to the Blender binary for a specific version.
// Downloads from manager and extracts if not already present.
func (m *Manager) GetBinaryPath(version string) (string, error) {
blenderDir := filepath.Join(m.workspaceDir, "blender-versions")
if err := os.MkdirAll(blenderDir, 0755); err != nil {
return "", fmt.Errorf("failed to create blender directory: %w", err)
}
// Check if already installed - look for version folder first
versionDir := filepath.Join(blenderDir, version)
binaryPath := filepath.Join(versionDir, "blender")
// Check if version folder exists and contains the binary
if versionInfo, err := os.Stat(versionDir); err == nil && versionInfo.IsDir() {
// Version folder exists, check if binary is present
if binaryInfo, err := os.Stat(binaryPath); err == nil {
// Verify it's actually a file (not a directory)
if !binaryInfo.IsDir() {
log.Printf("Found existing Blender %s installation at %s", version, binaryPath)
return binaryPath, nil
}
}
// Version folder exists but binary is missing - might be incomplete installation
log.Printf("Version folder %s exists but binary not found, will re-download", versionDir)
}
// Download from manager
log.Printf("Downloading Blender %s from manager", version)
reader, err := m.manager.DownloadBlender(version)
if err != nil {
return "", err
}
defer reader.Close()
// Manager serves pre-decompressed .tar files - extract directly
log.Printf("Extracting Blender %s...", version)
if err := workspace.ExtractTarStripPrefix(reader, versionDir); err != nil {
return "", fmt.Errorf("failed to extract blender: %w", err)
}
// Verify binary exists
if _, err := os.Stat(binaryPath); err != nil {
return "", fmt.Errorf("blender binary not found after extraction")
}
log.Printf("Blender %s installed at %s", version, binaryPath)
return binaryPath, nil
}
// GetBinaryForJob returns the Blender binary path for a job.
// Uses the version from metadata or falls back to system blender.
func (m *Manager) GetBinaryForJob(version string) (string, error) {
if version == "" {
return "blender", nil // System blender
}
return m.GetBinaryPath(version)
}

View File

@@ -0,0 +1,100 @@
package blender
import (
"regexp"
"strings"
"jiggablend/pkg/types"
)
// FilterLog checks if a Blender log line should be filtered or downgraded.
// Returns (shouldFilter, logLevel) - if shouldFilter is true, the log should be skipped.
func FilterLog(line string) (shouldFilter bool, logLevel types.LogLevel) {
trimmed := strings.TrimSpace(line)
// Filter out empty lines
if trimmed == "" {
return true, types.LogLevelInfo
}
// Filter out separator lines
if trimmed == "--------------------------------------------------------------------" ||
(strings.HasPrefix(trimmed, "-----") && strings.Contains(trimmed, "----")) {
return true, types.LogLevelInfo
}
// Filter out trace headers
upperLine := strings.ToUpper(trimmed)
upperOriginal := strings.ToUpper(line)
if trimmed == "Trace:" ||
trimmed == "Depth Type Name" ||
trimmed == "----- ---- ----" ||
line == "Depth Type Name" ||
line == "----- ---- ----" ||
(strings.Contains(upperLine, "DEPTH") && strings.Contains(upperLine, "TYPE") && strings.Contains(upperLine, "NAME")) ||
(strings.Contains(upperOriginal, "DEPTH") && strings.Contains(upperOriginal, "TYPE") && strings.Contains(upperOriginal, "NAME")) ||
strings.Contains(line, "Depth Type Name") ||
strings.Contains(line, "----- ---- ----") ||
strings.HasPrefix(trimmed, "-----") ||
regexp.MustCompile(`^[-]+\s+[-]+\s+[-]+$`).MatchString(trimmed) {
return true, types.LogLevelInfo
}
// Completely filter out dependency graph messages (they're just noise)
dependencyGraphPatterns := []string{
"Failed to add relation",
"Could not find op_from",
"OperationKey",
"find_node_operation: Failed for",
"BONE_DONE",
"component name:",
"operation code:",
"rope_ctrl_rot_",
}
for _, pattern := range dependencyGraphPatterns {
if strings.Contains(line, pattern) {
return true, types.LogLevelInfo
}
}
// Filter out animation system warnings (invalid drivers are common and harmless)
animationSystemPatterns := []string{
"BKE_animsys_eval_driver: invalid driver",
"bke.anim_sys",
"rotation_quaternion[",
"constraints[",
".influence[0]",
"pose.bones[",
}
for _, pattern := range animationSystemPatterns {
if strings.Contains(line, pattern) {
return true, types.LogLevelInfo
}
}
// Filter out modifier warnings (common when vertices change)
modifierPatterns := []string{
"BKE_modifier_set_error",
"bke.modifier",
"Vertices changed from",
"Modifier:",
}
for _, pattern := range modifierPatterns {
if strings.Contains(line, pattern) {
return true, types.LogLevelInfo
}
}
// Filter out lines that are just numbers or trace depth indicators
// Pattern: number, word, word (e.g., "1 Object timer_box_franck")
if matched, _ := regexp.MatchString(`^\d+\s+\w+\s+\w+`, trimmed); matched {
return true, types.LogLevelInfo
}
return false, types.LogLevelInfo
}

View File

@@ -0,0 +1,143 @@
package blender
import (
"compress/gzip"
"fmt"
"io"
"os"
"os/exec"
)
// ParseVersionFromFile parses the Blender version that a .blend file was saved with.
// Returns major and minor version numbers.
func ParseVersionFromFile(blendPath string) (major, minor int, err error) {
file, err := os.Open(blendPath)
if err != nil {
return 0, 0, fmt.Errorf("failed to open blend file: %w", err)
}
defer file.Close()
// Read the first 12 bytes of the blend file header
// Format: BLENDER-v<major><minor><patch> or BLENDER_v<major><minor><patch>
// The header is: "BLENDER" (7 bytes) + pointer size (1 byte: '-' for 64-bit, '_' for 32-bit)
// + endianness (1 byte: 'v' for little-endian, 'V' for big-endian)
// + version (3 bytes: e.g., "402" for 4.02)
header := make([]byte, 12)
n, err := file.Read(header)
if err != nil || n < 12 {
return 0, 0, fmt.Errorf("failed to read blend file header: %w", err)
}
// Check for BLENDER magic
if string(header[:7]) != "BLENDER" {
// Might be compressed - try to decompress
file.Seek(0, 0)
return parseCompressedVersion(file)
}
// Parse version from bytes 9-11 (3 digits)
versionStr := string(header[9:12])
// Version format changed in Blender 3.0
// Pre-3.0: "279" = 2.79, "280" = 2.80
// 3.0+: "300" = 3.0, "402" = 4.02, "410" = 4.10
if len(versionStr) == 3 {
// First digit is major version
fmt.Sscanf(string(versionStr[0]), "%d", &major)
// Next two digits are minor version
fmt.Sscanf(versionStr[1:3], "%d", &minor)
}
return major, minor, nil
}
// parseCompressedVersion handles gzip and zstd compressed blend files.
func parseCompressedVersion(file *os.File) (major, minor int, err error) {
magic := make([]byte, 4)
if _, err := file.Read(magic); err != nil {
return 0, 0, err
}
file.Seek(0, 0)
if magic[0] == 0x1f && magic[1] == 0x8b {
// gzip compressed
gzReader, err := gzip.NewReader(file)
if err != nil {
return 0, 0, fmt.Errorf("failed to create gzip reader: %w", err)
}
defer gzReader.Close()
header := make([]byte, 12)
n, err := gzReader.Read(header)
if err != nil || n < 12 {
return 0, 0, fmt.Errorf("failed to read compressed blend header: %w", err)
}
if string(header[:7]) != "BLENDER" {
return 0, 0, fmt.Errorf("invalid blend file format")
}
versionStr := string(header[9:12])
if len(versionStr) == 3 {
fmt.Sscanf(string(versionStr[0]), "%d", &major)
fmt.Sscanf(versionStr[1:3], "%d", &minor)
}
return major, minor, nil
}
// Check for zstd magic (Blender 3.0+): 0x28 0xB5 0x2F 0xFD
if magic[0] == 0x28 && magic[1] == 0xb5 && magic[2] == 0x2f && magic[3] == 0xfd {
return parseZstdVersion(file)
}
return 0, 0, fmt.Errorf("unknown blend file format")
}
// parseZstdVersion handles zstd-compressed blend files (Blender 3.0+).
// Uses zstd command line tool since Go doesn't have native zstd support.
func parseZstdVersion(file *os.File) (major, minor int, err error) {
file.Seek(0, 0)
cmd := exec.Command("zstd", "-d", "-c")
cmd.Stdin = file
stdout, err := cmd.StdoutPipe()
if err != nil {
return 0, 0, fmt.Errorf("failed to create zstd stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return 0, 0, fmt.Errorf("failed to start zstd decompression: %w", err)
}
// Read just the header (12 bytes)
header := make([]byte, 12)
n, readErr := io.ReadFull(stdout, header)
// Kill the process early - we only need the header
cmd.Process.Kill()
cmd.Wait()
if readErr != nil || n < 12 {
return 0, 0, fmt.Errorf("failed to read zstd compressed blend header: %v", readErr)
}
if string(header[:7]) != "BLENDER" {
return 0, 0, fmt.Errorf("invalid blend file format in zstd archive")
}
versionStr := string(header[9:12])
if len(versionStr) == 3 {
fmt.Sscanf(string(versionStr[0]), "%d", &major)
fmt.Sscanf(versionStr[1:3], "%d", &minor)
}
return major, minor, nil
}
// VersionString returns a formatted version string like "4.2".
func VersionString(major, minor int) string {
return fmt.Sprintf("%d.%d", major, minor)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,71 @@
// Package encoding handles video encoding with software encoders.
package encoding
import (
"os/exec"
)
// Encoder represents a video encoder.
type Encoder interface {
Name() string
Codec() string
Available() bool
BuildCommand(config *EncodeConfig) *exec.Cmd
}
// EncodeConfig holds configuration for video encoding.
type EncodeConfig struct {
InputPattern string // Input file pattern (e.g., "frame_%04d.exr")
OutputPath string // Output file path
StartFrame int // Starting frame number
FrameRate float64 // Frame rate
WorkDir string // Working directory
UseAlpha bool // Whether to preserve alpha channel
TwoPass bool // Whether to use 2-pass encoding
SourceFormat string // Source format: "exr" or "png" (defaults to "exr")
PreserveHDR bool // Whether to preserve HDR range for EXR (uses HLG with bt709 primaries)
}
// Selector selects the software encoder.
type Selector struct {
h264Encoders []Encoder
av1Encoders []Encoder
vp9Encoders []Encoder
}
// NewSelector creates a new encoder selector with software encoders.
func NewSelector() *Selector {
s := &Selector{}
s.detectEncoders()
return s
}
func (s *Selector) detectEncoders() {
// Use software encoding only - reliable and avoids hardware-specific colorspace issues
s.h264Encoders = []Encoder{
&SoftwareEncoder{codec: "libx264"},
}
s.av1Encoders = []Encoder{
&SoftwareEncoder{codec: "libaom-av1"},
}
s.vp9Encoders = []Encoder{
&SoftwareEncoder{codec: "libvpx-vp9"},
}
}
// SelectH264 returns the software H.264 encoder.
func (s *Selector) SelectH264() Encoder {
return &SoftwareEncoder{codec: "libx264"}
}
// SelectAV1 returns the software AV1 encoder.
func (s *Selector) SelectAV1() Encoder {
return &SoftwareEncoder{codec: "libaom-av1"}
}
// SelectVP9 returns the software VP9 encoder.
func (s *Selector) SelectVP9() Encoder {
return &SoftwareEncoder{codec: "libvpx-vp9"}
}

View File

@@ -0,0 +1,270 @@
package encoding
import (
"fmt"
"log"
"os/exec"
"strconv"
"strings"
)
const (
// CRFH264 is the Constant Rate Factor for H.264 encoding (lower = higher quality, range 0-51)
CRFH264 = 15
// CRFAV1 is the Constant Rate Factor for AV1 encoding (lower = higher quality, range 0-63)
CRFAV1 = 30
// CRFVP9 is the Constant Rate Factor for VP9 encoding (lower = higher quality, range 0-63)
CRFVP9 = 30
)
// tonemapFilter returns the appropriate filter for EXR input.
// For HDR preservation: converts linear RGB (EXR) to bt2020 YUV with HLG transfer function
// Uses zscale to properly convert colorspace from linear RGB to bt2020 YUV while preserving HDR range
// Step 1: Ensure format is gbrpf32le (linear RGB)
// Step 2: Convert transfer function from linear to HLG (arib-std-b67) with bt2020 primaries/matrix
// Step 3: Convert to YUV format
func tonemapFilter(useAlpha bool) string {
// Convert from linear RGB (gbrpf32le) to HLG with bt709 primaries to match PNG appearance
// Based on best practices: convert linear RGB directly to HLG with bt709 primaries
// This matches PNG color appearance (bt709 primaries) while preserving HDR range (HLG transfer)
// zscale uses numeric values:
// primaries: 1=bt709 (matches PNG), 9=bt2020
// matrix: 1=bt709, 9=bt2020nc, 0=gbr (RGB input)
// transfer: 8=linear, 18=arib-std-b67 (HLG)
// Direct conversion: linear RGB -> HLG with bt709 primaries -> bt2020 YUV (for wider gamut metadata)
// The bt709 primaries in the conversion match PNG, but we set bt2020 in metadata for HDR displays
// Convert linear RGB to sRGB first, then convert to HLG
// This approach: linear -> sRGB -> HLG -> bt2020
// Fixes red tint by using sRGB conversion, preserves HDR range with HLG
filter := "format=gbrpf32le,zscale=transferin=8:transfer=13:primariesin=1:primaries=1:matrixin=0:matrix=1:rangein=full:range=full,zscale=transferin=13:transfer=18:primariesin=1:primaries=9:matrixin=1:matrix=9:rangein=full:range=full"
if useAlpha {
return filter + ",format=yuva420p10le"
}
return filter + ",format=yuv420p10le"
}
// SoftwareEncoder implements software encoding (libx264, libaom-av1, libvpx-vp9).
type SoftwareEncoder struct {
codec string
}
func (e *SoftwareEncoder) Name() string { return "software" }
func (e *SoftwareEncoder) Codec() string { return e.codec }
func (e *SoftwareEncoder) Available() bool {
return true // Software encoding is always available
}
func (e *SoftwareEncoder) BuildCommand(config *EncodeConfig) *exec.Cmd {
// Use HDR pixel formats for EXR, SDR for PNG
var pixFmt string
var colorPrimaries, colorTrc, colorspace string
if config.SourceFormat == "png" {
// PNG: SDR format
pixFmt = "yuv420p"
if config.UseAlpha {
pixFmt = "yuva420p"
}
colorPrimaries = "bt709"
colorTrc = "bt709"
colorspace = "bt709"
} else {
// EXR: Use HDR encoding if PreserveHDR is true, otherwise SDR (like PNG)
if config.PreserveHDR {
// HDR: Use HLG transfer with bt709 primaries to preserve HDR range while matching PNG color
pixFmt = "yuv420p10le" // 10-bit to preserve HDR range
if config.UseAlpha {
pixFmt = "yuva420p10le"
}
colorPrimaries = "bt709" // bt709 primaries to match PNG color appearance
colorTrc = "arib-std-b67" // HLG transfer function - preserves HDR range, works on SDR displays
colorspace = "bt709" // bt709 colorspace to match PNG
} else {
// SDR: Treat as SDR (like PNG) - encode as bt709
pixFmt = "yuv420p"
if config.UseAlpha {
pixFmt = "yuva420p"
}
colorPrimaries = "bt709"
colorTrc = "bt709"
colorspace = "bt709"
}
}
var codecArgs []string
switch e.codec {
case "libaom-av1":
codecArgs = []string{"-crf", strconv.Itoa(CRFAV1), "-b:v", "0", "-tiles", "2x2", "-g", "240"}
case "libvpx-vp9":
// VP9 supports alpha and HDR, use good quality settings
codecArgs = []string{"-crf", strconv.Itoa(CRFVP9), "-b:v", "0", "-row-mt", "1", "-g", "240"}
default:
// H.264: Use High 10 profile for HDR EXR (10-bit), High profile for SDR
if config.SourceFormat != "png" && config.PreserveHDR {
codecArgs = []string{"-preset", "veryslow", "-crf", strconv.Itoa(CRFH264), "-profile:v", "high10", "-level", "5.2", "-tune", "film", "-keyint_min", "24", "-g", "240", "-bf", "2", "-refs", "4"}
} else {
codecArgs = []string{"-preset", "veryslow", "-crf", strconv.Itoa(CRFH264), "-profile:v", "high", "-level", "5.2", "-tune", "film", "-keyint_min", "24", "-g", "240", "-bf", "2", "-refs", "4"}
}
}
args := []string{
"-y",
"-f", "image2",
"-start_number", fmt.Sprintf("%d", config.StartFrame),
"-framerate", fmt.Sprintf("%.2f", config.FrameRate),
"-i", config.InputPattern,
"-c:v", e.codec,
"-pix_fmt", pixFmt,
"-r", fmt.Sprintf("%.2f", config.FrameRate),
"-color_primaries", colorPrimaries,
"-color_trc", colorTrc,
"-colorspace", colorspace,
"-color_range", "tv",
}
// Add video filter for EXR: convert linear RGB based on HDR setting
// PNG doesn't need any filter as it's already in sRGB
if config.SourceFormat != "png" {
var vf string
if config.PreserveHDR {
// HDR: Convert linear RGB -> sRGB -> HLG with bt709 primaries
// This preserves HDR range while matching PNG color appearance
vf = "format=gbrpf32le,zscale=transferin=8:transfer=13:primariesin=1:primaries=1:matrixin=0:matrix=1:rangein=full:range=full,zscale=transferin=13:transfer=18:primariesin=1:primaries=1:matrixin=1:matrix=1:rangein=full:range=full"
if config.UseAlpha {
vf += ",format=yuva420p10le"
} else {
vf += ",format=yuv420p10le"
}
} else {
// SDR: Convert linear RGB (EXR) to sRGB (bt709) - simple conversion like Krita does
// zscale: linear (8) -> sRGB (13) with bt709 primaries/matrix
vf = "format=gbrpf32le,zscale=transferin=8:transfer=13:primariesin=1:primaries=1:matrixin=0:matrix=1:rangein=full:range=full"
if config.UseAlpha {
vf += ",format=yuva420p"
} else {
vf += ",format=yuv420p"
}
}
args = append(args, "-vf", vf)
}
args = append(args, codecArgs...)
if config.TwoPass {
// For 2-pass, this builds pass 2 command
args = append(args, "-pass", "2")
}
args = append(args, config.OutputPath)
if config.TwoPass {
log.Printf("Build Software Pass 2 command: ffmpeg %s", strings.Join(args, " "))
} else {
log.Printf("Build Software command: ffmpeg %s", strings.Join(args, " "))
}
cmd := exec.Command("ffmpeg", args...)
cmd.Dir = config.WorkDir
return cmd
}
// BuildPass1Command builds the first pass command for 2-pass encoding.
func (e *SoftwareEncoder) BuildPass1Command(config *EncodeConfig) *exec.Cmd {
// Use HDR pixel formats for EXR, SDR for PNG
var pixFmt string
var colorPrimaries, colorTrc, colorspace string
if config.SourceFormat == "png" {
// PNG: SDR format
pixFmt = "yuv420p"
if config.UseAlpha {
pixFmt = "yuva420p"
}
colorPrimaries = "bt709"
colorTrc = "bt709"
colorspace = "bt709"
} else {
// EXR: Use HDR encoding if PreserveHDR is true, otherwise SDR (like PNG)
if config.PreserveHDR {
// HDR: Use HLG transfer with bt709 primaries to preserve HDR range while matching PNG color
pixFmt = "yuv420p10le" // 10-bit to preserve HDR range
if config.UseAlpha {
pixFmt = "yuva420p10le"
}
colorPrimaries = "bt709" // bt709 primaries to match PNG color appearance
colorTrc = "arib-std-b67" // HLG transfer function - preserves HDR range, works on SDR displays
colorspace = "bt709" // bt709 colorspace to match PNG
} else {
// SDR: Treat as SDR (like PNG) - encode as bt709
pixFmt = "yuv420p"
if config.UseAlpha {
pixFmt = "yuva420p"
}
colorPrimaries = "bt709"
colorTrc = "bt709"
colorspace = "bt709"
}
}
var codecArgs []string
switch e.codec {
case "libaom-av1":
codecArgs = []string{"-crf", strconv.Itoa(CRFAV1), "-b:v", "0", "-tiles", "2x2", "-g", "240"}
case "libvpx-vp9":
// VP9 supports alpha and HDR, use good quality settings
codecArgs = []string{"-crf", strconv.Itoa(CRFVP9), "-b:v", "0", "-row-mt", "1", "-g", "240"}
default:
// H.264: Use High 10 profile for HDR EXR (10-bit), High profile for SDR
if config.SourceFormat != "png" && config.PreserveHDR {
codecArgs = []string{"-preset", "veryslow", "-crf", strconv.Itoa(CRFH264), "-profile:v", "high10", "-level", "5.2", "-tune", "film", "-keyint_min", "24", "-g", "240", "-bf", "2", "-refs", "4"}
} else {
codecArgs = []string{"-preset", "veryslow", "-crf", strconv.Itoa(CRFH264), "-profile:v", "high", "-level", "5.2", "-tune", "film", "-keyint_min", "24", "-g", "240", "-bf", "2", "-refs", "4"}
}
}
args := []string{
"-y",
"-f", "image2",
"-start_number", fmt.Sprintf("%d", config.StartFrame),
"-framerate", fmt.Sprintf("%.2f", config.FrameRate),
"-i", config.InputPattern,
"-c:v", e.codec,
"-pix_fmt", pixFmt,
"-r", fmt.Sprintf("%.2f", config.FrameRate),
"-color_primaries", colorPrimaries,
"-color_trc", colorTrc,
"-colorspace", colorspace,
"-color_range", "tv",
}
// Add video filter for EXR: convert linear RGB based on HDR setting
// PNG doesn't need any filter as it's already in sRGB
if config.SourceFormat != "png" {
var vf string
if config.PreserveHDR {
// HDR: Convert linear RGB -> sRGB -> HLG with bt709 primaries
// This preserves HDR range while matching PNG color appearance
vf = "format=gbrpf32le,zscale=transferin=8:transfer=13:primariesin=1:primaries=1:matrixin=0:matrix=1:rangein=full:range=full,zscale=transferin=13:transfer=18:primariesin=1:primaries=1:matrixin=1:matrix=1:rangein=full:range=full"
if config.UseAlpha {
vf += ",format=yuva420p10le"
} else {
vf += ",format=yuv420p10le"
}
} else {
// SDR: Convert linear RGB (EXR) to sRGB (bt709) - simple conversion like Krita does
// zscale: linear (8) -> sRGB (13) with bt709 primaries/matrix
vf = "format=gbrpf32le,zscale=transferin=8:transfer=13:primariesin=1:primaries=1:matrixin=0:matrix=1:rangein=full:range=full"
if config.UseAlpha {
vf += ",format=yuva420p"
} else {
vf += ",format=yuv420p"
}
}
args = append(args, "-vf", vf)
}
args = append(args, codecArgs...)
args = append(args, "-pass", "1", "-f", "null", "/dev/null")
log.Printf("Build Software Pass 1 command: ffmpeg %s", strings.Join(args, " "))
cmd := exec.Command("ffmpeg", args...)
cmd.Dir = config.WorkDir
return cmd
}

View File

@@ -0,0 +1,980 @@
package encoding
import (
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
)
func TestSoftwareEncoder_BuildCommand_H264_EXR(t *testing.T) {
encoder := &SoftwareEncoder{codec: "libx264"}
config := &EncodeConfig{
InputPattern: "frame_%04d.exr",
OutputPath: "output.mp4",
StartFrame: 1,
FrameRate: 24.0,
WorkDir: "/tmp",
UseAlpha: false,
TwoPass: true,
SourceFormat: "exr",
}
cmd := encoder.BuildCommand(config)
if cmd == nil {
t.Fatal("BuildCommand returned nil")
}
if !strings.Contains(cmd.Path, "ffmpeg") {
t.Errorf("Expected command path to contain 'ffmpeg', got '%s'", cmd.Path)
}
if cmd.Dir != "/tmp" {
t.Errorf("Expected work dir '/tmp', got '%s'", cmd.Dir)
}
args := cmd.Args[1:] // Skip "ffmpeg"
argsStr := strings.Join(args, " ")
// Check required arguments
checks := []struct {
name string
expected string
}{
{"-y flag", "-y"},
{"image2 format", "-f image2"},
{"start number", "-start_number 1"},
{"framerate", "-framerate 24.00"},
{"input pattern", "-i frame_%04d.exr"},
{"codec", "-c:v libx264"},
{"pixel format", "-pix_fmt yuv420p"}, // EXR now treated as SDR (like PNG)
{"frame rate", "-r 24.00"},
{"color primaries", "-color_primaries bt709"}, // EXR now uses bt709 (SDR)
{"color trc", "-color_trc bt709"}, // EXR now uses bt709 (SDR)
{"colorspace", "-colorspace bt709"},
{"color range", "-color_range tv"},
{"video filter", "-vf"},
{"preset", "-preset veryslow"},
{"crf", "-crf 15"},
{"profile", "-profile:v high"}, // EXR now uses high profile (SDR)
{"pass 2", "-pass 2"},
{"output path", "output.mp4"},
}
for _, check := range checks {
if !strings.Contains(argsStr, check.expected) {
t.Errorf("Missing expected argument: %s", check.expected)
}
}
// Verify filter is present for EXR (linear RGB to sRGB conversion, like Krita does)
if !strings.Contains(argsStr, "format=gbrpf32le") {
t.Error("Expected format conversion filter for EXR source, but not found")
}
if !strings.Contains(argsStr, "zscale=transferin=8:transfer=13") {
t.Error("Expected linear to sRGB conversion for EXR source, but not found")
}
}
func TestSoftwareEncoder_BuildCommand_H264_PNG(t *testing.T) {
encoder := &SoftwareEncoder{codec: "libx264"}
config := &EncodeConfig{
InputPattern: "frame_%04d.png",
OutputPath: "output.mp4",
StartFrame: 1,
FrameRate: 24.0,
WorkDir: "/tmp",
UseAlpha: false,
TwoPass: true,
SourceFormat: "png",
}
cmd := encoder.BuildCommand(config)
args := cmd.Args[1:]
argsStr := strings.Join(args, " ")
// PNG should NOT have video filter
if strings.Contains(argsStr, "-vf") {
t.Error("PNG source should not have video filter, but -vf was found")
}
// Should still have all other required args
if !strings.Contains(argsStr, "-c:v libx264") {
t.Error("Missing codec argument")
}
}
func TestSoftwareEncoder_BuildCommand_AV1_WithAlpha(t *testing.T) {
encoder := &SoftwareEncoder{codec: "libaom-av1"}
config := &EncodeConfig{
InputPattern: "frame_%04d.exr",
OutputPath: "output.mp4",
StartFrame: 100,
FrameRate: 30.0,
WorkDir: "/tmp",
UseAlpha: true,
TwoPass: true,
SourceFormat: "exr",
}
cmd := encoder.BuildCommand(config)
args := cmd.Args[1:]
argsStr := strings.Join(args, " ")
// Check alpha-specific settings
if !strings.Contains(argsStr, "-pix_fmt yuva420p") {
t.Error("Expected yuva420p pixel format for alpha, but not found")
}
// Check AV1-specific arguments
av1Checks := []string{
"-c:v libaom-av1",
"-crf 30",
"-b:v 0",
"-tiles 2x2",
"-g 240",
}
for _, check := range av1Checks {
if !strings.Contains(argsStr, check) {
t.Errorf("Missing AV1 argument: %s", check)
}
}
// Check tonemap filter includes alpha format
if !strings.Contains(argsStr, "format=yuva420p") {
t.Error("Expected tonemap filter to output yuva420p for alpha, but not found")
}
}
func TestSoftwareEncoder_BuildCommand_VP9(t *testing.T) {
encoder := &SoftwareEncoder{codec: "libvpx-vp9"}
config := &EncodeConfig{
InputPattern: "frame_%04d.exr",
OutputPath: "output.webm",
StartFrame: 1,
FrameRate: 24.0,
WorkDir: "/tmp",
UseAlpha: true,
TwoPass: true,
SourceFormat: "exr",
}
cmd := encoder.BuildCommand(config)
args := cmd.Args[1:]
argsStr := strings.Join(args, " ")
// Check VP9-specific arguments
vp9Checks := []string{
"-c:v libvpx-vp9",
"-crf 30",
"-b:v 0",
"-row-mt 1",
"-g 240",
}
for _, check := range vp9Checks {
if !strings.Contains(argsStr, check) {
t.Errorf("Missing VP9 argument: %s", check)
}
}
}
func TestSoftwareEncoder_BuildPass1Command(t *testing.T) {
encoder := &SoftwareEncoder{codec: "libx264"}
config := &EncodeConfig{
InputPattern: "frame_%04d.exr",
OutputPath: "output.mp4",
StartFrame: 1,
FrameRate: 24.0,
WorkDir: "/tmp",
UseAlpha: false,
TwoPass: true,
SourceFormat: "exr",
}
cmd := encoder.BuildPass1Command(config)
args := cmd.Args[1:]
argsStr := strings.Join(args, " ")
// Pass 1 should have -pass 1 and output to null
if !strings.Contains(argsStr, "-pass 1") {
t.Error("Pass 1 command should include '-pass 1'")
}
if !strings.Contains(argsStr, "-f null") {
t.Error("Pass 1 command should include '-f null'")
}
if !strings.Contains(argsStr, "/dev/null") {
t.Error("Pass 1 command should output to /dev/null")
}
// Should NOT have output path
if strings.Contains(argsStr, "output.mp4") {
t.Error("Pass 1 command should not include output path")
}
}
func TestSoftwareEncoder_BuildPass1Command_AV1(t *testing.T) {
encoder := &SoftwareEncoder{codec: "libaom-av1"}
config := &EncodeConfig{
InputPattern: "frame_%04d.exr",
OutputPath: "output.mp4",
StartFrame: 1,
FrameRate: 24.0,
WorkDir: "/tmp",
UseAlpha: false,
TwoPass: true,
SourceFormat: "exr",
}
cmd := encoder.BuildPass1Command(config)
args := cmd.Args[1:]
argsStr := strings.Join(args, " ")
// Pass 1 should have -pass 1 and output to null
if !strings.Contains(argsStr, "-pass 1") {
t.Error("Pass 1 command should include '-pass 1'")
}
if !strings.Contains(argsStr, "-f null") {
t.Error("Pass 1 command should include '-f null'")
}
if !strings.Contains(argsStr, "/dev/null") {
t.Error("Pass 1 command should output to /dev/null")
}
// Check AV1-specific arguments in pass 1
av1Checks := []string{
"-c:v libaom-av1",
"-crf 30",
"-b:v 0",
"-tiles 2x2",
"-g 240",
}
for _, check := range av1Checks {
if !strings.Contains(argsStr, check) {
t.Errorf("Missing AV1 argument in pass 1: %s", check)
}
}
}
func TestSoftwareEncoder_BuildPass1Command_VP9(t *testing.T) {
encoder := &SoftwareEncoder{codec: "libvpx-vp9"}
config := &EncodeConfig{
InputPattern: "frame_%04d.exr",
OutputPath: "output.webm",
StartFrame: 1,
FrameRate: 24.0,
WorkDir: "/tmp",
UseAlpha: false,
TwoPass: true,
SourceFormat: "exr",
}
cmd := encoder.BuildPass1Command(config)
args := cmd.Args[1:]
argsStr := strings.Join(args, " ")
// Pass 1 should have -pass 1 and output to null
if !strings.Contains(argsStr, "-pass 1") {
t.Error("Pass 1 command should include '-pass 1'")
}
if !strings.Contains(argsStr, "-f null") {
t.Error("Pass 1 command should include '-f null'")
}
if !strings.Contains(argsStr, "/dev/null") {
t.Error("Pass 1 command should output to /dev/null")
}
// Check VP9-specific arguments in pass 1
vp9Checks := []string{
"-c:v libvpx-vp9",
"-crf 30",
"-b:v 0",
"-row-mt 1",
"-g 240",
}
for _, check := range vp9Checks {
if !strings.Contains(argsStr, check) {
t.Errorf("Missing VP9 argument in pass 1: %s", check)
}
}
}
func TestSoftwareEncoder_BuildCommand_NoTwoPass(t *testing.T) {
encoder := &SoftwareEncoder{codec: "libx264"}
config := &EncodeConfig{
InputPattern: "frame_%04d.exr",
OutputPath: "output.mp4",
StartFrame: 1,
FrameRate: 24.0,
WorkDir: "/tmp",
UseAlpha: false,
TwoPass: false,
SourceFormat: "exr",
}
cmd := encoder.BuildCommand(config)
args := cmd.Args[1:]
argsStr := strings.Join(args, " ")
// Should NOT have -pass flag when TwoPass is false
if strings.Contains(argsStr, "-pass") {
t.Error("Command should not include -pass flag when TwoPass is false")
}
}
func TestSelector_SelectH264(t *testing.T) {
selector := NewSelector()
encoder := selector.SelectH264()
if encoder == nil {
t.Fatal("SelectH264 returned nil")
}
if encoder.Codec() != "libx264" {
t.Errorf("Expected codec 'libx264', got '%s'", encoder.Codec())
}
if encoder.Name() != "software" {
t.Errorf("Expected name 'software', got '%s'", encoder.Name())
}
}
func TestSelector_SelectAV1(t *testing.T) {
selector := NewSelector()
encoder := selector.SelectAV1()
if encoder == nil {
t.Fatal("SelectAV1 returned nil")
}
if encoder.Codec() != "libaom-av1" {
t.Errorf("Expected codec 'libaom-av1', got '%s'", encoder.Codec())
}
}
func TestSelector_SelectVP9(t *testing.T) {
selector := NewSelector()
encoder := selector.SelectVP9()
if encoder == nil {
t.Fatal("SelectVP9 returned nil")
}
if encoder.Codec() != "libvpx-vp9" {
t.Errorf("Expected codec 'libvpx-vp9', got '%s'", encoder.Codec())
}
}
func TestTonemapFilter_WithAlpha(t *testing.T) {
filter := tonemapFilter(true)
// Filter should convert from gbrpf32le to yuva420p10le with proper colorspace conversion
if !strings.Contains(filter, "yuva420p10le") {
t.Error("Tonemap filter with alpha should output yuva420p10le format for HDR")
}
if !strings.Contains(filter, "gbrpf32le") {
t.Error("Tonemap filter should start with gbrpf32le format")
}
// Should use zscale for colorspace conversion from linear RGB to bt2020 YUV
if !strings.Contains(filter, "zscale") {
t.Error("Tonemap filter should use zscale for colorspace conversion")
}
// Check for HLG transfer function (numeric value 18 or string arib-std-b67)
if !strings.Contains(filter, "transfer=18") && !strings.Contains(filter, "transfer=arib-std-b67") {
t.Error("Tonemap filter should use HLG transfer function (18 or arib-std-b67)")
}
}
func TestTonemapFilter_WithoutAlpha(t *testing.T) {
filter := tonemapFilter(false)
// Filter should convert from gbrpf32le to yuv420p10le with proper colorspace conversion
if !strings.Contains(filter, "yuv420p10le") {
t.Error("Tonemap filter without alpha should output yuv420p10le format for HDR")
}
if strings.Contains(filter, "yuva420p") {
t.Error("Tonemap filter without alpha should not output yuva420p format")
}
if !strings.Contains(filter, "gbrpf32le") {
t.Error("Tonemap filter should start with gbrpf32le format")
}
// Should use zscale for colorspace conversion from linear RGB to bt2020 YUV
if !strings.Contains(filter, "zscale") {
t.Error("Tonemap filter should use zscale for colorspace conversion")
}
// Check for HLG transfer function (numeric value 18 or string arib-std-b67)
if !strings.Contains(filter, "transfer=18") && !strings.Contains(filter, "transfer=arib-std-b67") {
t.Error("Tonemap filter should use HLG transfer function (18 or arib-std-b67)")
}
}
func TestSoftwareEncoder_Available(t *testing.T) {
encoder := &SoftwareEncoder{codec: "libx264"}
if !encoder.Available() {
t.Error("Software encoder should always be available")
}
}
func TestEncodeConfig_DefaultSourceFormat(t *testing.T) {
config := &EncodeConfig{
InputPattern: "frame_%04d.exr",
OutputPath: "output.mp4",
StartFrame: 1,
FrameRate: 24.0,
WorkDir: "/tmp",
UseAlpha: false,
TwoPass: false,
// SourceFormat not set, should default to empty string (treated as exr)
}
encoder := &SoftwareEncoder{codec: "libx264"}
cmd := encoder.BuildCommand(config)
args := strings.Join(cmd.Args[1:], " ")
// Should still have tonemap filter when SourceFormat is empty (defaults to exr behavior)
if !strings.Contains(args, "-vf") {
t.Error("Empty SourceFormat should default to EXR behavior with tonemap filter")
}
}
func TestCommandOrder(t *testing.T) {
encoder := &SoftwareEncoder{codec: "libx264"}
config := &EncodeConfig{
InputPattern: "frame_%04d.exr",
OutputPath: "output.mp4",
StartFrame: 1,
FrameRate: 24.0,
WorkDir: "/tmp",
UseAlpha: false,
TwoPass: true,
SourceFormat: "exr",
}
cmd := encoder.BuildCommand(config)
args := cmd.Args[1:]
// Verify argument order: input should come before codec
inputIdx := -1
codecIdx := -1
vfIdx := -1
for i, arg := range args {
if arg == "-i" && i+1 < len(args) && args[i+1] == "frame_%04d.exr" {
inputIdx = i
}
if arg == "-c:v" && i+1 < len(args) && args[i+1] == "libx264" {
codecIdx = i
}
if arg == "-vf" {
vfIdx = i
}
}
if inputIdx == -1 {
t.Fatal("Input pattern not found in command")
}
if codecIdx == -1 {
t.Fatal("Codec not found in command")
}
if vfIdx == -1 {
t.Fatal("Video filter not found in command")
}
// Input should come before codec
if inputIdx >= codecIdx {
t.Error("Input pattern should come before codec in command")
}
// Video filter should come after input (order: input -> codec -> colorspace -> filter -> codec args)
// In practice, the filter comes after codec and colorspace metadata but before codec-specific args
if vfIdx <= inputIdx {
t.Error("Video filter should come after input")
}
}
func TestCommand_ColorspaceMetadata(t *testing.T) {
encoder := &SoftwareEncoder{codec: "libx264"}
config := &EncodeConfig{
InputPattern: "frame_%04d.exr",
OutputPath: "output.mp4",
StartFrame: 1,
FrameRate: 24.0,
WorkDir: "/tmp",
UseAlpha: false,
TwoPass: false,
SourceFormat: "exr",
PreserveHDR: false, // SDR encoding
}
cmd := encoder.BuildCommand(config)
args := cmd.Args[1:]
argsStr := strings.Join(args, " ")
// Verify all SDR colorspace metadata is present for EXR (SDR encoding)
colorspaceArgs := []string{
"-color_primaries bt709", // EXR uses bt709 (SDR)
"-color_trc bt709", // EXR uses bt709 (SDR)
"-colorspace bt709",
"-color_range tv",
}
for _, arg := range colorspaceArgs {
if !strings.Contains(argsStr, arg) {
t.Errorf("Missing colorspace metadata: %s", arg)
}
}
// Verify SDR pixel format
if !strings.Contains(argsStr, "-pix_fmt yuv420p") {
t.Error("SDR encoding should use yuv420p pixel format")
}
// Verify H.264 high profile (not high10)
if !strings.Contains(argsStr, "-profile:v high") {
t.Error("SDR encoding should use high profile")
}
if strings.Contains(argsStr, "-profile:v high10") {
t.Error("SDR encoding should not use high10 profile")
}
}
func TestCommand_HDR_ColorspaceMetadata(t *testing.T) {
encoder := &SoftwareEncoder{codec: "libx264"}
config := &EncodeConfig{
InputPattern: "frame_%04d.exr",
OutputPath: "output.mp4",
StartFrame: 1,
FrameRate: 24.0,
WorkDir: "/tmp",
UseAlpha: false,
TwoPass: false,
SourceFormat: "exr",
PreserveHDR: true, // HDR encoding
}
cmd := encoder.BuildCommand(config)
args := cmd.Args[1:]
argsStr := strings.Join(args, " ")
// Verify all HDR colorspace metadata is present for EXR (HDR encoding)
colorspaceArgs := []string{
"-color_primaries bt709", // bt709 primaries to match PNG color appearance
"-color_trc arib-std-b67", // HLG transfer function for HDR/SDR compatibility
"-colorspace bt709", // bt709 colorspace to match PNG
"-color_range tv",
}
for _, arg := range colorspaceArgs {
if !strings.Contains(argsStr, arg) {
t.Errorf("Missing HDR colorspace metadata: %s", arg)
}
}
// Verify HDR pixel format (10-bit)
if !strings.Contains(argsStr, "-pix_fmt yuv420p10le") {
t.Error("HDR encoding should use yuv420p10le pixel format")
}
// Verify H.264 high10 profile (for 10-bit)
if !strings.Contains(argsStr, "-profile:v high10") {
t.Error("HDR encoding should use high10 profile")
}
// Verify HDR filter chain (linear -> sRGB -> HLG)
if !strings.Contains(argsStr, "-vf") {
t.Fatal("HDR encoding should have video filter")
}
vfIdx := -1
for i, arg := range args {
if arg == "-vf" && i+1 < len(args) {
vfIdx = i + 1
break
}
}
if vfIdx == -1 {
t.Fatal("Video filter not found")
}
filter := args[vfIdx]
if !strings.Contains(filter, "transfer=18") {
t.Error("HDR filter should convert to HLG (transfer=18)")
}
if !strings.Contains(filter, "yuv420p10le") {
t.Error("HDR filter should output yuv420p10le format")
}
}
// Integration tests using example files
func TestIntegration_Encode_EXR_H264(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
// Check if example file exists
exampleDir := filepath.Join("..", "..", "..", "examples")
exrFile := filepath.Join(exampleDir, "frame_0800.exr")
if _, err := os.Stat(exrFile); os.IsNotExist(err) {
t.Skipf("Example file not found: %s", exrFile)
}
// Get absolute paths
workspaceRoot, err := filepath.Abs(filepath.Join("..", "..", ".."))
if err != nil {
t.Fatalf("Failed to get workspace root: %v", err)
}
exampleDirAbs, err := filepath.Abs(exampleDir)
if err != nil {
t.Fatalf("Failed to get example directory: %v", err)
}
tmpDir := filepath.Join(workspaceRoot, "tmp")
if err := os.MkdirAll(tmpDir, 0755); err != nil {
t.Fatalf("Failed to create tmp directory: %v", err)
}
encoder := &SoftwareEncoder{codec: "libx264"}
config := &EncodeConfig{
InputPattern: filepath.Join(exampleDirAbs, "frame_%04d.exr"),
OutputPath: filepath.Join(tmpDir, "test_exr_h264.mp4"),
StartFrame: 800,
FrameRate: 24.0,
WorkDir: tmpDir,
UseAlpha: false,
TwoPass: false, // Use single pass for faster testing
SourceFormat: "exr",
}
// Build and run command
cmd := encoder.BuildCommand(config)
if cmd == nil {
t.Fatal("BuildCommand returned nil")
}
// Capture stderr to see what went wrong
output, err := cmd.CombinedOutput()
if err != nil {
t.Errorf("FFmpeg command failed: %v\nCommand output: %s", err, string(output))
return
}
// Verify output file was created
if _, err := os.Stat(config.OutputPath); os.IsNotExist(err) {
t.Errorf("Output file was not created: %s\nCommand output: %s", config.OutputPath, string(output))
} else {
t.Logf("Successfully created output file: %s", config.OutputPath)
// Verify file has content
info, _ := os.Stat(config.OutputPath)
if info.Size() == 0 {
t.Errorf("Output file was created but is empty\nCommand output: %s", string(output))
} else {
t.Logf("Output file size: %d bytes", info.Size())
}
}
}
func TestIntegration_Encode_PNG_H264(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
// Check if example file exists
exampleDir := filepath.Join("..", "..", "..", "examples")
pngFile := filepath.Join(exampleDir, "frame_0800.png")
if _, err := os.Stat(pngFile); os.IsNotExist(err) {
t.Skipf("Example file not found: %s", pngFile)
}
// Get absolute paths
workspaceRoot, err := filepath.Abs(filepath.Join("..", "..", ".."))
if err != nil {
t.Fatalf("Failed to get workspace root: %v", err)
}
exampleDirAbs, err := filepath.Abs(exampleDir)
if err != nil {
t.Fatalf("Failed to get example directory: %v", err)
}
tmpDir := filepath.Join(workspaceRoot, "tmp")
if err := os.MkdirAll(tmpDir, 0755); err != nil {
t.Fatalf("Failed to create tmp directory: %v", err)
}
encoder := &SoftwareEncoder{codec: "libx264"}
config := &EncodeConfig{
InputPattern: filepath.Join(exampleDirAbs, "frame_%04d.png"),
OutputPath: filepath.Join(tmpDir, "test_png_h264.mp4"),
StartFrame: 800,
FrameRate: 24.0,
WorkDir: tmpDir,
UseAlpha: false,
TwoPass: false, // Use single pass for faster testing
SourceFormat: "png",
}
// Build and run command
cmd := encoder.BuildCommand(config)
if cmd == nil {
t.Fatal("BuildCommand returned nil")
}
// Verify no video filter is used for PNG
argsStr := strings.Join(cmd.Args, " ")
if strings.Contains(argsStr, "-vf") {
t.Error("PNG encoding should not use video filter, but -vf was found in command")
}
// Run the command
cmdOutput, err := cmd.CombinedOutput()
if err != nil {
t.Errorf("FFmpeg command failed: %v\nCommand output: %s", err, string(cmdOutput))
return
}
// Verify output file was created
if _, err := os.Stat(config.OutputPath); os.IsNotExist(err) {
t.Errorf("Output file was not created: %s\nCommand output: %s", config.OutputPath, string(cmdOutput))
} else {
t.Logf("Successfully created output file: %s", config.OutputPath)
info, _ := os.Stat(config.OutputPath)
if info.Size() == 0 {
t.Error("Output file was created but is empty")
} else {
t.Logf("Output file size: %d bytes", info.Size())
}
}
}
func TestIntegration_Encode_EXR_VP9(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
// Check if example file exists
exampleDir := filepath.Join("..", "..", "..", "examples")
exrFile := filepath.Join(exampleDir, "frame_0800.exr")
if _, err := os.Stat(exrFile); os.IsNotExist(err) {
t.Skipf("Example file not found: %s", exrFile)
}
// Check if VP9 encoder is available
checkCmd := exec.Command("ffmpeg", "-hide_banner", "-encoders")
checkOutput, err := checkCmd.CombinedOutput()
if err != nil || !strings.Contains(string(checkOutput), "libvpx-vp9") {
t.Skip("VP9 encoder (libvpx-vp9) not available in ffmpeg")
}
// Get absolute paths
workspaceRoot, err := filepath.Abs(filepath.Join("..", "..", ".."))
if err != nil {
t.Fatalf("Failed to get workspace root: %v", err)
}
exampleDirAbs, err := filepath.Abs(exampleDir)
if err != nil {
t.Fatalf("Failed to get example directory: %v", err)
}
tmpDir := filepath.Join(workspaceRoot, "tmp")
if err := os.MkdirAll(tmpDir, 0755); err != nil {
t.Fatalf("Failed to create tmp directory: %v", err)
}
encoder := &SoftwareEncoder{codec: "libvpx-vp9"}
config := &EncodeConfig{
InputPattern: filepath.Join(exampleDirAbs, "frame_%04d.exr"),
OutputPath: filepath.Join(tmpDir, "test_exr_vp9.webm"),
StartFrame: 800,
FrameRate: 24.0,
WorkDir: tmpDir,
UseAlpha: false,
TwoPass: false, // Use single pass for faster testing
SourceFormat: "exr",
}
// Build and run command
cmd := encoder.BuildCommand(config)
if cmd == nil {
t.Fatal("BuildCommand returned nil")
}
// Capture stderr to see what went wrong
output, err := cmd.CombinedOutput()
if err != nil {
t.Errorf("FFmpeg command failed: %v\nCommand output: %s", err, string(output))
return
}
// Verify output file was created
if _, err := os.Stat(config.OutputPath); os.IsNotExist(err) {
t.Errorf("Output file was not created: %s\nCommand output: %s", config.OutputPath, string(output))
} else {
t.Logf("Successfully created output file: %s", config.OutputPath)
// Verify file has content
info, _ := os.Stat(config.OutputPath)
if info.Size() == 0 {
t.Errorf("Output file was created but is empty\nCommand output: %s", string(output))
} else {
t.Logf("Output file size: %d bytes", info.Size())
}
}
}
func TestIntegration_Encode_EXR_AV1(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
// Check if example file exists
exampleDir := filepath.Join("..", "..", "..", "examples")
exrFile := filepath.Join(exampleDir, "frame_0800.exr")
if _, err := os.Stat(exrFile); os.IsNotExist(err) {
t.Skipf("Example file not found: %s", exrFile)
}
// Check if AV1 encoder is available
checkCmd := exec.Command("ffmpeg", "-hide_banner", "-encoders")
output, err := checkCmd.CombinedOutput()
if err != nil || !strings.Contains(string(output), "libaom-av1") {
t.Skip("AV1 encoder (libaom-av1) not available in ffmpeg")
}
// Get absolute paths
workspaceRoot, err := filepath.Abs(filepath.Join("..", "..", ".."))
if err != nil {
t.Fatalf("Failed to get workspace root: %v", err)
}
exampleDirAbs, err := filepath.Abs(exampleDir)
if err != nil {
t.Fatalf("Failed to get example directory: %v", err)
}
tmpDir := filepath.Join(workspaceRoot, "tmp")
if err := os.MkdirAll(tmpDir, 0755); err != nil {
t.Fatalf("Failed to create tmp directory: %v", err)
}
encoder := &SoftwareEncoder{codec: "libaom-av1"}
config := &EncodeConfig{
InputPattern: filepath.Join(exampleDirAbs, "frame_%04d.exr"),
OutputPath: filepath.Join(tmpDir, "test_exr_av1.mp4"),
StartFrame: 800,
FrameRate: 24.0,
WorkDir: tmpDir,
UseAlpha: false,
TwoPass: false,
SourceFormat: "exr",
}
// Build and run command
cmd := encoder.BuildCommand(config)
cmdOutput, err := cmd.CombinedOutput()
if err != nil {
t.Errorf("FFmpeg command failed: %v\nCommand output: %s", err, string(cmdOutput))
return
}
// Verify output file was created
if _, err := os.Stat(config.OutputPath); os.IsNotExist(err) {
t.Errorf("Output file was not created: %s\nCommand output: %s", config.OutputPath, string(cmdOutput))
} else {
t.Logf("Successfully created AV1 output file: %s", config.OutputPath)
info, _ := os.Stat(config.OutputPath)
if info.Size() == 0 {
t.Errorf("Output file was created but is empty\nCommand output: %s", string(cmdOutput))
} else {
t.Logf("Output file size: %d bytes", info.Size())
}
}
}
func TestIntegration_Encode_EXR_VP9_WithAlpha(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
// Check if example file exists
exampleDir := filepath.Join("..", "..", "..", "examples")
exrFile := filepath.Join(exampleDir, "frame_0800.exr")
if _, err := os.Stat(exrFile); os.IsNotExist(err) {
t.Skipf("Example file not found: %s", exrFile)
}
// Check if VP9 encoder is available
checkCmd := exec.Command("ffmpeg", "-hide_banner", "-encoders")
output, err := checkCmd.CombinedOutput()
if err != nil || !strings.Contains(string(output), "libvpx-vp9") {
t.Skip("VP9 encoder (libvpx-vp9) not available in ffmpeg")
}
// Get absolute paths
workspaceRoot, err := filepath.Abs(filepath.Join("..", "..", ".."))
if err != nil {
t.Fatalf("Failed to get workspace root: %v", err)
}
exampleDirAbs, err := filepath.Abs(exampleDir)
if err != nil {
t.Fatalf("Failed to get example directory: %v", err)
}
tmpDir := filepath.Join(workspaceRoot, "tmp")
if err := os.MkdirAll(tmpDir, 0755); err != nil {
t.Fatalf("Failed to create tmp directory: %v", err)
}
encoder := &SoftwareEncoder{codec: "libvpx-vp9"}
config := &EncodeConfig{
InputPattern: filepath.Join(exampleDirAbs, "frame_%04d.exr"),
OutputPath: filepath.Join(tmpDir, "test_exr_vp9_alpha.webm"),
StartFrame: 800,
FrameRate: 24.0,
WorkDir: tmpDir,
UseAlpha: true, // Test with alpha
TwoPass: false, // Use single pass for faster testing
SourceFormat: "exr",
}
// Build and run command
cmd := encoder.BuildCommand(config)
if cmd == nil {
t.Fatal("BuildCommand returned nil")
}
// Capture stderr to see what went wrong
cmdOutput, err := cmd.CombinedOutput()
if err != nil {
t.Errorf("FFmpeg command failed: %v\nCommand output: %s", err, string(cmdOutput))
return
}
// Verify output file was created
if _, err := os.Stat(config.OutputPath); os.IsNotExist(err) {
t.Errorf("Output file was not created: %s\nCommand output: %s", config.OutputPath, string(cmdOutput))
} else {
t.Logf("Successfully created VP9 output file with alpha: %s", config.OutputPath)
info, _ := os.Stat(config.OutputPath)
if info.Size() == 0 {
t.Errorf("Output file was created but is empty\nCommand output: %s", string(cmdOutput))
} else {
t.Logf("Output file size: %d bytes", info.Size())
}
}
}
// Helper function to copy files
func copyFile(src, dst string) error {
data, err := os.ReadFile(src)
if err != nil {
return err
}
return os.WriteFile(dst, data, 0644)
}

361
internal/runner/runner.go Normal file
View File

@@ -0,0 +1,361 @@
// Package runner provides the Jiggablend render runner.
package runner
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"log"
"net"
"os"
"os/exec"
"strings"
"sync"
"time"
"jiggablend/internal/runner/api"
"jiggablend/internal/runner/blender"
"jiggablend/internal/runner/encoding"
"jiggablend/internal/runner/tasks"
"jiggablend/internal/runner/workspace"
"jiggablend/pkg/executils"
"jiggablend/pkg/types"
)
// Runner is the main render runner.
type Runner struct {
id int64
name string
hostname string
manager *api.ManagerClient
workspace *workspace.Manager
blender *blender.Manager
encoder *encoding.Selector
processes *executils.ProcessTracker
processors map[string]tasks.Processor
stopChan chan struct{}
fingerprint string
fingerprintMu sync.RWMutex
}
// New creates a new runner.
func New(managerURL, name, hostname string) *Runner {
manager := api.NewManagerClient(managerURL)
r := &Runner{
name: name,
hostname: hostname,
manager: manager,
processes: executils.NewProcessTracker(),
stopChan: make(chan struct{}),
processors: make(map[string]tasks.Processor),
}
// Generate fingerprint
r.generateFingerprint()
return r
}
// CheckRequiredTools verifies that required external tools are available.
func (r *Runner) CheckRequiredTools() error {
if err := exec.Command("zstd", "--version").Run(); err != nil {
return fmt.Errorf("zstd not found - required for compressed blend file support. Install with: apt install zstd")
}
log.Printf("Found zstd for compressed blend file support")
if err := exec.Command("xvfb-run", "--help").Run(); err != nil {
return fmt.Errorf("xvfb-run not found - required for headless Blender rendering. Install with: apt install xvfb")
}
log.Printf("Found xvfb-run for headless rendering without -b option")
return nil
}
var cachedCapabilities map[string]interface{} = nil
// ProbeCapabilities detects hardware capabilities.
func (r *Runner) ProbeCapabilities() map[string]interface{} {
if cachedCapabilities != nil {
return cachedCapabilities
}
caps := make(map[string]interface{})
// Check for ffmpeg and probe encoding capabilities
if err := exec.Command("ffmpeg", "-version").Run(); err == nil {
caps["ffmpeg"] = true
} else {
caps["ffmpeg"] = false
}
cachedCapabilities = caps
return caps
}
// Register registers the runner with the manager.
func (r *Runner) Register(apiKey string) (int64, error) {
caps := r.ProbeCapabilities()
id, err := r.manager.Register(r.name, r.hostname, caps, apiKey, r.GetFingerprint())
if err != nil {
return 0, err
}
r.id = id
// Initialize workspace after registration
r.workspace = workspace.NewManager(r.name)
// Initialize blender manager
r.blender = blender.NewManager(r.manager, r.workspace.BaseDir())
// Initialize encoder selector
r.encoder = encoding.NewSelector()
// Register task processors
r.processors["render"] = tasks.NewRenderProcessor()
r.processors["encode"] = tasks.NewEncodeProcessor()
return id, nil
}
// Start starts the job polling loop.
func (r *Runner) Start(pollInterval time.Duration) {
log.Printf("Starting job polling loop (interval: %v)", pollInterval)
for {
select {
case <-r.stopChan:
log.Printf("Stopping job polling loop")
return
default:
}
log.Printf("Polling for next job (runner ID: %d)", r.id)
job, err := r.manager.PollNextJob()
if err != nil {
log.Printf("Error polling for job: %v", err)
time.Sleep(pollInterval)
continue
}
if job == nil {
log.Printf("No job available, sleeping for %v", pollInterval)
time.Sleep(pollInterval)
continue
}
log.Printf("Received job assignment: task=%d, job=%d, type=%s",
job.Task.TaskID, job.Task.JobID, job.Task.TaskType)
if err := r.executeJob(job); err != nil {
log.Printf("Error processing job: %v", err)
}
}
}
// Stop stops the runner.
func (r *Runner) Stop() {
close(r.stopChan)
}
// KillAllProcesses kills all running processes.
func (r *Runner) KillAllProcesses() {
log.Printf("Killing all running processes...")
killedCount := r.processes.KillAll()
// Release all allocated devices
if r.encoder != nil {
// Device pool cleanup is handled internally
}
log.Printf("Killed %d process(es)", killedCount)
}
// Cleanup removes the workspace directory.
func (r *Runner) Cleanup() {
if r.workspace != nil {
r.workspace.Cleanup()
}
}
// executeJob handles a job using per-job WebSocket connection.
func (r *Runner) executeJob(job *api.NextJobResponse) (err error) {
// Recover from panics to prevent runner process crashes during task execution
defer func() {
if rec := recover(); rec != nil {
log.Printf("Task execution panicked: %v", rec)
err = fmt.Errorf("task execution panicked: %v", rec)
}
}()
// Connect to job WebSocket (no runnerID needed - authentication handles it)
jobConn := api.NewJobConnection()
if err := jobConn.Connect(r.manager.GetBaseURL(), job.JobPath, job.JobToken); err != nil {
return fmt.Errorf("failed to connect job WebSocket: %w", err)
}
defer jobConn.Close()
log.Printf("Job WebSocket authenticated for task %d", job.Task.TaskID)
// Create task context
workDir := r.workspace.JobDir(job.Task.JobID)
ctx := tasks.NewContext(
job.Task.TaskID,
job.Task.JobID,
job.Task.JobName,
job.Task.Frame,
job.Task.TaskType,
workDir,
job.JobToken,
job.Task.Metadata,
r.manager,
jobConn,
r.workspace,
r.blender,
r.encoder,
r.processes,
)
ctx.Info(fmt.Sprintf("Task assignment received (job: %d, type: %s)",
job.Task.JobID, job.Task.TaskType))
// Get processor for task type
processor, ok := r.processors[job.Task.TaskType]
if !ok {
return fmt.Errorf("unknown task type: %s", job.Task.TaskType)
}
// Process the task
var processErr error
switch job.Task.TaskType {
case "render": // this task has a upload outputs step because the frames are not uploaded by the render task directly we have to do it manually here TODO: maybe we should make it work like the encode task
// Download context
contextPath := job.JobPath + "/context.tar"
if err := r.downloadContext(job.Task.JobID, contextPath, job.JobToken); err != nil {
jobConn.Log(job.Task.TaskID, types.LogLevelError, fmt.Sprintf("Failed to download context: %v", err))
jobConn.Complete(job.Task.TaskID, false, fmt.Errorf("failed to download context: %v", err))
return fmt.Errorf("failed to download context: %w", err)
}
processErr = processor.Process(ctx)
if processErr == nil {
processErr = r.uploadOutputs(ctx, job)
}
case "encode": // this task doesn't have a upload outputs step because the video is already uploaded by the encode task
processErr = processor.Process(ctx)
default:
return fmt.Errorf("unknown task type: %s", job.Task.TaskType)
}
if processErr != nil {
ctx.Error(fmt.Sprintf("Task failed: %v", processErr))
ctx.Complete(false, processErr)
return processErr
}
ctx.Complete(true, nil)
return nil
}
func (r *Runner) downloadContext(jobID int64, contextPath, jobToken string) error {
reader, err := r.manager.DownloadContext(contextPath, jobToken)
if err != nil {
return err
}
defer reader.Close()
jobDir := r.workspace.JobDir(jobID)
return workspace.ExtractTar(reader, jobDir)
}
func (r *Runner) uploadOutputs(ctx *tasks.Context, job *api.NextJobResponse) error {
outputDir := ctx.WorkDir + "/output"
uploadPath := fmt.Sprintf("/api/runner/jobs/%d/upload", job.Task.JobID)
entries, err := os.ReadDir(outputDir)
if err != nil {
return fmt.Errorf("failed to read output directory: %w", err)
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
filePath := outputDir + "/" + entry.Name()
if err := r.manager.UploadFile(uploadPath, job.JobToken, filePath); err != nil {
log.Printf("Failed to upload %s: %v", filePath, err)
} else {
ctx.OutputUploaded(entry.Name())
}
}
return nil
}
// generateFingerprint creates a unique hardware fingerprint.
func (r *Runner) generateFingerprint() {
r.fingerprintMu.Lock()
defer r.fingerprintMu.Unlock()
var components []string
components = append(components, r.hostname)
if machineID, err := os.ReadFile("/etc/machine-id"); err == nil {
components = append(components, strings.TrimSpace(string(machineID)))
}
if productUUID, err := os.ReadFile("/sys/class/dmi/id/product_uuid"); err == nil {
components = append(components, strings.TrimSpace(string(productUUID)))
}
if macAddr, err := r.getMACAddress(); err == nil {
components = append(components, macAddr)
}
if len(components) <= 1 {
components = append(components, fmt.Sprintf("%d", os.Getpid()))
components = append(components, fmt.Sprintf("%d", time.Now().Unix()))
}
h := sha256.New()
for _, comp := range components {
h.Write([]byte(comp))
h.Write([]byte{0})
}
r.fingerprint = hex.EncodeToString(h.Sum(nil))
}
func (r *Runner) getMACAddress() (string, error) {
interfaces, err := net.Interfaces()
if err != nil {
return "", err
}
for _, iface := range interfaces {
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
continue
}
if len(iface.HardwareAddr) == 0 {
continue
}
return iface.HardwareAddr.String(), nil
}
return "", fmt.Errorf("no suitable network interface found")
}
// GetFingerprint returns the runner's hardware fingerprint.
func (r *Runner) GetFingerprint() string {
r.fingerprintMu.RLock()
defer r.fingerprintMu.RUnlock()
return r.fingerprint
}
// GetID returns the runner ID.
func (r *Runner) GetID() int64 {
return r.id
}

View File

@@ -0,0 +1,588 @@
package tasks
import (
"bufio"
"errors"
"fmt"
"log"
"math"
"os"
"os/exec"
"path/filepath"
"regexp"
"sort"
"strings"
"jiggablend/internal/runner/encoding"
)
// EncodeProcessor handles encode tasks.
type EncodeProcessor struct{}
// NewEncodeProcessor creates a new encode processor.
func NewEncodeProcessor() *EncodeProcessor {
return &EncodeProcessor{}
}
// Process executes an encode task.
func (p *EncodeProcessor) Process(ctx *Context) error {
ctx.Info(fmt.Sprintf("Starting encode task: job %d", ctx.JobID))
log.Printf("Processing encode task %d for job %d", ctx.TaskID, ctx.JobID)
// Create temporary work directory
workDir, err := ctx.Workspace.CreateVideoDir(ctx.JobID)
if err != nil {
return fmt.Errorf("failed to create work directory: %w", err)
}
defer func() {
if err := ctx.Workspace.CleanupVideoDir(ctx.JobID); err != nil {
log.Printf("Warning: Failed to cleanup encode work directory: %v", err)
}
}()
// Get output format and frame rate
outputFormat := ctx.GetOutputFormat()
if outputFormat == "" {
outputFormat = "EXR_264_MP4"
}
frameRate := ctx.GetFrameRate()
ctx.Info(fmt.Sprintf("Encode: detected output format '%s'", outputFormat))
ctx.Info(fmt.Sprintf("Encode: using frame rate %.2f fps", frameRate))
// Get job files
files, err := ctx.Manager.GetJobFiles(ctx.JobID)
if err != nil {
ctx.Error(fmt.Sprintf("Failed to get job files: %v", err))
return fmt.Errorf("failed to get job files: %w", err)
}
ctx.Info(fmt.Sprintf("GetJobFiles returned %d total files for job %d", len(files), ctx.JobID))
// Log all files for debugging
for _, file := range files {
ctx.Info(fmt.Sprintf("File: %s (type: %s, size: %d)", file.FileName, file.FileType, file.FileSize))
}
// Determine source format based on output format
sourceFormat := "exr"
fileExt := ".exr"
// Find and deduplicate frame files (EXR or PNG)
frameFileSet := make(map[string]bool)
var frameFilesList []string
for _, file := range files {
if file.FileType == "output" && strings.HasSuffix(strings.ToLower(file.FileName), fileExt) {
// Deduplicate by filename
if !frameFileSet[file.FileName] {
frameFileSet[file.FileName] = true
frameFilesList = append(frameFilesList, file.FileName)
}
}
}
if len(frameFilesList) == 0 {
// Log why no files matched (deduplicate for error reporting)
outputFileSet := make(map[string]bool)
frameFilesOtherTypeSet := make(map[string]bool)
var outputFiles []string
var frameFilesOtherType []string
for _, file := range files {
if file.FileType == "output" {
if !outputFileSet[file.FileName] {
outputFileSet[file.FileName] = true
outputFiles = append(outputFiles, file.FileName)
}
}
if strings.HasSuffix(strings.ToLower(file.FileName), fileExt) {
key := fmt.Sprintf("%s (type: %s)", file.FileName, file.FileType)
if !frameFilesOtherTypeSet[key] {
frameFilesOtherTypeSet[key] = true
frameFilesOtherType = append(frameFilesOtherType, key)
}
}
}
ctx.Error(fmt.Sprintf("no %s frame files found for encode: found %d total files, %d unique output files, %d unique %s files (with other types)", strings.ToUpper(fileExt[1:]), len(files), len(outputFiles), len(frameFilesOtherType), strings.ToUpper(fileExt[1:])))
if len(outputFiles) > 0 {
ctx.Error(fmt.Sprintf("Output files found: %v", outputFiles))
}
if len(frameFilesOtherType) > 0 {
ctx.Error(fmt.Sprintf("%s files with wrong type: %v", strings.ToUpper(fileExt[1:]), frameFilesOtherType))
}
err := fmt.Errorf("no %s frame files found for encode", strings.ToUpper(fileExt[1:]))
return err
}
ctx.Info(fmt.Sprintf("Found %d %s frames for encode", len(frameFilesList), strings.ToUpper(fileExt[1:])))
// Download frames
ctx.Info(fmt.Sprintf("Downloading %d %s frames for encode...", len(frameFilesList), strings.ToUpper(fileExt[1:])))
var frameFiles []string
for i, fileName := range frameFilesList {
ctx.Info(fmt.Sprintf("Downloading frame %d/%d: %s", i+1, len(frameFilesList), fileName))
framePath := filepath.Join(workDir, fileName)
if err := ctx.Manager.DownloadFrame(ctx.JobID, fileName, framePath); err != nil {
ctx.Error(fmt.Sprintf("Failed to download %s frame %s: %v", strings.ToUpper(fileExt[1:]), fileName, err))
log.Printf("Failed to download %s frame for encode %s: %v", strings.ToUpper(fileExt[1:]), fileName, err)
continue
}
ctx.Info(fmt.Sprintf("Successfully downloaded frame %d/%d: %s", i+1, len(frameFilesList), fileName))
frameFiles = append(frameFiles, framePath)
}
if len(frameFiles) == 0 {
err := fmt.Errorf("failed to download any %s frames for encode", strings.ToUpper(fileExt[1:]))
ctx.Error(err.Error())
return err
}
sort.Strings(frameFiles)
ctx.Info(fmt.Sprintf("Downloaded %d frames", len(frameFiles)))
// Check if EXR files have alpha channel and HDR content (only for EXR source format)
hasAlpha := false
hasHDR := false
if sourceFormat == "exr" {
// Check first frame for alpha channel and HDR using ffprobe
firstFrame := frameFiles[0]
hasAlpha = detectAlphaChannel(ctx, firstFrame)
if hasAlpha {
ctx.Info("Detected alpha channel in EXR files")
} else {
ctx.Info("No alpha channel detected in EXR files")
}
hasHDR = detectHDR(ctx, firstFrame)
if hasHDR {
ctx.Info("Detected HDR content in EXR files")
} else {
ctx.Info("No HDR content detected in EXR files (SDR range)")
}
}
// Generate video
// Use alpha if:
// 1. User explicitly enabled it OR source has alpha channel AND
// 2. Codec supports alpha (AV1 or VP9)
preserveAlpha := ctx.ShouldPreserveAlpha()
useAlpha := (preserveAlpha || hasAlpha) && (outputFormat == "EXR_AV1_MP4" || outputFormat == "EXR_VP9_WEBM")
if (preserveAlpha || hasAlpha) && outputFormat == "EXR_264_MP4" {
ctx.Warn("Alpha channel requested/detected but H.264 does not support alpha. Consider using EXR_AV1_MP4 or EXR_VP9_WEBM to preserve alpha.")
}
if preserveAlpha && !hasAlpha {
ctx.Warn("Alpha preservation requested but no alpha channel detected in EXR files.")
}
if useAlpha {
if preserveAlpha && hasAlpha {
ctx.Info("Alpha preservation enabled: Using alpha channel encoding")
} else if hasAlpha {
ctx.Info("Alpha channel detected - automatically enabling alpha encoding")
}
}
var outputExt string
switch outputFormat {
case "EXR_VP9_WEBM":
outputExt = "webm"
ctx.Info("Encoding WebM video with VP9 codec (with alpha channel and HDR support)...")
case "EXR_AV1_MP4":
outputExt = "mp4"
ctx.Info("Encoding MP4 video with AV1 codec (with alpha channel)...")
default:
outputExt = "mp4"
ctx.Info("Encoding MP4 video with H.264 codec...")
}
outputVideo := filepath.Join(workDir, fmt.Sprintf("output_%d.%s", ctx.JobID, outputExt))
// Build input pattern
firstFrame := frameFiles[0]
baseName := filepath.Base(firstFrame)
re := regexp.MustCompile(`_(\d+)\.`)
var pattern string
var startNumber int
frameNumStr := re.FindStringSubmatch(baseName)
if len(frameNumStr) > 1 {
pattern = re.ReplaceAllString(baseName, "_%04d.")
fmt.Sscanf(frameNumStr[1], "%d", &startNumber)
} else {
startNumber = extractFrameNumber(baseName)
pattern = strings.Replace(baseName, fmt.Sprintf("%d", startNumber), "%04d", 1)
}
patternPath := filepath.Join(workDir, pattern)
// Select encoder and build command (software encoding only)
var encoder encoding.Encoder
switch outputFormat {
case "EXR_AV1_MP4":
encoder = ctx.Encoder.SelectAV1()
case "EXR_VP9_WEBM":
encoder = ctx.Encoder.SelectVP9()
default:
encoder = ctx.Encoder.SelectH264()
}
ctx.Info(fmt.Sprintf("Using encoder: %s (%s)", encoder.Name(), encoder.Codec()))
// All software encoders use 2-pass for optimal quality
ctx.Info("Starting 2-pass encode for optimal quality...")
// Pass 1
ctx.Info("Pass 1/2: Analyzing content for optimal encode...")
softEncoder := encoder.(*encoding.SoftwareEncoder)
// Use HDR if: user explicitly enabled it OR HDR content was detected
preserveHDR := (ctx.ShouldPreserveHDR() || hasHDR) && sourceFormat == "exr"
if hasHDR && !ctx.ShouldPreserveHDR() {
ctx.Info("HDR content detected - automatically enabling HDR preservation")
}
pass1Cmd := softEncoder.BuildPass1Command(&encoding.EncodeConfig{
InputPattern: patternPath,
OutputPath: outputVideo,
StartFrame: startNumber,
FrameRate: frameRate,
WorkDir: workDir,
UseAlpha: useAlpha,
TwoPass: true,
SourceFormat: sourceFormat,
PreserveHDR: preserveHDR,
})
if err := pass1Cmd.Run(); err != nil {
ctx.Warn(fmt.Sprintf("Pass 1 completed (warnings expected): %v", err))
}
// Pass 2
ctx.Info("Pass 2/2: Encoding with optimal quality...")
preserveHDR = (ctx.ShouldPreserveHDR() || hasHDR) && sourceFormat == "exr"
if preserveHDR {
if hasHDR && !ctx.ShouldPreserveHDR() {
ctx.Info("HDR preservation enabled (auto-detected): Using HLG transfer with bt709 primaries")
} else {
ctx.Info("HDR preservation enabled: Using HLG transfer with bt709 primaries")
}
}
config := &encoding.EncodeConfig{
InputPattern: patternPath,
OutputPath: outputVideo,
StartFrame: startNumber,
FrameRate: frameRate,
WorkDir: workDir,
UseAlpha: useAlpha,
TwoPass: true, // Software encoding always uses 2-pass for quality
SourceFormat: sourceFormat,
PreserveHDR: preserveHDR,
}
cmd := encoder.BuildCommand(config)
if cmd == nil {
return errors.New("failed to build encode command")
}
// Set up pipes
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("failed to create stdout pipe: %w", err)
}
stderrPipe, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("failed to create stderr pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start encode command: %w", err)
}
ctx.Processes.Track(ctx.TaskID, cmd)
defer ctx.Processes.Untrack(ctx.TaskID)
// Stream stdout
stdoutDone := make(chan bool)
go func() {
defer close(stdoutDone)
scanner := bufio.NewScanner(stdoutPipe)
for scanner.Scan() {
line := scanner.Text()
if line != "" {
ctx.Info(line)
}
}
}()
// Stream stderr
stderrDone := make(chan bool)
go func() {
defer close(stderrDone)
scanner := bufio.NewScanner(stderrPipe)
for scanner.Scan() {
line := scanner.Text()
if line != "" {
ctx.Warn(line)
}
}
}()
err = cmd.Wait()
<-stdoutDone
<-stderrDone
if err != nil {
var errMsg string
if exitErr, ok := err.(*exec.ExitError); ok {
if exitErr.ExitCode() == 137 {
errMsg = "FFmpeg was killed due to excessive memory usage (OOM)"
} else {
errMsg = fmt.Sprintf("ffmpeg encoding failed: %v", err)
}
} else {
errMsg = fmt.Sprintf("ffmpeg encoding failed: %v", err)
}
if sizeErr := checkFFmpegSizeError(errMsg); sizeErr != nil {
ctx.Error(sizeErr.Error())
return sizeErr
}
ctx.Error(errMsg)
return errors.New(errMsg)
}
// Verify output
if _, err := os.Stat(outputVideo); os.IsNotExist(err) {
err := fmt.Errorf("video %s file not created: %s", outputExt, outputVideo)
ctx.Error(err.Error())
return err
}
// Clean up 2-pass log files
os.Remove(filepath.Join(workDir, "ffmpeg2pass-0.log"))
os.Remove(filepath.Join(workDir, "ffmpeg2pass-0.log.mbtree"))
ctx.Info(fmt.Sprintf("%s video encoded successfully", strings.ToUpper(outputExt)))
// Upload video
ctx.Info(fmt.Sprintf("Uploading encoded %s video...", strings.ToUpper(outputExt)))
uploadPath := fmt.Sprintf("/api/runner/jobs/%d/upload", ctx.JobID)
if err := ctx.Manager.UploadFile(uploadPath, ctx.JobToken, outputVideo); err != nil {
ctx.Error(fmt.Sprintf("Failed to upload %s: %v", strings.ToUpper(outputExt), err))
return fmt.Errorf("failed to upload %s: %w", strings.ToUpper(outputExt), err)
}
ctx.Info(fmt.Sprintf("Successfully uploaded %s: %s", strings.ToUpper(outputExt), filepath.Base(outputVideo)))
log.Printf("Successfully generated and uploaded %s for job %d: %s", strings.ToUpper(outputExt), ctx.JobID, filepath.Base(outputVideo))
return nil
}
// detectAlphaChannel checks if an EXR file has an alpha channel using ffprobe
func detectAlphaChannel(ctx *Context, filePath string) bool {
// Use ffprobe to check pixel format and stream properties
// EXR files with alpha will have formats like gbrapf32le (RGBA) vs gbrpf32le (RGB)
cmd := exec.Command("ffprobe",
"-v", "error",
"-select_streams", "v:0",
"-show_entries", "stream=pix_fmt:stream=codec_name",
"-of", "default=noprint_wrappers=1",
filePath,
)
output, err := cmd.Output()
if err != nil {
// If ffprobe fails, assume no alpha (conservative approach)
ctx.Warn(fmt.Sprintf("Failed to detect alpha channel in %s: %v", filepath.Base(filePath), err))
return false
}
outputStr := string(output)
// Check pixel format - EXR with alpha typically has 'a' in the format name (e.g., gbrapf32le)
// Also check for formats that explicitly indicate alpha
hasAlpha := strings.Contains(outputStr, "pix_fmt=gbrap") ||
strings.Contains(outputStr, "pix_fmt=rgba") ||
strings.Contains(outputStr, "pix_fmt=yuva") ||
strings.Contains(outputStr, "pix_fmt=abgr")
if hasAlpha {
ctx.Info(fmt.Sprintf("Detected alpha channel in EXR file: %s", filepath.Base(filePath)))
}
return hasAlpha
}
// detectHDR checks if an EXR file contains HDR content using ffprobe
func detectHDR(ctx *Context, filePath string) bool {
// First, check if the pixel format supports HDR (32-bit float)
cmd := exec.Command("ffprobe",
"-v", "error",
"-select_streams", "v:0",
"-show_entries", "stream=pix_fmt",
"-of", "default=noprint_wrappers=1:nokey=1",
filePath,
)
output, err := cmd.Output()
if err != nil {
// If ffprobe fails, assume no HDR (conservative approach)
ctx.Warn(fmt.Sprintf("Failed to detect HDR in %s: %v", filepath.Base(filePath), err))
return false
}
pixFmt := strings.TrimSpace(string(output))
// EXR files with 32-bit float format (gbrpf32le, gbrapf32le) can contain HDR
// Check if it's a 32-bit float format
isFloat32 := strings.Contains(pixFmt, "f32") || strings.Contains(pixFmt, "f32le")
if !isFloat32 {
// Not a float format, definitely not HDR
return false
}
// For 32-bit float EXR, sample pixels to check if values exceed SDR range (> 1.0)
// Use ffmpeg to extract pixel statistics - check max pixel values
// This is more efficient than sampling individual pixels
cmd = exec.Command("ffmpeg",
"-v", "error",
"-i", filePath,
"-vf", "signalstats",
"-f", "null",
"-",
)
output, err = cmd.CombinedOutput()
if err != nil {
// If stats extraction fails, try sampling a few pixels directly
return detectHDRBySampling(ctx, filePath)
}
// Check output for max pixel values
outputStr := string(output)
// Look for max values in the signalstats output
// If we find values > 1.0, it's HDR
if strings.Contains(outputStr, "MAX") {
// Try to extract max values from signalstats output
// Format is typically like: YMAX:1.234 UMAX:0.567 VMAX:0.890
// For EXR (RGB), we need to check R, G, B channels
// Since signalstats works on YUV, we'll use a different approach
return detectHDRBySampling(ctx, filePath)
}
// Fallback to pixel sampling
return detectHDRBySampling(ctx, filePath)
}
// detectHDRBySampling samples pixels from multiple regions to detect HDR content
func detectHDRBySampling(ctx *Context, filePath string) bool {
// Sample multiple 10x10 regions from different parts of the image
// This gives us better coverage than a single sample
sampleRegions := []string{
"crop=10:10:iw/4:ih/4", // Top-left quadrant
"crop=10:10:iw*3/4:ih/4", // Top-right quadrant
"crop=10:10:iw/4:ih*3/4", // Bottom-left quadrant
"crop=10:10:iw*3/4:ih*3/4", // Bottom-right quadrant
"crop=10:10:iw/2:ih/2", // Center
}
for _, region := range sampleRegions {
cmd := exec.Command("ffmpeg",
"-v", "error",
"-i", filePath,
"-vf", fmt.Sprintf("%s,scale=1:1", region),
"-f", "rawvideo",
"-pix_fmt", "gbrpf32le",
"-",
)
output, err := cmd.Output()
if err != nil {
continue // Skip this region if sampling fails
}
// Parse the float32 values (4 bytes per float, 3 channels RGB)
if len(output) >= 12 { // At least 3 floats (RGB) = 12 bytes
for i := 0; i < len(output)-11; i += 12 {
// Read RGB values (little-endian float32)
r := float32FromBytes(output[i : i+4])
g := float32FromBytes(output[i+4 : i+8])
b := float32FromBytes(output[i+8 : i+12])
// Check if any channel exceeds 1.0 (SDR range)
if r > 1.0 || g > 1.0 || b > 1.0 {
maxVal := max(r, max(g, b))
ctx.Info(fmt.Sprintf("Detected HDR content in EXR file: %s (max value: %.2f)", filepath.Base(filePath), maxVal))
return true
}
}
}
}
// If we sampled multiple regions and none exceed 1.0, it's likely SDR content
// But since it's 32-bit float format, user can still manually enable HDR if needed
return false
}
// float32FromBytes converts 4 bytes (little-endian) to float32
func float32FromBytes(bytes []byte) float32 {
if len(bytes) < 4 {
return 0
}
bits := uint32(bytes[0]) | uint32(bytes[1])<<8 | uint32(bytes[2])<<16 | uint32(bytes[3])<<24
return math.Float32frombits(bits)
}
// max returns the maximum of two float32 values
func max(a, b float32) float32 {
if a > b {
return a
}
return b
}
func extractFrameNumber(filename string) int {
parts := strings.Split(filepath.Base(filename), "_")
if len(parts) < 2 {
return 0
}
framePart := strings.Split(parts[1], ".")[0]
var frameNum int
fmt.Sscanf(framePart, "%d", &frameNum)
return frameNum
}
func checkFFmpegSizeError(output string) error {
outputLower := strings.ToLower(output)
if strings.Contains(outputLower, "hardware does not support encoding at size") {
constraintsMatch := regexp.MustCompile(`constraints:\s*width\s+(\d+)-(\d+)\s+height\s+(\d+)-(\d+)`).FindStringSubmatch(output)
if len(constraintsMatch) == 5 {
return fmt.Errorf("video frame size is outside hardware encoder limits. Hardware requires: width %s-%s, height %s-%s",
constraintsMatch[1], constraintsMatch[2], constraintsMatch[3], constraintsMatch[4])
}
return fmt.Errorf("video frame size is outside hardware encoder limits")
}
if strings.Contains(outputLower, "picture size") && strings.Contains(outputLower, "is invalid") {
sizeMatch := regexp.MustCompile(`picture size\s+(\d+)x(\d+)`).FindStringSubmatch(output)
if len(sizeMatch) == 3 {
return fmt.Errorf("invalid video frame size: %sx%s", sizeMatch[1], sizeMatch[2])
}
return fmt.Errorf("invalid video frame size")
}
if strings.Contains(outputLower, "error while opening encoder") &&
(strings.Contains(outputLower, "width") || strings.Contains(outputLower, "height") || strings.Contains(outputLower, "size")) {
sizeMatch := regexp.MustCompile(`at size\s+(\d+)x(\d+)`).FindStringSubmatch(output)
if len(sizeMatch) == 3 {
return fmt.Errorf("hardware encoder cannot encode frame size %sx%s", sizeMatch[1], sizeMatch[2])
}
return fmt.Errorf("hardware encoder error: frame size may be invalid")
}
if strings.Contains(outputLower, "invalid") &&
(strings.Contains(outputLower, "width") || strings.Contains(outputLower, "height") || strings.Contains(outputLower, "dimension")) {
return fmt.Errorf("invalid frame dimensions detected")
}
return nil
}

View File

@@ -0,0 +1,156 @@
// Package tasks provides task processing implementations.
package tasks
import (
"jiggablend/internal/runner/api"
"jiggablend/internal/runner/blender"
"jiggablend/internal/runner/encoding"
"jiggablend/internal/runner/workspace"
"jiggablend/pkg/executils"
"jiggablend/pkg/types"
)
// Processor handles a specific task type.
type Processor interface {
Process(ctx *Context) error
}
// Context provides task execution context.
type Context struct {
TaskID int64
JobID int64
JobName string
Frame int
TaskType string
WorkDir string
JobToken string
Metadata *types.BlendMetadata
Manager *api.ManagerClient
JobConn *api.JobConnection
Workspace *workspace.Manager
Blender *blender.Manager
Encoder *encoding.Selector
Processes *executils.ProcessTracker
}
// NewContext creates a new task context.
func NewContext(
taskID, jobID int64,
jobName string,
frame int,
taskType string,
workDir string,
jobToken string,
metadata *types.BlendMetadata,
manager *api.ManagerClient,
jobConn *api.JobConnection,
ws *workspace.Manager,
blenderMgr *blender.Manager,
encoder *encoding.Selector,
processes *executils.ProcessTracker,
) *Context {
return &Context{
TaskID: taskID,
JobID: jobID,
JobName: jobName,
Frame: frame,
TaskType: taskType,
WorkDir: workDir,
JobToken: jobToken,
Metadata: metadata,
Manager: manager,
JobConn: jobConn,
Workspace: ws,
Blender: blenderMgr,
Encoder: encoder,
Processes: processes,
}
}
// Log sends a log entry to the manager.
func (c *Context) Log(level types.LogLevel, message string) {
if c.JobConn != nil {
c.JobConn.Log(c.TaskID, level, message)
}
}
// Info logs an info message.
func (c *Context) Info(message string) {
c.Log(types.LogLevelInfo, message)
}
// Warn logs a warning message.
func (c *Context) Warn(message string) {
c.Log(types.LogLevelWarn, message)
}
// Error logs an error message.
func (c *Context) Error(message string) {
c.Log(types.LogLevelError, message)
}
// Progress sends a progress update.
func (c *Context) Progress(progress float64) {
if c.JobConn != nil {
c.JobConn.Progress(c.TaskID, progress)
}
}
// OutputUploaded notifies that an output file was uploaded.
func (c *Context) OutputUploaded(fileName string) {
if c.JobConn != nil {
c.JobConn.OutputUploaded(c.TaskID, fileName)
}
}
// Complete sends task completion.
func (c *Context) Complete(success bool, errorMsg error) {
if c.JobConn != nil {
c.JobConn.Complete(c.TaskID, success, errorMsg)
}
}
// GetOutputFormat returns the output format from metadata or default.
func (c *Context) GetOutputFormat() string {
if c.Metadata != nil && c.Metadata.RenderSettings.OutputFormat != "" {
return c.Metadata.RenderSettings.OutputFormat
}
return "PNG"
}
// GetFrameRate returns the frame rate from metadata or default.
func (c *Context) GetFrameRate() float64 {
if c.Metadata != nil && c.Metadata.RenderSettings.FrameRate > 0 {
return c.Metadata.RenderSettings.FrameRate
}
return 24.0
}
// GetBlenderVersion returns the Blender version from metadata.
func (c *Context) GetBlenderVersion() string {
if c.Metadata != nil {
return c.Metadata.BlenderVersion
}
return ""
}
// ShouldUnhideObjects returns whether to unhide objects.
func (c *Context) ShouldUnhideObjects() bool {
return c.Metadata != nil && c.Metadata.UnhideObjects != nil && *c.Metadata.UnhideObjects
}
// ShouldEnableExecution returns whether to enable auto-execution.
func (c *Context) ShouldEnableExecution() bool {
return c.Metadata != nil && c.Metadata.EnableExecution != nil && *c.Metadata.EnableExecution
}
// ShouldPreserveHDR returns whether to preserve HDR range for EXR encoding.
func (c *Context) ShouldPreserveHDR() bool {
return c.Metadata != nil && c.Metadata.PreserveHDR != nil && *c.Metadata.PreserveHDR
}
// ShouldPreserveAlpha returns whether to preserve alpha channel for EXR encoding.
func (c *Context) ShouldPreserveAlpha() bool {
return c.Metadata != nil && c.Metadata.PreserveAlpha != nil && *c.Metadata.PreserveAlpha
}

View File

@@ -0,0 +1,301 @@
package tasks
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"log"
"os"
"os/exec"
"path/filepath"
"strings"
"jiggablend/internal/runner/blender"
"jiggablend/internal/runner/workspace"
"jiggablend/pkg/scripts"
"jiggablend/pkg/types"
)
// RenderProcessor handles render tasks.
type RenderProcessor struct{}
// NewRenderProcessor creates a new render processor.
func NewRenderProcessor() *RenderProcessor {
return &RenderProcessor{}
}
// Process executes a render task.
func (p *RenderProcessor) Process(ctx *Context) error {
ctx.Info(fmt.Sprintf("Starting task: job %d, frame %d, format: %s",
ctx.JobID, ctx.Frame, ctx.GetOutputFormat()))
log.Printf("Processing task %d: job %d, frame %d", ctx.TaskID, ctx.JobID, ctx.Frame)
// Find .blend file
blendFile, err := workspace.FindFirstBlendFile(ctx.WorkDir)
if err != nil {
return fmt.Errorf("failed to find blend file: %w", err)
}
// Get Blender binary
blenderBinary := "blender"
if version := ctx.GetBlenderVersion(); version != "" {
ctx.Info(fmt.Sprintf("Job requires Blender %s", version))
binaryPath, err := ctx.Blender.GetBinaryPath(version)
if err != nil {
ctx.Warn(fmt.Sprintf("Could not get Blender %s, using system blender: %v", version, err))
} else {
blenderBinary = binaryPath
ctx.Info(fmt.Sprintf("Using Blender binary: %s", blenderBinary))
}
} else {
ctx.Info("No Blender version specified, using system blender")
}
// Create output directory
outputDir := filepath.Join(ctx.WorkDir, "output")
if err := os.MkdirAll(outputDir, 0755); err != nil {
return fmt.Errorf("failed to create output directory: %w", err)
}
// Create home directory for Blender inside workspace
blenderHome := filepath.Join(ctx.WorkDir, "home")
if err := os.MkdirAll(blenderHome, 0755); err != nil {
return fmt.Errorf("failed to create Blender home directory: %w", err)
}
// Determine render format
outputFormat := ctx.GetOutputFormat()
renderFormat := outputFormat
if outputFormat == "EXR_264_MP4" || outputFormat == "EXR_AV1_MP4" || outputFormat == "EXR_VP9_WEBM" {
renderFormat = "EXR" // Use EXR for maximum quality
}
// Create render script
if err := p.createRenderScript(ctx, renderFormat); err != nil {
return err
}
// Render
ctx.Info(fmt.Sprintf("Starting Blender render for frame %d...", ctx.Frame))
if err := p.runBlender(ctx, blenderBinary, blendFile, outputDir, renderFormat, blenderHome); err != nil {
ctx.Error(fmt.Sprintf("Blender render failed: %v", err))
return err
}
// Verify output
if _, err := p.findOutputFile(ctx, outputDir, renderFormat); err != nil {
ctx.Error(fmt.Sprintf("Output verification failed: %v", err))
return err
}
ctx.Info(fmt.Sprintf("Blender render completed for frame %d", ctx.Frame))
return nil
}
func (p *RenderProcessor) createRenderScript(ctx *Context, renderFormat string) error {
formatFilePath := filepath.Join(ctx.WorkDir, "output_format.txt")
renderSettingsFilePath := filepath.Join(ctx.WorkDir, "render_settings.json")
// Build unhide code conditionally
unhideCode := ""
if ctx.ShouldUnhideObjects() {
unhideCode = scripts.UnhideObjects
}
// Load template and replace placeholders
scriptContent := scripts.RenderBlenderTemplate
scriptContent = strings.ReplaceAll(scriptContent, "{{UNHIDE_CODE}}", unhideCode)
scriptContent = strings.ReplaceAll(scriptContent, "{{FORMAT_FILE_PATH}}", fmt.Sprintf("%q", formatFilePath))
scriptContent = strings.ReplaceAll(scriptContent, "{{RENDER_SETTINGS_FILE}}", fmt.Sprintf("%q", renderSettingsFilePath))
scriptPath := filepath.Join(ctx.WorkDir, "enable_gpu.py")
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0644); err != nil {
errMsg := fmt.Sprintf("failed to create GPU enable script: %v", err)
ctx.Error(errMsg)
return errors.New(errMsg)
}
// Write output format
outputFormat := ctx.GetOutputFormat()
ctx.Info(fmt.Sprintf("Writing output format '%s' to format file", outputFormat))
if err := os.WriteFile(formatFilePath, []byte(outputFormat), 0644); err != nil {
errMsg := fmt.Sprintf("failed to create format file: %v", err)
ctx.Error(errMsg)
return errors.New(errMsg)
}
// Write render settings if available
if ctx.Metadata != nil && ctx.Metadata.RenderSettings.EngineSettings != nil {
settingsJSON, err := json.Marshal(ctx.Metadata.RenderSettings)
if err == nil {
if err := os.WriteFile(renderSettingsFilePath, settingsJSON, 0644); err != nil {
ctx.Warn(fmt.Sprintf("Failed to write render settings file: %v", err))
}
}
}
return nil
}
func (p *RenderProcessor) runBlender(ctx *Context, blenderBinary, blendFile, outputDir, renderFormat, blenderHome string) error {
scriptPath := filepath.Join(ctx.WorkDir, "enable_gpu.py")
args := []string{"-b", blendFile, "--python", scriptPath}
if ctx.ShouldEnableExecution() {
args = append(args, "--enable-autoexec")
}
// Output pattern
outputPattern := filepath.Join(outputDir, fmt.Sprintf("frame_####.%s", strings.ToLower(renderFormat)))
outputAbsPattern, _ := filepath.Abs(outputPattern)
args = append(args, "-o", outputAbsPattern)
args = append(args, "-f", fmt.Sprintf("%d", ctx.Frame))
// Wrap with xvfb-run
xvfbArgs := []string{"-a", "-s", "-screen 0 800x600x24", blenderBinary}
xvfbArgs = append(xvfbArgs, args...)
cmd := exec.Command("xvfb-run", xvfbArgs...)
cmd.Dir = ctx.WorkDir
// Set up environment with custom HOME directory
env := os.Environ()
// Remove existing HOME if present and add our custom one
newEnv := make([]string, 0, len(env)+1)
for _, e := range env {
if !strings.HasPrefix(e, "HOME=") {
newEnv = append(newEnv, e)
}
}
newEnv = append(newEnv, fmt.Sprintf("HOME=%s", blenderHome))
cmd.Env = newEnv
// Set up pipes
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("failed to create stdout pipe: %w", err)
}
stderrPipe, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("failed to create stderr pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start blender: %w", err)
}
// Track process
ctx.Processes.Track(ctx.TaskID, cmd)
defer ctx.Processes.Untrack(ctx.TaskID)
// Stream stdout
stdoutDone := make(chan bool)
go func() {
defer close(stdoutDone)
scanner := bufio.NewScanner(stdoutPipe)
for scanner.Scan() {
line := scanner.Text()
if line != "" {
shouldFilter, logLevel := blender.FilterLog(line)
if !shouldFilter {
ctx.Log(logLevel, line)
}
}
}
}()
// Stream stderr
stderrDone := make(chan bool)
go func() {
defer close(stderrDone)
scanner := bufio.NewScanner(stderrPipe)
for scanner.Scan() {
line := scanner.Text()
if line != "" {
shouldFilter, logLevel := blender.FilterLog(line)
if !shouldFilter {
if logLevel == types.LogLevelInfo {
logLevel = types.LogLevelWarn
}
ctx.Log(logLevel, line)
}
}
}
}()
// Wait for completion
err = cmd.Wait()
<-stdoutDone
<-stderrDone
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
if exitErr.ExitCode() == 137 {
return errors.New("Blender was killed due to excessive memory usage (OOM)")
}
}
return fmt.Errorf("blender failed: %w", err)
}
return nil
}
func (p *RenderProcessor) findOutputFile(ctx *Context, outputDir, renderFormat string) (string, error) {
entries, err := os.ReadDir(outputDir)
if err != nil {
return "", fmt.Errorf("failed to read output directory: %w", err)
}
ctx.Info("Checking output directory for files...")
// Try exact match first
expectedFile := filepath.Join(outputDir, fmt.Sprintf("frame_%04d.%s", ctx.Frame, strings.ToLower(renderFormat)))
if _, err := os.Stat(expectedFile); err == nil {
ctx.Info(fmt.Sprintf("Found output file: %s", filepath.Base(expectedFile)))
return expectedFile, nil
}
// Try without zero padding
altFile := filepath.Join(outputDir, fmt.Sprintf("frame_%d.%s", ctx.Frame, strings.ToLower(renderFormat)))
if _, err := os.Stat(altFile); err == nil {
ctx.Info(fmt.Sprintf("Found output file: %s", filepath.Base(altFile)))
return altFile, nil
}
// Try just frame number
altFile2 := filepath.Join(outputDir, fmt.Sprintf("%04d.%s", ctx.Frame, strings.ToLower(renderFormat)))
if _, err := os.Stat(altFile2); err == nil {
ctx.Info(fmt.Sprintf("Found output file: %s", filepath.Base(altFile2)))
return altFile2, nil
}
// Search through all files
for _, entry := range entries {
if !entry.IsDir() {
fileName := entry.Name()
if strings.Contains(fileName, "%04d") || strings.Contains(fileName, "%d") {
ctx.Warn(fmt.Sprintf("Skipping file with literal pattern: %s", fileName))
continue
}
frameStr := fmt.Sprintf("%d", ctx.Frame)
frameStrPadded := fmt.Sprintf("%04d", ctx.Frame)
if strings.Contains(fileName, frameStrPadded) ||
(strings.Contains(fileName, frameStr) && strings.HasSuffix(strings.ToLower(fileName), strings.ToLower(renderFormat))) {
outputFile := filepath.Join(outputDir, fileName)
ctx.Info(fmt.Sprintf("Found output file: %s", fileName))
return outputFile, nil
}
}
}
// Not found
fileList := []string{}
for _, entry := range entries {
if !entry.IsDir() {
fileList = append(fileList, entry.Name())
}
}
return "", fmt.Errorf("output file not found: %s\nFiles in output directory: %v", expectedFile, fileList)
}

View File

@@ -0,0 +1,146 @@
package workspace
import (
"archive/tar"
"fmt"
"io"
"log"
"os"
"path/filepath"
"strings"
)
// ExtractTar extracts a tar archive from a reader to a directory.
func ExtractTar(reader io.Reader, destDir string) error {
if err := os.MkdirAll(destDir, 0755); err != nil {
return fmt.Errorf("failed to create destination directory: %w", err)
}
tarReader := tar.NewReader(reader)
for {
header, err := tarReader.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("failed to read tar header: %w", err)
}
// Sanitize path to prevent directory traversal
targetPath := filepath.Join(destDir, header.Name)
if !strings.HasPrefix(filepath.Clean(targetPath), filepath.Clean(destDir)+string(os.PathSeparator)) {
return fmt.Errorf("invalid file path in tar: %s", header.Name)
}
switch header.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
case tar.TypeReg:
if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
return fmt.Errorf("failed to create parent directory: %w", err)
}
outFile, err := os.Create(targetPath)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
if _, err := io.Copy(outFile, tarReader); err != nil {
outFile.Close()
return fmt.Errorf("failed to write file: %w", err)
}
outFile.Close()
if err := os.Chmod(targetPath, os.FileMode(header.Mode)); err != nil {
log.Printf("Warning: failed to set file permissions: %v", err)
}
}
}
return nil
}
// ExtractTarStripPrefix extracts a tar archive, stripping the top-level directory.
// Useful for Blender archives like "blender-4.2.3-linux-x64/".
func ExtractTarStripPrefix(reader io.Reader, destDir string) error {
if err := os.MkdirAll(destDir, 0755); err != nil {
return err
}
tarReader := tar.NewReader(reader)
stripPrefix := ""
for {
header, err := tarReader.Next()
if err == io.EOF {
break
}
if err != nil {
return err
}
// Determine strip prefix from first entry (e.g., "blender-4.2.3-linux-x64/")
if stripPrefix == "" {
parts := strings.SplitN(header.Name, "/", 2)
if len(parts) > 0 {
stripPrefix = parts[0] + "/"
}
}
// Strip the top-level directory
name := strings.TrimPrefix(header.Name, stripPrefix)
if name == "" {
continue
}
targetPath := filepath.Join(destDir, name)
switch header.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil {
return err
}
case tar.TypeReg:
if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
return err
}
outFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
if err != nil {
return err
}
if _, err := io.Copy(outFile, tarReader); err != nil {
outFile.Close()
return err
}
outFile.Close()
case tar.TypeSymlink:
if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
return err
}
os.Remove(targetPath) // Remove existing symlink if present
if err := os.Symlink(header.Linkname, targetPath); err != nil {
return err
}
}
}
return nil
}
// ExtractTarFile extracts a tar file to a directory.
func ExtractTarFile(tarPath, destDir string) error {
file, err := os.Open(tarPath)
if err != nil {
return fmt.Errorf("failed to open tar file: %w", err)
}
defer file.Close()
return ExtractTar(file, destDir)
}

View File

@@ -0,0 +1,217 @@
// Package workspace manages runner workspace directories.
package workspace
import (
"fmt"
"log"
"os"
"path/filepath"
"strings"
)
// Manager handles workspace directory operations.
type Manager struct {
baseDir string
runnerName string
}
// NewManager creates a new workspace manager.
func NewManager(runnerName string) *Manager {
m := &Manager{
runnerName: sanitizeName(runnerName),
}
m.init()
return m
}
func sanitizeName(name string) string {
name = strings.ReplaceAll(name, " ", "_")
name = strings.ReplaceAll(name, "/", "_")
name = strings.ReplaceAll(name, "\\", "_")
name = strings.ReplaceAll(name, ":", "_")
return name
}
func (m *Manager) init() {
// Prefer current directory if writable, otherwise use temp
baseDir := os.TempDir()
if cwd, err := os.Getwd(); err == nil {
baseDir = cwd
}
m.baseDir = filepath.Join(baseDir, "jiggablend-workspaces", m.runnerName)
if err := os.MkdirAll(m.baseDir, 0755); err != nil {
log.Printf("Warning: Failed to create workspace directory %s: %v", m.baseDir, err)
// Fallback to temp directory
m.baseDir = filepath.Join(os.TempDir(), "jiggablend-workspaces", m.runnerName)
if err := os.MkdirAll(m.baseDir, 0755); err != nil {
log.Printf("Error: Failed to create fallback workspace directory: %v", err)
// Last resort
m.baseDir = filepath.Join(os.TempDir(), "jiggablend-runner")
os.MkdirAll(m.baseDir, 0755)
}
}
log.Printf("Runner workspace initialized at: %s", m.baseDir)
}
// BaseDir returns the base workspace directory.
func (m *Manager) BaseDir() string {
return m.baseDir
}
// JobDir returns the directory for a specific job.
func (m *Manager) JobDir(jobID int64) string {
return filepath.Join(m.baseDir, fmt.Sprintf("job-%d", jobID))
}
// VideoDir returns the directory for encoding.
func (m *Manager) VideoDir(jobID int64) string {
return filepath.Join(m.baseDir, fmt.Sprintf("job-%d-video", jobID))
}
// BlenderDir returns the directory for Blender installations.
func (m *Manager) BlenderDir() string {
return filepath.Join(m.baseDir, "blender-versions")
}
// CreateJobDir creates and returns the job directory.
func (m *Manager) CreateJobDir(jobID int64) (string, error) {
dir := m.JobDir(jobID)
if err := os.MkdirAll(dir, 0755); err != nil {
return "", fmt.Errorf("failed to create job directory: %w", err)
}
return dir, nil
}
// CreateVideoDir creates and returns the encode directory.
func (m *Manager) CreateVideoDir(jobID int64) (string, error) {
dir := m.VideoDir(jobID)
if err := os.MkdirAll(dir, 0755); err != nil {
return "", fmt.Errorf("failed to create video directory: %w", err)
}
return dir, nil
}
// CleanupJobDir removes a job directory.
func (m *Manager) CleanupJobDir(jobID int64) error {
dir := m.JobDir(jobID)
return os.RemoveAll(dir)
}
// CleanupVideoDir removes an encode directory.
func (m *Manager) CleanupVideoDir(jobID int64) error {
dir := m.VideoDir(jobID)
return os.RemoveAll(dir)
}
// Cleanup removes the entire workspace directory.
func (m *Manager) Cleanup() {
if m.baseDir != "" {
log.Printf("Cleaning up workspace directory: %s", m.baseDir)
if err := os.RemoveAll(m.baseDir); err != nil {
log.Printf("Warning: Failed to remove workspace directory %s: %v", m.baseDir, err)
} else {
log.Printf("Successfully removed workspace directory: %s", m.baseDir)
}
}
// Also clean up any orphaned jiggablend directories
cleanupOrphanedWorkspaces()
}
// cleanupOrphanedWorkspaces removes any jiggablend workspace directories
// that might be left behind from previous runs or crashes.
func cleanupOrphanedWorkspaces() {
log.Printf("Cleaning up orphaned jiggablend workspace directories...")
dirsToCheck := []string{".", os.TempDir()}
for _, baseDir := range dirsToCheck {
workspaceDir := filepath.Join(baseDir, "jiggablend-workspaces")
if _, err := os.Stat(workspaceDir); err == nil {
log.Printf("Removing orphaned workspace directory: %s", workspaceDir)
if err := os.RemoveAll(workspaceDir); err != nil {
log.Printf("Warning: Failed to remove workspace directory %s: %v", workspaceDir, err)
} else {
log.Printf("Successfully removed workspace directory: %s", workspaceDir)
}
}
}
}
// FindBlendFiles finds all .blend files in a directory.
func FindBlendFiles(dir string) ([]string, error) {
var blendFiles []string
err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".blend") {
// Check it's not a Blender save file (.blend1, .blend2, etc.)
lower := strings.ToLower(info.Name())
idx := strings.LastIndex(lower, ".blend")
if idx != -1 {
suffix := lower[idx+len(".blend"):]
isSaveFile := false
if len(suffix) > 0 {
isSaveFile = true
for _, r := range suffix {
if r < '0' || r > '9' {
isSaveFile = false
break
}
}
}
if !isSaveFile {
relPath, _ := filepath.Rel(dir, path)
blendFiles = append(blendFiles, relPath)
}
}
}
return nil
})
return blendFiles, err
}
// FindFirstBlendFile finds the first .blend file in a directory.
func FindFirstBlendFile(dir string) (string, error) {
var blendFile string
err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".blend") {
lower := strings.ToLower(info.Name())
idx := strings.LastIndex(lower, ".blend")
if idx != -1 {
suffix := lower[idx+len(".blend"):]
isSaveFile := false
if len(suffix) > 0 {
isSaveFile = true
for _, r := range suffix {
if r < '0' || r > '9' {
isSaveFile = false
break
}
}
}
if !isSaveFile {
blendFile = path
return filepath.SkipAll
}
}
}
return nil
})
if err != nil {
return "", err
}
if blendFile == "" {
return "", fmt.Errorf("no .blend file found in %s", dir)
}
return blendFile, nil
}

View File

@@ -1,9 +1,11 @@
package storage
import (
"archive/tar"
"archive/zip"
"fmt"
"io"
"log"
"os"
"path/filepath"
"strings"
@@ -29,6 +31,7 @@ func (s *Storage) init() error {
s.basePath,
s.uploadsPath(),
s.outputsPath(),
s.tempPath(),
}
for _, dir := range dirs {
@@ -40,6 +43,28 @@ func (s *Storage) init() error {
return nil
}
// tempPath returns the path for temporary files
func (s *Storage) tempPath() string {
return filepath.Join(s.basePath, "temp")
}
// BasePath returns the storage base path (for cleanup tasks)
func (s *Storage) BasePath() string {
return s.basePath
}
// TempDir creates a temporary directory under the storage base path
// Returns the path to the temporary directory
func (s *Storage) TempDir(pattern string) (string, error) {
// Ensure temp directory exists
if err := os.MkdirAll(s.tempPath(), 0755); err != nil {
return "", fmt.Errorf("failed to create temp directory: %w", err)
}
// Create temp directory under storage base path
return os.MkdirTemp(s.tempPath(), pattern)
}
// uploadsPath returns the path for uploads
func (s *Storage) uploadsPath() string {
return filepath.Join(s.basePath, "uploads")
@@ -140,6 +165,13 @@ func (s *Storage) GetFileSize(filePath string) (int64, error) {
// ExtractZip extracts a ZIP file to the destination directory
// Returns a list of all extracted file paths
func (s *Storage) ExtractZip(zipPath, destDir string) ([]string, error) {
log.Printf("Extracting ZIP archive: %s -> %s", zipPath, destDir)
// Ensure destination directory exists
if err := os.MkdirAll(destDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create destination directory: %w", err)
}
r, err := zip.OpenReader(zipPath)
if err != nil {
return nil, fmt.Errorf("failed to open ZIP file: %w", err)
@@ -147,12 +179,20 @@ func (s *Storage) ExtractZip(zipPath, destDir string) ([]string, error) {
defer r.Close()
var extractedFiles []string
fileCount := 0
dirCount := 0
log.Printf("ZIP contains %d entries", len(r.File))
for _, f := range r.File {
// Sanitize file path to prevent directory traversal
destPath := filepath.Join(destDir, f.Name)
if !strings.HasPrefix(destPath, filepath.Clean(destDir)+string(os.PathSeparator)) {
return nil, fmt.Errorf("invalid file path in ZIP: %s", f.Name)
cleanDestPath := filepath.Clean(destPath)
cleanDestDir := filepath.Clean(destDir)
if !strings.HasPrefix(cleanDestPath, cleanDestDir+string(os.PathSeparator)) && cleanDestPath != cleanDestDir {
log.Printf("ERROR: Invalid file path in ZIP - target: %s, destDir: %s", cleanDestPath, cleanDestDir)
return nil, fmt.Errorf("invalid file path in ZIP: %s (target: %s, destDir: %s)", f.Name, cleanDestPath, cleanDestDir)
}
// Create directory structure
@@ -160,6 +200,7 @@ func (s *Storage) ExtractZip(zipPath, destDir string) ([]string, error) {
if err := os.MkdirAll(destPath, 0755); err != nil {
return nil, fmt.Errorf("failed to create directory: %w", err)
}
dirCount++
continue
}
@@ -189,8 +230,381 @@ func (s *Storage) ExtractZip(zipPath, destDir string) ([]string, error) {
}
extractedFiles = append(extractedFiles, destPath)
fileCount++
}
log.Printf("ZIP extraction complete: %d files, %d directories extracted to %s", fileCount, dirCount, destDir)
return extractedFiles, nil
}
// findCommonPrefix finds the common leading directory prefix if all paths share the same first-level directory
// Returns the prefix to strip (with trailing slash) or empty string if no common prefix
func findCommonPrefix(relPaths []string) string {
if len(relPaths) == 0 {
return ""
}
// Get the first path component of each path
firstComponents := make([]string, 0, len(relPaths))
for _, path := range relPaths {
parts := strings.Split(filepath.ToSlash(path), "/")
if len(parts) > 0 && parts[0] != "" {
firstComponents = append(firstComponents, parts[0])
} else {
// If any path is at root level, no common prefix
return ""
}
}
// Check if all first components are the same
if len(firstComponents) == 0 {
return ""
}
commonFirst := firstComponents[0]
for _, comp := range firstComponents {
if comp != commonFirst {
// Not all paths share the same first directory
return ""
}
}
// All paths share the same first directory - return it with trailing slash
return commonFirst + "/"
}
// isBlenderSaveFile checks if a filename is a Blender save file (.blend1, .blend2, etc.)
// Returns true for files like "file.blend1", "file.blend2", but false for "file.blend"
func isBlenderSaveFile(filename string) bool {
lower := strings.ToLower(filename)
// Check if it ends with .blend followed by one or more digits
// Pattern: *.blend[digits]
if !strings.HasSuffix(lower, ".blend") {
// Doesn't end with .blend, check if it ends with .blend + digits
idx := strings.LastIndex(lower, ".blend")
if idx == -1 {
return false
}
// Check if there are digits after .blend
suffix := lower[idx+len(".blend"):]
if len(suffix) == 0 {
return false
}
// All remaining characters must be digits
for _, r := range suffix {
if r < '0' || r > '9' {
return false
}
}
return true
}
// Ends with .blend exactly - this is a regular blend file, not a save file
return false
}
// CreateJobContext creates a tar archive containing all job input files
// Filters out Blender save files (.blend1, .blend2, etc.)
// Uses temporary directories and streaming to handle large files efficiently
func (s *Storage) CreateJobContext(jobID int64) (string, error) {
jobPath := s.JobPath(jobID)
contextPath := filepath.Join(jobPath, "context.tar")
// Create temporary directory for staging
tmpDir, err := os.MkdirTemp("", "jiggablend-context-*")
if err != nil {
return "", fmt.Errorf("failed to create temporary directory: %w", err)
}
defer os.RemoveAll(tmpDir)
// Collect all files from job directory, excluding the context file itself and Blender save files
var filesToInclude []string
err = filepath.Walk(jobPath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// Skip directories
if info.IsDir() {
return nil
}
// Skip the context file itself if it exists
if path == contextPath {
return nil
}
// Skip Blender save files
if isBlenderSaveFile(info.Name()) {
return nil
}
// Get relative path from job directory
relPath, err := filepath.Rel(jobPath, path)
if err != nil {
return err
}
// Sanitize path - ensure it doesn't escape the job directory
cleanRelPath := filepath.Clean(relPath)
if strings.HasPrefix(cleanRelPath, "..") {
return fmt.Errorf("invalid file path: %s", relPath)
}
filesToInclude = append(filesToInclude, path)
return nil
})
if err != nil {
return "", fmt.Errorf("failed to walk job directory: %w", err)
}
if len(filesToInclude) == 0 {
return "", fmt.Errorf("no files found to include in context")
}
// Create the tar file using streaming
contextFile, err := os.Create(contextPath)
if err != nil {
return "", fmt.Errorf("failed to create context file: %w", err)
}
defer contextFile.Close()
tarWriter := tar.NewWriter(contextFile)
defer tarWriter.Close()
// Add each file to the tar archive
for _, filePath := range filesToInclude {
file, err := os.Open(filePath)
if err != nil {
return "", fmt.Errorf("failed to open file %s: %w", filePath, err)
}
// Use a function closure to ensure file is closed even on error
err = func() error {
defer file.Close()
info, err := file.Stat()
if err != nil {
return fmt.Errorf("failed to stat file %s: %w", filePath, err)
}
// Get relative path for tar header
relPath, err := filepath.Rel(jobPath, filePath)
if err != nil {
return fmt.Errorf("failed to get relative path for %s: %w", filePath, err)
}
// Normalize path separators for tar (use forward slashes)
tarPath := filepath.ToSlash(relPath)
// Create tar header
header, err := tar.FileInfoHeader(info, "")
if err != nil {
return fmt.Errorf("failed to create tar header for %s: %w", filePath, err)
}
header.Name = tarPath
// Write header
if err := tarWriter.WriteHeader(header); err != nil {
return fmt.Errorf("failed to write tar header for %s: %w", filePath, err)
}
// Copy file contents using streaming
if _, err := io.Copy(tarWriter, file); err != nil {
return fmt.Errorf("failed to write file %s to tar: %w", filePath, err)
}
return nil
}()
if err != nil {
return "", err
}
}
// Ensure all data is flushed
if err := tarWriter.Close(); err != nil {
return "", fmt.Errorf("failed to close tar writer: %w", err)
}
if err := contextFile.Close(); err != nil {
return "", fmt.Errorf("failed to close context file: %w", err)
}
return contextPath, nil
}
// CreateJobContextFromDir creates a context archive (tar) from files in a source directory
// This is used during upload to immediately create the context archive as the primary artifact
// excludeFiles is a set of relative paths (from sourceDir) to exclude from the context
func (s *Storage) CreateJobContextFromDir(sourceDir string, jobID int64, excludeFiles ...string) (string, error) {
jobPath := s.JobPath(jobID)
contextPath := filepath.Join(jobPath, "context.tar")
// Ensure job directory exists
if err := os.MkdirAll(jobPath, 0755); err != nil {
return "", fmt.Errorf("failed to create job directory: %w", err)
}
// Build set of files to exclude (normalize paths)
excludeSet := make(map[string]bool)
for _, excludeFile := range excludeFiles {
// Normalize the exclude path
excludePath := filepath.Clean(excludeFile)
excludeSet[excludePath] = true
// Also add with forward slash for cross-platform compatibility
excludeSet[filepath.ToSlash(excludePath)] = true
}
// Collect all files from source directory, excluding Blender save files and excluded files
var filesToInclude []string
err := filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// Skip directories
if info.IsDir() {
return nil
}
// Skip Blender save files
if isBlenderSaveFile(info.Name()) {
return nil
}
// Get relative path from source directory
relPath, err := filepath.Rel(sourceDir, path)
if err != nil {
return err
}
// Sanitize path - ensure it doesn't escape the source directory
cleanRelPath := filepath.Clean(relPath)
if strings.HasPrefix(cleanRelPath, "..") {
return fmt.Errorf("invalid file path: %s", relPath)
}
// Check if this file should be excluded
if excludeSet[cleanRelPath] || excludeSet[filepath.ToSlash(cleanRelPath)] {
return nil
}
filesToInclude = append(filesToInclude, path)
return nil
})
if err != nil {
return "", fmt.Errorf("failed to walk source directory: %w", err)
}
if len(filesToInclude) == 0 {
return "", fmt.Errorf("no files found to include in context archive")
}
// Collect relative paths to find common prefix
relPaths := make([]string, 0, len(filesToInclude))
for _, filePath := range filesToInclude {
relPath, err := filepath.Rel(sourceDir, filePath)
if err != nil {
return "", fmt.Errorf("failed to get relative path for %s: %w", filePath, err)
}
relPaths = append(relPaths, relPath)
}
// Find and strip common leading directory if all files share one
commonPrefix := findCommonPrefix(relPaths)
// Validate that there's exactly one .blend file at the root level after prefix stripping
blendFilesAtRoot := 0
for _, relPath := range relPaths {
tarPath := filepath.ToSlash(relPath)
// Strip common prefix if present
if commonPrefix != "" && strings.HasPrefix(tarPath, commonPrefix) {
tarPath = strings.TrimPrefix(tarPath, commonPrefix)
}
// Check if it's a .blend file at root (no path separators after prefix stripping)
if strings.HasSuffix(strings.ToLower(tarPath), ".blend") {
// Check if it's at root level (no directory separators)
if !strings.Contains(tarPath, "/") {
blendFilesAtRoot++
}
}
}
if blendFilesAtRoot == 0 {
return "", fmt.Errorf("no .blend file found at root level in context archive - .blend files must be at the root level of the uploaded archive, not in subdirectories")
}
if blendFilesAtRoot > 1 {
return "", fmt.Errorf("multiple .blend files found at root level in context archive (found %d, expected 1)", blendFilesAtRoot)
}
// Create the tar file using streaming
contextFile, err := os.Create(contextPath)
if err != nil {
return "", fmt.Errorf("failed to create context file: %w", err)
}
defer contextFile.Close()
tarWriter := tar.NewWriter(contextFile)
defer tarWriter.Close()
// Add each file to the tar archive
for i, filePath := range filesToInclude {
file, err := os.Open(filePath)
if err != nil {
return "", fmt.Errorf("failed to open file %s: %w", filePath, err)
}
// Use a function closure to ensure file is closed even on error
err = func() error {
defer file.Close()
info, err := file.Stat()
if err != nil {
return fmt.Errorf("failed to stat file %s: %w", filePath, err)
}
// Get relative path and strip common prefix if present
relPath := relPaths[i]
tarPath := filepath.ToSlash(relPath)
// Strip common prefix if found
if commonPrefix != "" && strings.HasPrefix(tarPath, commonPrefix) {
tarPath = strings.TrimPrefix(tarPath, commonPrefix)
}
// Create tar header
header, err := tar.FileInfoHeader(info, "")
if err != nil {
return fmt.Errorf("failed to create tar header for %s: %w", filePath, err)
}
header.Name = tarPath
// Write header
if err := tarWriter.WriteHeader(header); err != nil {
return fmt.Errorf("failed to write tar header for %s: %w", filePath, err)
}
// Copy file contents using streaming
if _, err := io.Copy(tarWriter, file); err != nil {
return fmt.Errorf("failed to write file %s to tar: %w", filePath, err)
}
return nil
}()
if err != nil {
return "", err
}
}
// Ensure all data is flushed
if err := tarWriter.Close(); err != nil {
return "", fmt.Errorf("failed to close tar writer: %w", err)
}
if err := contextFile.Close(); err != nil {
return "", fmt.Errorf("failed to close context file: %w", err)
}
return contextPath, nil
}

BIN
jiggablend Executable file

Binary file not shown.

366
pkg/executils/exec.go Normal file
View File

@@ -0,0 +1,366 @@
package executils
import (
"bufio"
"errors"
"fmt"
"os"
"os/exec"
"sync"
"time"
"jiggablend/pkg/types"
)
// DefaultTracker is the global default process tracker
// Use this for processes that should be tracked globally and killed on shutdown
var DefaultTracker = NewProcessTracker()
// ProcessTracker tracks running processes for cleanup
type ProcessTracker struct {
processes sync.Map // map[int64]*exec.Cmd - tracks running processes by task ID
}
// NewProcessTracker creates a new process tracker
func NewProcessTracker() *ProcessTracker {
return &ProcessTracker{}
}
// Track registers a process for tracking
func (pt *ProcessTracker) Track(taskID int64, cmd *exec.Cmd) {
pt.processes.Store(taskID, cmd)
}
// Untrack removes a process from tracking
func (pt *ProcessTracker) Untrack(taskID int64) {
pt.processes.Delete(taskID)
}
// Get returns the command for a task ID if it exists
func (pt *ProcessTracker) Get(taskID int64) (*exec.Cmd, bool) {
if val, ok := pt.processes.Load(taskID); ok {
return val.(*exec.Cmd), true
}
return nil, false
}
// Kill kills a specific process by task ID
// Returns true if the process was found and killed
func (pt *ProcessTracker) Kill(taskID int64) bool {
cmd, ok := pt.Get(taskID)
if !ok || cmd.Process == nil {
return false
}
// Try graceful kill first (SIGINT)
if err := cmd.Process.Signal(os.Interrupt); err != nil {
// If SIGINT fails, try SIGKILL
cmd.Process.Kill()
} else {
// Give it a moment to clean up gracefully
time.Sleep(100 * time.Millisecond)
// Force kill if still running
cmd.Process.Kill()
}
pt.Untrack(taskID)
return true
}
// KillAll kills all tracked processes
// Returns the number of processes killed
func (pt *ProcessTracker) KillAll() int {
var killedCount int
pt.processes.Range(func(key, value interface{}) bool {
taskID := key.(int64)
cmd := value.(*exec.Cmd)
if cmd.Process != nil {
// Try graceful kill first (SIGINT)
if err := cmd.Process.Signal(os.Interrupt); err == nil {
// Give it a moment to clean up
time.Sleep(100 * time.Millisecond)
}
// Force kill
cmd.Process.Kill()
killedCount++
}
pt.processes.Delete(taskID)
return true
})
return killedCount
}
// Count returns the number of tracked processes
func (pt *ProcessTracker) Count() int {
count := 0
pt.processes.Range(func(key, value interface{}) bool {
count++
return true
})
return count
}
// CommandResult holds the output from a command execution
type CommandResult struct {
Stdout string
Stderr string
ExitCode int
}
// RunCommand executes a command and returns the output
// If tracker is provided, the process will be registered for tracking
// This is useful for commands where you need to capture output (like metadata extraction)
func RunCommand(
cmdPath string,
args []string,
dir string,
env []string,
taskID int64,
tracker *ProcessTracker,
) (*CommandResult, error) {
cmd := exec.Command(cmdPath, args...)
cmd.Dir = dir
if env != nil {
cmd.Env = env
}
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
}
stderrPipe, err := cmd.StderrPipe()
if err != nil {
return nil, fmt.Errorf("failed to create stderr pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start command: %w", err)
}
// Track the process if tracker is provided
if tracker != nil {
tracker.Track(taskID, cmd)
defer tracker.Untrack(taskID)
}
// Collect stdout
var stdoutBuf, stderrBuf []byte
var stdoutErr, stderrErr error
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
stdoutBuf, stdoutErr = readAll(stdoutPipe)
}()
go func() {
defer wg.Done()
stderrBuf, stderrErr = readAll(stderrPipe)
}()
waitErr := cmd.Wait()
wg.Wait()
// Check for read errors
if stdoutErr != nil {
return nil, fmt.Errorf("failed to read stdout: %w", stdoutErr)
}
if stderrErr != nil {
return nil, fmt.Errorf("failed to read stderr: %w", stderrErr)
}
result := &CommandResult{
Stdout: string(stdoutBuf),
Stderr: string(stderrBuf),
}
if waitErr != nil {
if exitErr, ok := waitErr.(*exec.ExitError); ok {
result.ExitCode = exitErr.ExitCode()
} else {
result.ExitCode = -1
}
return result, waitErr
}
result.ExitCode = 0
return result, nil
}
// readAll reads all data from a reader
func readAll(r interface{ Read([]byte) (int, error) }) ([]byte, error) {
var buf []byte
tmp := make([]byte, 4096)
for {
n, err := r.Read(tmp)
if n > 0 {
buf = append(buf, tmp[:n]...)
}
if err != nil {
if err.Error() == "EOF" {
break
}
return buf, err
}
}
return buf, nil
}
// LogSender is a function type for sending logs
type LogSender func(taskID int, level types.LogLevel, message string, stepName string)
// LineFilter is a function that processes a line and returns whether to filter it out and the log level
type LineFilter func(line string) (shouldFilter bool, level types.LogLevel)
// RunCommandWithStreaming executes a command with streaming output and OOM detection
// If tracker is provided, the process will be registered for tracking
func RunCommandWithStreaming(
cmdPath string,
args []string,
dir string,
env []string,
taskID int,
stepName string,
logSender LogSender,
stdoutFilter LineFilter,
stderrFilter LineFilter,
oomMessage string,
tracker *ProcessTracker,
) error {
cmd := exec.Command(cmdPath, args...)
cmd.Dir = dir
cmd.Env = env
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
errMsg := fmt.Sprintf("failed to create stdout pipe: %v", err)
logSender(taskID, types.LogLevelError, errMsg, stepName)
return errors.New(errMsg)
}
stderrPipe, err := cmd.StderrPipe()
if err != nil {
errMsg := fmt.Sprintf("failed to create stderr pipe: %v", err)
logSender(taskID, types.LogLevelError, errMsg, stepName)
return errors.New(errMsg)
}
if err := cmd.Start(); err != nil {
errMsg := fmt.Sprintf("failed to start command: %v", err)
logSender(taskID, types.LogLevelError, errMsg, stepName)
return errors.New(errMsg)
}
// Track the process if tracker is provided
if tracker != nil {
tracker.Track(int64(taskID), cmd)
defer tracker.Untrack(int64(taskID))
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
scanner := bufio.NewScanner(stdoutPipe)
for scanner.Scan() {
line := scanner.Text()
if line != "" {
shouldFilter, level := stdoutFilter(line)
if !shouldFilter {
logSender(taskID, level, line, stepName)
}
}
}
}()
go func() {
defer wg.Done()
scanner := bufio.NewScanner(stderrPipe)
for scanner.Scan() {
line := scanner.Text()
if line != "" {
shouldFilter, level := stderrFilter(line)
if !shouldFilter {
logSender(taskID, level, line, stepName)
}
}
}
}()
err = cmd.Wait()
wg.Wait()
if err != nil {
var errMsg string
if exitErr, ok := err.(*exec.ExitError); ok {
if exitErr.ExitCode() == 137 {
errMsg = oomMessage
} else {
errMsg = fmt.Sprintf("command failed: %v", err)
}
} else {
errMsg = fmt.Sprintf("command failed: %v", err)
}
logSender(taskID, types.LogLevelError, errMsg, stepName)
return errors.New(errMsg)
}
return nil
}
// ============================================================================
// Helper functions using DefaultTracker
// ============================================================================
// Run executes a command using the default tracker and returns the output
// This is a convenience wrapper around RunCommand that uses DefaultTracker
func Run(cmdPath string, args []string, dir string, env []string, taskID int64) (*CommandResult, error) {
return RunCommand(cmdPath, args, dir, env, taskID, DefaultTracker)
}
// RunStreaming executes a command with streaming output using the default tracker
// This is a convenience wrapper around RunCommandWithStreaming that uses DefaultTracker
func RunStreaming(
cmdPath string,
args []string,
dir string,
env []string,
taskID int,
stepName string,
logSender LogSender,
stdoutFilter LineFilter,
stderrFilter LineFilter,
oomMessage string,
) error {
return RunCommandWithStreaming(cmdPath, args, dir, env, taskID, stepName, logSender, stdoutFilter, stderrFilter, oomMessage, DefaultTracker)
}
// KillAll kills all processes tracked by the default tracker
// Returns the number of processes killed
func KillAll() int {
return DefaultTracker.KillAll()
}
// Kill kills a specific process by task ID using the default tracker
// Returns true if the process was found and killed
func Kill(taskID int64) bool {
return DefaultTracker.Kill(taskID)
}
// Track registers a process with the default tracker
func Track(taskID int64, cmd *exec.Cmd) {
DefaultTracker.Track(taskID, cmd)
}
// Untrack removes a process from the default tracker
func Untrack(taskID int64) {
DefaultTracker.Untrack(taskID)
}
// GetTrackedCount returns the number of processes tracked by the default tracker
func GetTrackedCount() int {
return DefaultTracker.Count()
}

13
pkg/scripts/scripts.go Normal file
View File

@@ -0,0 +1,13 @@
package scripts
import _ "embed"
//go:embed scripts/extract_metadata.py
var ExtractMetadata string
//go:embed scripts/unhide_objects.py
var UnhideObjects string
//go:embed scripts/render_blender.py.template
var RenderBlenderTemplate string

View File

@@ -0,0 +1,370 @@
import bpy
import json
import sys
# Make all file paths relative to the blend file location FIRST
# This must be done immediately after file load, before any other operations
# to prevent Blender from trying to access external files with absolute paths
try:
bpy.ops.file.make_paths_relative()
print("Made all file paths relative to blend file")
except Exception as e:
print(f"Warning: Could not make paths relative: {e}")
# Check for missing addons that the blend file requires
# Blender marks missing addons with "_missing" suffix in preferences
missing_files_info = {
"checked": False,
"has_missing": False,
"missing_files": [],
"missing_addons": []
}
try:
missing = []
for mod in bpy.context.preferences.addons:
if mod.module.endswith("_missing"):
missing.append(mod.module.rsplit("_", 1)[0])
missing_files_info["checked"] = True
if missing:
missing_files_info["has_missing"] = True
missing_files_info["missing_addons"] = missing
print("Missing add-ons required by this .blend:")
for name in missing:
print(" -", name)
else:
print("No missing add-ons detected file is headless-safe")
except Exception as e:
print(f"Warning: Could not check for missing addons: {e}")
missing_files_info["error"] = str(e)
# Get scene
scene = bpy.context.scene
# Extract frame range from scene settings
frame_start = scene.frame_start
frame_end = scene.frame_end
# Check for negative frames (not supported)
has_negative_start = frame_start < 0
has_negative_end = frame_end < 0
# Also check for actual animation range (keyframes)
# Find the earliest and latest keyframes across all objects
animation_start = None
animation_end = None
for obj in scene.objects:
if obj.animation_data and obj.animation_data.action:
action = obj.animation_data.action
# Check if action has fcurves attribute (varies by Blender version/context)
try:
fcurves = action.fcurves if hasattr(action, 'fcurves') else None
if fcurves:
for fcurve in fcurves:
if fcurve.keyframe_points:
for keyframe in fcurve.keyframe_points:
frame = int(keyframe.co[0])
if animation_start is None or frame < animation_start:
animation_start = frame
if animation_end is None or frame > animation_end:
animation_end = frame
except (AttributeError, TypeError) as e:
# Action doesn't have fcurves or fcurves is not iterable - skip this object
pass
# Use animation range if available, otherwise use scene frame range
# If scene range seems wrong (start == end), prefer animation range
if animation_start is not None and animation_end is not None:
if frame_start == frame_end or (animation_start < frame_start or animation_end > frame_end):
# Use animation range if scene range is invalid or animation extends beyond it
frame_start = animation_start
frame_end = animation_end
# Check for negative frames (not supported)
has_negative_start = frame_start < 0
has_negative_end = frame_end < 0
has_negative_animation = (animation_start is not None and animation_start < 0) or (animation_end is not None and animation_end < 0)
# Extract render settings
render = scene.render
resolution_x = render.resolution_x
resolution_y = render.resolution_y
frame_rate = render.fps / render.fps_base if render.fps_base != 0 else render.fps
engine = scene.render.engine.upper()
# Determine output format from file format
output_format = render.image_settings.file_format
# Extract engine-specific settings
engine_settings = {}
if engine == 'CYCLES':
cycles = scene.cycles
# Get denoiser settings - in Blender 3.0+ it's on the view layer
denoiser = 'OPENIMAGEDENOISE' # Default
denoising_use_gpu = False
denoising_input_passes = 'RGB_ALBEDO_NORMAL' # Default: Albedo and Normal
denoising_prefilter = 'ACCURATE' # Default
denoising_quality = 'HIGH' # Default (for OpenImageDenoise)
try:
view_layer = bpy.context.view_layer
if hasattr(view_layer, 'cycles'):
vl_cycles = view_layer.cycles
denoiser = getattr(vl_cycles, 'denoiser', 'OPENIMAGEDENOISE')
denoising_use_gpu = getattr(vl_cycles, 'denoising_use_gpu', False)
denoising_input_passes = getattr(vl_cycles, 'denoising_input_passes', 'RGB_ALBEDO_NORMAL')
denoising_prefilter = getattr(vl_cycles, 'denoising_prefilter', 'ACCURATE')
# Quality is only for OpenImageDenoise in Blender 4.0+
denoising_quality = getattr(vl_cycles, 'denoising_quality', 'HIGH')
except:
pass
engine_settings = {
# Sampling settings
"samples": getattr(cycles, 'samples', 4096), # Max Samples
"adaptive_min_samples": getattr(cycles, 'adaptive_min_samples', 0), # Min Samples
"use_adaptive_sampling": getattr(cycles, 'use_adaptive_sampling', True), # Noise Threshold enabled
"adaptive_threshold": getattr(cycles, 'adaptive_threshold', 0.01), # Noise Threshold value
"time_limit": getattr(cycles, 'time_limit', 0.0), # Time Limit (0 = disabled)
# Denoising settings
"use_denoising": getattr(cycles, 'use_denoising', False),
"denoiser": denoiser,
"denoising_use_gpu": denoising_use_gpu,
"denoising_input_passes": denoising_input_passes,
"denoising_prefilter": denoising_prefilter,
"denoising_quality": denoising_quality,
# Path Guiding settings
"use_guiding": getattr(cycles, 'use_guiding', False),
"guiding_training_samples": getattr(cycles, 'guiding_training_samples', 128),
"use_surface_guiding": getattr(cycles, 'use_surface_guiding', True),
"use_volume_guiding": getattr(cycles, 'use_volume_guiding', True),
# Lights settings
"use_light_tree": getattr(cycles, 'use_light_tree', True),
"light_sampling_threshold": getattr(cycles, 'light_sampling_threshold', 0.01),
# Device
"device": getattr(cycles, 'device', 'CPU'),
# Advanced/Seed settings
"seed": getattr(cycles, 'seed', 0),
"use_animated_seed": getattr(cycles, 'use_animated_seed', False),
"sampling_pattern": getattr(cycles, 'sampling_pattern', 'AUTOMATIC'),
"scrambling_distance": getattr(cycles, 'scrambling_distance', 1.0),
"auto_scrambling_distance_multiplier": getattr(cycles, 'auto_scrambling_distance_multiplier', 1.0),
"preview_scrambling_distance": getattr(cycles, 'preview_scrambling_distance', False),
"min_light_bounces": getattr(cycles, 'min_light_bounces', 0),
"min_transparent_bounces": getattr(cycles, 'min_transparent_bounces', 0),
# Clamping
"sample_clamp_direct": getattr(cycles, 'sample_clamp_direct', 0.0),
"sample_clamp_indirect": getattr(cycles, 'sample_clamp_indirect', 0.0),
# Light Paths / Bounces
"max_bounces": getattr(cycles, 'max_bounces', 12),
"diffuse_bounces": getattr(cycles, 'diffuse_bounces', 4),
"glossy_bounces": getattr(cycles, 'glossy_bounces', 4),
"transmission_bounces": getattr(cycles, 'transmission_bounces', 12),
"volume_bounces": getattr(cycles, 'volume_bounces', 0),
"transparent_max_bounces": getattr(cycles, 'transparent_max_bounces', 8),
# Caustics
"caustics_reflective": getattr(cycles, 'caustics_reflective', False),
"caustics_refractive": getattr(cycles, 'caustics_refractive', False),
"blur_glossy": getattr(cycles, 'blur_glossy', 0.0), # Filter Glossy
# Fast GI Approximation
"use_fast_gi": getattr(cycles, 'use_fast_gi', False),
"fast_gi_method": getattr(cycles, 'fast_gi_method', 'REPLACE'), # REPLACE or ADD
"ao_bounces": getattr(cycles, 'ao_bounces', 1), # Viewport bounces
"ao_bounces_render": getattr(cycles, 'ao_bounces_render', 1), # Render bounces
# Volumes
"volume_step_rate": getattr(cycles, 'volume_step_rate', 1.0),
"volume_preview_step_rate": getattr(cycles, 'volume_preview_step_rate', 1.0),
"volume_max_steps": getattr(cycles, 'volume_max_steps', 1024),
# Film
"film_exposure": getattr(cycles, 'film_exposure', 1.0),
"film_transparent": getattr(cycles, 'film_transparent', False),
"film_transparent_glass": getattr(cycles, 'film_transparent_glass', False),
"film_transparent_roughness": getattr(cycles, 'film_transparent_roughness', 0.1),
"filter_type": getattr(cycles, 'filter_type', 'BLACKMAN_HARRIS'), # BOX, GAUSSIAN, BLACKMAN_HARRIS
"filter_width": getattr(cycles, 'filter_width', 1.5),
"pixel_filter_type": getattr(cycles, 'pixel_filter_type', 'BLACKMAN_HARRIS'),
# Performance
"use_auto_tile": getattr(cycles, 'use_auto_tile', True),
"tile_size": getattr(cycles, 'tile_size', 2048),
"use_persistent_data": getattr(cycles, 'use_persistent_data', False),
# Hair/Curves
"use_hair": getattr(cycles, 'use_hair', True),
"hair_subdivisions": getattr(cycles, 'hair_subdivisions', 2),
"hair_shape": getattr(cycles, 'hair_shape', 'THICK'), # ROUND, RIBBONS, THICK
# Simplify (from scene.render)
"use_simplify": getattr(scene.render, 'use_simplify', False),
"simplify_subdivision_render": getattr(scene.render, 'simplify_subdivision_render', 6),
"simplify_child_particles_render": getattr(scene.render, 'simplify_child_particles_render', 1.0),
# Other
"use_light_linking": getattr(cycles, 'use_light_linking', False),
"use_layer_samples": getattr(cycles, 'use_layer_samples', False),
}
elif engine == 'EEVEE' or engine == 'EEVEE_NEXT':
# Treat EEVEE_NEXT as EEVEE (modern Blender uses EEVEE for what was EEVEE_NEXT)
eevee = scene.eevee
engine_settings = {
# Sampling
"taa_render_samples": getattr(eevee, 'taa_render_samples', 64),
"taa_samples": getattr(eevee, 'taa_samples', 16), # Viewport samples
"use_taa_reprojection": getattr(eevee, 'use_taa_reprojection', True),
# Clamping
"clamp_surface_direct": getattr(eevee, 'clamp_surface_direct', 0.0),
"clamp_surface_indirect": getattr(eevee, 'clamp_surface_indirect', 0.0),
"clamp_volume_direct": getattr(eevee, 'clamp_volume_direct', 0.0),
"clamp_volume_indirect": getattr(eevee, 'clamp_volume_indirect', 0.0),
# Shadows
"shadow_cube_size": getattr(eevee, 'shadow_cube_size', '512'),
"shadow_cascade_size": getattr(eevee, 'shadow_cascade_size', '1024'),
"use_shadow_high_bitdepth": getattr(eevee, 'use_shadow_high_bitdepth', False),
"use_soft_shadows": getattr(eevee, 'use_soft_shadows', True),
"light_threshold": getattr(eevee, 'light_threshold', 0.01),
# Raytracing (EEVEE Next / modern EEVEE)
"use_raytracing": getattr(eevee, 'use_raytracing', False),
"ray_tracing_method": getattr(eevee, 'ray_tracing_method', 'SCREEN'), # SCREEN or PROBE
"ray_tracing_options_trace_max_roughness": getattr(eevee, 'ray_tracing_options', {}).get('trace_max_roughness', 0.5) if hasattr(getattr(eevee, 'ray_tracing_options', None), 'get') else 0.5,
# Screen Space Reflections (legacy/fallback)
"use_ssr": getattr(eevee, 'use_ssr', False),
"use_ssr_refraction": getattr(eevee, 'use_ssr_refraction', False),
"use_ssr_halfres": getattr(eevee, 'use_ssr_halfres', True),
"ssr_quality": getattr(eevee, 'ssr_quality', 0.25),
"ssr_max_roughness": getattr(eevee, 'ssr_max_roughness', 0.5),
"ssr_thickness": getattr(eevee, 'ssr_thickness', 0.2),
"ssr_border_fade": getattr(eevee, 'ssr_border_fade', 0.075),
"ssr_firefly_fac": getattr(eevee, 'ssr_firefly_fac', 10.0),
# Ambient Occlusion
"use_gtao": getattr(eevee, 'use_gtao', False),
"gtao_distance": getattr(eevee, 'gtao_distance', 0.2),
"gtao_factor": getattr(eevee, 'gtao_factor', 1.0),
"gtao_quality": getattr(eevee, 'gtao_quality', 0.25),
"use_gtao_bent_normals": getattr(eevee, 'use_gtao_bent_normals', True),
"use_gtao_bounce": getattr(eevee, 'use_gtao_bounce', True),
# Bloom
"use_bloom": getattr(eevee, 'use_bloom', False),
"bloom_threshold": getattr(eevee, 'bloom_threshold', 0.8),
"bloom_knee": getattr(eevee, 'bloom_knee', 0.5),
"bloom_radius": getattr(eevee, 'bloom_radius', 6.5),
"bloom_color": list(getattr(eevee, 'bloom_color', (1.0, 1.0, 1.0))),
"bloom_intensity": getattr(eevee, 'bloom_intensity', 0.05),
"bloom_clamp": getattr(eevee, 'bloom_clamp', 0.0),
# Depth of Field
"bokeh_max_size": getattr(eevee, 'bokeh_max_size', 100.0),
"bokeh_threshold": getattr(eevee, 'bokeh_threshold', 1.0),
"bokeh_neighbor_max": getattr(eevee, 'bokeh_neighbor_max', 10.0),
"bokeh_denoise_fac": getattr(eevee, 'bokeh_denoise_fac', 0.75),
"use_bokeh_high_quality_slight_defocus": getattr(eevee, 'use_bokeh_high_quality_slight_defocus', False),
"use_bokeh_jittered": getattr(eevee, 'use_bokeh_jittered', False),
"bokeh_overblur": getattr(eevee, 'bokeh_overblur', 5.0),
# Subsurface Scattering
"sss_samples": getattr(eevee, 'sss_samples', 7),
"sss_jitter_threshold": getattr(eevee, 'sss_jitter_threshold', 0.3),
# Volumetrics
"use_volumetric_lights": getattr(eevee, 'use_volumetric_lights', True),
"use_volumetric_shadows": getattr(eevee, 'use_volumetric_shadows', False),
"volumetric_start": getattr(eevee, 'volumetric_start', 0.1),
"volumetric_end": getattr(eevee, 'volumetric_end', 100.0),
"volumetric_tile_size": getattr(eevee, 'volumetric_tile_size', '8'),
"volumetric_samples": getattr(eevee, 'volumetric_samples', 64),
"volumetric_sample_distribution": getattr(eevee, 'volumetric_sample_distribution', 0.8),
"volumetric_ray_depth": getattr(eevee, 'volumetric_ray_depth', 16),
# Motion Blur
"use_motion_blur": getattr(eevee, 'use_motion_blur', False),
"motion_blur_position": getattr(eevee, 'motion_blur_position', 'CENTER'),
"motion_blur_shutter": getattr(eevee, 'motion_blur_shutter', 0.5),
"motion_blur_depth_scale": getattr(eevee, 'motion_blur_depth_scale', 100.0),
"motion_blur_max": getattr(eevee, 'motion_blur_max', 32),
"motion_blur_steps": getattr(eevee, 'motion_blur_steps', 1),
# Film
"use_overscan": getattr(eevee, 'use_overscan', False),
"overscan_size": getattr(eevee, 'overscan_size', 3.0),
# Indirect Lighting
"gi_diffuse_bounces": getattr(eevee, 'gi_diffuse_bounces', 3),
"gi_cubemap_resolution": getattr(eevee, 'gi_cubemap_resolution', '512'),
"gi_visibility_resolution": getattr(eevee, 'gi_visibility_resolution', '32'),
"gi_irradiance_smoothing": getattr(eevee, 'gi_irradiance_smoothing', 0.1),
"gi_glossy_clamp": getattr(eevee, 'gi_glossy_clamp', 0.0),
"gi_filter_quality": getattr(eevee, 'gi_filter_quality', 3.0),
"gi_show_irradiance": getattr(eevee, 'gi_show_irradiance', False),
"gi_show_cubemaps": getattr(eevee, 'gi_show_cubemaps', False),
"gi_auto_bake": getattr(eevee, 'gi_auto_bake', False),
# Hair/Curves
"hair_type": getattr(eevee, 'hair_type', 'STRIP'), # STRIP or STRAND
# Performance
"use_shadow_jitter_viewport": getattr(eevee, 'use_shadow_jitter_viewport', True),
# Simplify (from scene.render)
"use_simplify": getattr(scene.render, 'use_simplify', False),
"simplify_subdivision_render": getattr(scene.render, 'simplify_subdivision_render', 6),
"simplify_child_particles_render": getattr(scene.render, 'simplify_child_particles_render', 1.0),
}
else:
# For other engines, extract basic samples if available
engine_settings = {
"samples": getattr(scene, 'samples', 128) if hasattr(scene, 'samples') else 128
}
# Extract scene info
camera_count = len([obj for obj in scene.objects if obj.type == 'CAMERA'])
object_count = len(scene.objects)
material_count = len(bpy.data.materials)
# Extract Blender version info
# bpy.data.version gives the version the file was saved with
blender_version = ".".join(map(str, bpy.data.version)) if hasattr(bpy.data, 'version') else bpy.app.version_string
# Build metadata dictionary
metadata = {
"frame_start": frame_start,
"frame_end": frame_end,
"has_negative_frames": has_negative_start or has_negative_end or has_negative_animation,
"blender_version": blender_version,
"render_settings": {
"resolution_x": resolution_x,
"resolution_y": resolution_y,
"frame_rate": frame_rate,
"output_format": output_format,
"engine": engine.lower(),
"engine_settings": engine_settings
},
"scene_info": {
"camera_count": camera_count,
"object_count": object_count,
"material_count": material_count
},
"missing_files_info": missing_files_info
}
# Output as JSON
print(json.dumps(metadata))
sys.stdout.flush()

View File

@@ -0,0 +1,653 @@
import bpy
import sys
import os
import json
# Make all file paths relative to the blend file location FIRST
# This must be done immediately after file load, before any other operations
# to prevent Blender from trying to access external files with absolute paths
try:
bpy.ops.file.make_paths_relative()
print("Made all file paths relative to blend file")
except Exception as e:
print(f"Warning: Could not make paths relative: {e}")
# Auto-enable addons from blender_addons folder in context
# Supports .zip files (installed via Blender API) and already-extracted addons
blend_dir = os.path.dirname(bpy.data.filepath) if bpy.data.filepath else os.getcwd()
addons_dir = os.path.join(blend_dir, "blender_addons")
if os.path.isdir(addons_dir):
print(f"Found blender_addons folder: {addons_dir}")
for item in os.listdir(addons_dir):
item_path = os.path.join(addons_dir, item)
try:
if item.endswith('.zip'):
# Install and enable zip addon using Blender's API
bpy.ops.preferences.addon_install(filepath=item_path)
# Get module name from zip (usually the folder name inside)
import zipfile
with zipfile.ZipFile(item_path, 'r') as zf:
# Find the top-level module name
names = zf.namelist()
if names:
module_name = names[0].split('/')[0]
if module_name.endswith('.py'):
module_name = module_name[:-3]
bpy.ops.preferences.addon_enable(module=module_name)
print(f" Installed and enabled addon: {module_name}")
elif item.endswith('.py') and not item.startswith('__'):
# Single-file addon
bpy.ops.preferences.addon_install(filepath=item_path)
module_name = item[:-3]
bpy.ops.preferences.addon_enable(module=module_name)
print(f" Installed and enabled addon: {module_name}")
elif os.path.isdir(item_path) and os.path.exists(os.path.join(item_path, '__init__.py')):
# Multi-file addon directory - add to path and enable
if addons_dir not in sys.path:
sys.path.insert(0, addons_dir)
bpy.ops.preferences.addon_enable(module=item)
print(f" Enabled addon: {item}")
except Exception as e:
print(f" Error with addon {item}: {e}")
else:
print(f"No blender_addons folder found at: {addons_dir}")
{{UNHIDE_CODE}}
# Read output format from file (created by Go code)
format_file_path = {{FORMAT_FILE_PATH}}
output_format_override = None
if os.path.exists(format_file_path):
try:
with open(format_file_path, 'r') as f:
output_format_override = f.read().strip().upper()
print(f"Read output format from file: '{output_format_override}'")
except Exception as e:
print(f"Warning: Could not read output format file: {e}")
else:
print(f"Warning: Output format file does not exist: {format_file_path}")
# Read render settings from JSON file (created by Go code)
render_settings_file = {{RENDER_SETTINGS_FILE}}
render_settings_override = None
if os.path.exists(render_settings_file):
try:
with open(render_settings_file, 'r') as f:
render_settings_override = json.load(f)
print(f"Loaded render settings from job metadata")
except Exception as e:
print(f"Warning: Could not read render settings file: {e}")
# Get current scene settings (preserve blend file preferences)
scene = bpy.context.scene
current_engine = scene.render.engine
current_device = scene.cycles.device if hasattr(scene, 'cycles') and scene.cycles else None
current_output_format = scene.render.image_settings.file_format
print(f"Blend file render engine: {current_engine}")
if current_device:
print(f"Blend file device setting: {current_device}")
print(f"Blend file output format: {current_output_format}")
# Override output format if specified
# The format file always takes precedence (it's written specifically for this job)
if output_format_override:
print(f"Overriding output format from '{current_output_format}' to '{output_format_override}'")
# Map common format names to Blender's format constants
# For video formats, we render as appropriate frame format first
format_to_use = output_format_override.upper()
if format_to_use in ['EXR_264_MP4', 'EXR_AV1_MP4', 'EXR_VP9_WEBM']:
format_to_use = 'EXR' # Render as EXR for EXR video formats
format_map = {
'PNG': 'PNG',
'JPEG': 'JPEG',
'JPG': 'JPEG',
'EXR': 'OPEN_EXR',
'OPEN_EXR': 'OPEN_EXR',
'TARGA': 'TARGA',
'TIFF': 'TIFF',
'BMP': 'BMP',
}
blender_format = format_map.get(format_to_use, format_to_use)
try:
scene.render.image_settings.file_format = blender_format
print(f"Successfully set output format to: {blender_format}")
except Exception as e:
print(f"Warning: Could not set output format to {blender_format}: {e}")
print(f"Using blend file's format: {current_output_format}")
else:
print(f"Using blend file's output format: {current_output_format}")
# Apply render settings from job metadata if provided
# Note: output_format is NOT applied from render_settings_override - it's already set from format file above
if render_settings_override:
engine_override = render_settings_override.get('engine', '').upper()
engine_settings = render_settings_override.get('engine_settings', {})
# Switch engine if specified
if engine_override and engine_override != current_engine.upper():
print(f"Switching render engine from '{current_engine}' to '{engine_override}'")
try:
scene.render.engine = engine_override
current_engine = engine_override
print(f"Successfully switched to {engine_override} engine")
except Exception as e:
print(f"Warning: Could not switch engine to {engine_override}: {e}")
print(f"Using blend file's engine: {current_engine}")
# Apply engine-specific settings
if engine_settings:
if current_engine.upper() == 'CYCLES':
cycles = scene.cycles
print("Applying Cycles render settings from job metadata...")
for key, value in engine_settings.items():
try:
if hasattr(cycles, key):
setattr(cycles, key, value)
print(f" Set Cycles.{key} = {value}")
else:
print(f" Warning: Cycles has no attribute '{key}'")
except Exception as e:
print(f" Warning: Could not set Cycles.{key} = {value}: {e}")
elif current_engine.upper() in ['EEVEE', 'EEVEE_NEXT']:
eevee = scene.eevee
print("Applying EEVEE render settings from job metadata...")
for key, value in engine_settings.items():
try:
if hasattr(eevee, key):
setattr(eevee, key, value)
print(f" Set EEVEE.{key} = {value}")
else:
print(f" Warning: EEVEE has no attribute '{key}'")
except Exception as e:
print(f" Warning: Could not set EEVEE.{key} = {value}: {e}")
# Apply resolution if specified
if 'resolution_x' in render_settings_override:
try:
scene.render.resolution_x = render_settings_override['resolution_x']
print(f"Set resolution_x = {render_settings_override['resolution_x']}")
except Exception as e:
print(f"Warning: Could not set resolution_x: {e}")
if 'resolution_y' in render_settings_override:
try:
scene.render.resolution_y = render_settings_override['resolution_y']
print(f"Set resolution_y = {render_settings_override['resolution_y']}")
except Exception as e:
print(f"Warning: Could not set resolution_y: {e}")
# Only override device selection if using Cycles (other engines handle GPU differently)
if current_engine == 'CYCLES':
# Check if CPU rendering is forced
force_cpu = False
if render_settings_override and render_settings_override.get('force_cpu'):
force_cpu = render_settings_override.get('force_cpu', False)
print("Force CPU rendering is enabled - skipping GPU detection")
# Ensure Cycles addon is enabled
try:
if 'cycles' not in bpy.context.preferences.addons:
bpy.ops.preferences.addon_enable(module='cycles')
print("Enabled Cycles addon")
except Exception as e:
print(f"Warning: Could not enable Cycles addon: {e}")
# If CPU is forced, skip GPU detection and set CPU directly
if force_cpu:
scene.cycles.device = 'CPU'
print("Forced CPU rendering (skipping GPU detection)")
else:
# Access Cycles preferences
prefs = bpy.context.preferences
try:
cycles_prefs = prefs.addons['cycles'].preferences
except (KeyError, AttributeError):
try:
cycles_addon = prefs.addons.get('cycles')
if cycles_addon:
cycles_prefs = cycles_addon.preferences
else:
raise Exception("Cycles addon not found")
except Exception as e:
print(f"ERROR: Could not access Cycles preferences: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
# Check all devices and choose the best GPU type
# Device type preference order (most performant first)
device_type_preference = ['OPTIX', 'CUDA', 'HIP', 'ONEAPI', 'METAL']
gpu_available = False
best_device_type = None
best_gpu_devices = []
devices_by_type = {} # {device_type: [devices]}
seen_device_ids = set() # Track device IDs to avoid duplicates
print("Checking for GPU availability...")
# Try to get all devices - try each device type to see what's available
for device_type in device_type_preference:
try:
cycles_prefs.compute_device_type = device_type
cycles_prefs.refresh_devices()
# Get devices for this type
devices = None
if hasattr(cycles_prefs, 'devices'):
try:
devices_prop = cycles_prefs.devices
if devices_prop:
devices = list(devices_prop) if hasattr(devices_prop, '__iter__') else [devices_prop]
except Exception as e:
pass
if not devices or len(devices) == 0:
try:
devices = cycles_prefs.get_devices()
except Exception as e:
pass
if devices and len(devices) > 0:
# Categorize devices by their type attribute, avoiding duplicates
for device in devices:
if hasattr(device, 'type'):
device_type_str = str(device.type).upper()
device_id = getattr(device, 'id', None)
# Use device ID to avoid duplicates (same device appears when checking different compute_device_types)
if device_id and device_id in seen_device_ids:
continue
if device_id:
seen_device_ids.add(device_id)
if device_type_str not in devices_by_type:
devices_by_type[device_type_str] = []
devices_by_type[device_type_str].append(device)
except (ValueError, AttributeError, KeyError, TypeError):
# Device type not supported, continue
continue
except Exception as e:
# Other errors - log but continue
print(f" Error checking {device_type}: {e}")
continue
# Print what we found
print(f"Found devices by type: {list(devices_by_type.keys())}")
for dev_type, dev_list in devices_by_type.items():
print(f" {dev_type}: {len(dev_list)} device(s)")
for device in dev_list:
device_name = getattr(device, 'name', 'Unknown')
print(f" - {device_name}")
# Choose the best GPU type based on preference
for preferred_type in device_type_preference:
if preferred_type in devices_by_type:
gpu_devices = [d for d in devices_by_type[preferred_type] if preferred_type in ['CUDA', 'OPENCL', 'OPTIX', 'HIP', 'METAL', 'ONEAPI']]
if gpu_devices:
best_device_type = preferred_type
best_gpu_devices = [(d, preferred_type) for d in gpu_devices]
print(f"Selected {preferred_type} as best GPU type with {len(gpu_devices)} device(s)")
break
# Second pass: Enable the best GPU we found
if best_device_type and best_gpu_devices:
print(f"\nEnabling GPU devices for {best_device_type}...")
try:
# Set the device type again
cycles_prefs.compute_device_type = best_device_type
cycles_prefs.refresh_devices()
# First, disable all CPU devices to ensure only GPU is used
print(f" Disabling CPU devices...")
all_devices = cycles_prefs.devices if hasattr(cycles_prefs, 'devices') else cycles_prefs.get_devices()
if all_devices:
for device in all_devices:
if hasattr(device, 'type') and str(device.type).upper() == 'CPU':
try:
device.use = False
device_name = getattr(device, 'name', 'Unknown')
print(f" Disabled CPU: {device_name}")
except Exception as e:
print(f" Warning: Could not disable CPU device {getattr(device, 'name', 'Unknown')}: {e}")
# Enable all GPU devices
enabled_count = 0
for device, device_type in best_gpu_devices:
try:
device.use = True
enabled_count += 1
device_name = getattr(device, 'name', 'Unknown')
print(f" Enabled: {device_name}")
except Exception as e:
print(f" Warning: Could not enable device {getattr(device, 'name', 'Unknown')}: {e}")
# Enable ray tracing acceleration for supported device types
try:
if best_device_type == 'HIP':
# HIPRT (HIP Ray Tracing) for AMD GPUs
if hasattr(cycles_prefs, 'use_hiprt'):
cycles_prefs.use_hiprt = True
print(f" Enabled HIPRT (HIP Ray Tracing) for faster rendering")
elif hasattr(scene.cycles, 'use_hiprt'):
scene.cycles.use_hiprt = True
print(f" Enabled HIPRT (HIP Ray Tracing) for faster rendering")
else:
print(f" HIPRT not available (requires Blender 4.0+)")
elif best_device_type == 'OPTIX':
# OptiX is already enabled when using OPTIX device type
# But we can check if there are any OptiX-specific settings
if hasattr(scene.cycles, 'use_optix_denoising'):
scene.cycles.use_optix_denoising = True
print(f" Enabled OptiX denoising")
print(f" OptiX ray tracing is active (using OPTIX device type)")
elif best_device_type == 'CUDA':
# CUDA can use OptiX if available, but it's usually automatic
# Check if we can prefer OptiX over CUDA
if hasattr(scene.cycles, 'use_optix_denoising'):
scene.cycles.use_optix_denoising = True
print(f" Enabled OptiX denoising (if OptiX available)")
print(f" CUDA ray tracing active")
elif best_device_type == 'METAL':
# MetalRT for Apple Silicon (if available)
if hasattr(scene.cycles, 'use_metalrt'):
scene.cycles.use_metalrt = True
print(f" Enabled MetalRT (Metal Ray Tracing) for faster rendering")
elif hasattr(cycles_prefs, 'use_metalrt'):
cycles_prefs.use_metalrt = True
print(f" Enabled MetalRT (Metal Ray Tracing) for faster rendering")
else:
print(f" MetalRT not available")
elif best_device_type == 'ONEAPI':
# Intel oneAPI - Embree might be available
if hasattr(scene.cycles, 'use_embree'):
scene.cycles.use_embree = True
print(f" Enabled Embree for faster CPU ray tracing")
print(f" oneAPI ray tracing active")
except Exception as e:
print(f" Could not enable ray tracing acceleration: {e}")
print(f"SUCCESS: Enabled {enabled_count} GPU device(s) for {best_device_type}")
gpu_available = True
except Exception as e:
print(f"ERROR: Failed to enable GPU devices: {e}")
import traceback
traceback.print_exc()
# Set device based on availability (prefer GPU, fallback to CPU)
if gpu_available:
scene.cycles.device = 'GPU'
print(f"Using GPU for rendering (blend file had: {current_device})")
# Auto-enable GPU denoising when using GPU (OpenImageDenoise supports all GPUs)
try:
view_layer = bpy.context.view_layer
if hasattr(view_layer, 'cycles') and hasattr(view_layer.cycles, 'denoising_use_gpu'):
view_layer.cycles.denoising_use_gpu = True
print("Auto-enabled GPU denoising (OpenImageDenoise)")
except Exception as e:
print(f"Could not auto-enable GPU denoising: {e}")
else:
scene.cycles.device = 'CPU'
print(f"GPU not available, using CPU for rendering (blend file had: {current_device})")
# Ensure GPU denoising is disabled when using CPU
try:
view_layer = bpy.context.view_layer
if hasattr(view_layer, 'cycles') and hasattr(view_layer.cycles, 'denoising_use_gpu'):
view_layer.cycles.denoising_use_gpu = False
print("Using CPU denoising")
except Exception as e:
pass
# Verify device setting
if current_engine == 'CYCLES':
final_device = scene.cycles.device
print(f"Final Cycles device: {final_device}")
else:
# For other engines (EEVEE, etc.), respect blend file settings
print(f"Using {current_engine} engine - respecting blend file settings")
# Enable GPU acceleration for EEVEE viewport rendering (if using EEVEE)
if current_engine == 'EEVEE' or current_engine == 'EEVEE_NEXT':
try:
if hasattr(bpy.context.preferences.system, 'gpu_backend'):
bpy.context.preferences.system.gpu_backend = 'OPENGL'
print("Enabled OpenGL GPU backend for EEVEE")
except Exception as e:
print(f"Could not set EEVEE GPU backend: {e}")
# Enable GPU acceleration for compositing (if compositing is enabled)
try:
if scene.use_nodes and hasattr(scene, 'node_tree') and scene.node_tree:
if hasattr(scene.node_tree, 'use_gpu_compositing'):
scene.node_tree.use_gpu_compositing = True
print("Enabled GPU compositing")
except Exception as e:
print(f"Could not enable GPU compositing: {e}")
# CRITICAL: Initialize headless rendering to prevent black images
# This ensures the render engine is properly initialized before rendering
print("Initializing headless rendering context...")
try:
# Ensure world exists and has proper settings
if not scene.world:
# Create a default world if none exists
world = bpy.data.worlds.new("World")
scene.world = world
print("Created default world")
# Ensure world has a background shader (not just black)
if scene.world:
# Enable nodes if not already enabled
if not scene.world.use_nodes:
scene.world.use_nodes = True
print("Enabled world nodes")
world_nodes = scene.world.node_tree
if world_nodes:
# Find or create background shader
bg_shader = None
for node in world_nodes.nodes:
if node.type == 'BACKGROUND':
bg_shader = node
break
if not bg_shader:
bg_shader = world_nodes.nodes.new(type='ShaderNodeBackground')
# Connect to output
output = world_nodes.nodes.get('World Output')
if not output:
output = world_nodes.nodes.new(type='ShaderNodeOutputWorld')
output.name = 'World Output'
if output and bg_shader:
# Connect background to surface input
if 'Surface' in output.inputs and 'Background' in bg_shader.outputs:
world_nodes.links.new(bg_shader.outputs['Background'], output.inputs['Surface'])
print("Created background shader for world")
# Ensure background has some color (not pure black)
if bg_shader:
# Only set if it's pure black (0,0,0)
if hasattr(bg_shader.inputs, 'Color'):
color = bg_shader.inputs['Color'].default_value
if len(color) >= 3 and color[0] == 0.0 and color[1] == 0.0 and color[2] == 0.0:
# Set to a very dark gray instead of pure black
bg_shader.inputs['Color'].default_value = (0.01, 0.01, 0.01, 1.0)
print("Adjusted world background color to prevent black renders")
else:
# Fallback: use legacy world color if nodes aren't working
if hasattr(scene.world, 'color'):
color = scene.world.color
if len(color) >= 3 and color[0] == 0.0 and color[1] == 0.0 and color[2] == 0.0:
scene.world.color = (0.01, 0.01, 0.01)
print("Adjusted legacy world color to prevent black renders")
# For EEVEE, force viewport update to initialize render engine
if current_engine in ['EEVEE', 'EEVEE_NEXT']:
# Force EEVEE to update its internal state
try:
# Update depsgraph to ensure everything is initialized
depsgraph = bpy.context.evaluated_depsgraph_get()
if depsgraph:
# Force update
depsgraph.update()
print("Forced EEVEE depsgraph update for headless rendering")
except Exception as e:
print(f"Warning: Could not force EEVEE update: {e}")
# Ensure EEVEE settings are applied
try:
# Force a material update to ensure shaders are compiled
for obj in scene.objects:
if obj.type == 'MESH' and obj.data.materials:
for mat in obj.data.materials:
if mat and mat.use_nodes:
# Touch the material to force update
mat.use_nodes = mat.use_nodes
print("Forced material updates for EEVEE")
except Exception as e:
print(f"Warning: Could not update materials: {e}")
# For Cycles, ensure proper initialization
if current_engine == 'CYCLES':
# Ensure samples are set (even if 1 for preview)
if not hasattr(scene.cycles, 'samples') or scene.cycles.samples < 1:
scene.cycles.samples = 1
print("Set minimum Cycles samples")
# Check for lights in the scene
lights = [obj for obj in scene.objects if obj.type == 'LIGHT']
print(f"Found {len(lights)} light(s) in scene")
if len(lights) == 0:
print("WARNING: No lights found in scene - rendering may be black!")
print(" Consider adding lights or ensuring world background emits light")
# Ensure world background emits light (critical for Cycles)
if scene.world and scene.world.use_nodes:
world_nodes = scene.world.node_tree
if world_nodes:
bg_shader = None
for node in world_nodes.nodes:
if node.type == 'BACKGROUND':
bg_shader = node
break
if bg_shader:
# Check and set strength - Cycles needs this to emit light!
if hasattr(bg_shader.inputs, 'Strength'):
strength = bg_shader.inputs['Strength'].default_value
if strength <= 0.0:
bg_shader.inputs['Strength'].default_value = 1.0
print("Set world background strength to 1.0 for Cycles lighting")
else:
print(f"World background strength: {strength}")
# Also ensure color is not pure black
if hasattr(bg_shader.inputs, 'Color'):
color = bg_shader.inputs['Color'].default_value
if len(color) >= 3 and color[0] == 0.0 and color[1] == 0.0 and color[2] == 0.0:
bg_shader.inputs['Color'].default_value = (1.0, 1.0, 1.0, 1.0)
print("Set world background color to white for Cycles lighting")
# Check film_transparent setting - if enabled, background will be transparent/black
if hasattr(scene.cycles, 'film_transparent') and scene.cycles.film_transparent:
print("WARNING: film_transparent is enabled - background will be transparent")
print(" If you see black renders, try disabling film_transparent")
# Force Cycles to update/compile materials and shaders
try:
# Update depsgraph to ensure everything is initialized
depsgraph = bpy.context.evaluated_depsgraph_get()
if depsgraph:
depsgraph.update()
print("Forced Cycles depsgraph update")
# Force material updates to ensure shaders are compiled
for obj in scene.objects:
if obj.type == 'MESH' and obj.data.materials:
for mat in obj.data.materials:
if mat and mat.use_nodes:
# Force material update
mat.use_nodes = mat.use_nodes
print("Forced Cycles material updates")
except Exception as e:
print(f"Warning: Could not force Cycles updates: {e}")
# Verify device is actually set correctly
if hasattr(scene.cycles, 'device'):
actual_device = scene.cycles.device
print(f"Cycles device setting: {actual_device}")
if actual_device == 'GPU':
# Try to verify GPU is actually available
try:
prefs = bpy.context.preferences
cycles_prefs = prefs.addons['cycles'].preferences
devices = cycles_prefs.devices
enabled_devices = [d for d in devices if d.use]
if len(enabled_devices) == 0:
print("WARNING: GPU device set but no GPU devices are enabled!")
print(" Falling back to CPU may cause issues")
except Exception as e:
print(f"Could not verify GPU devices: {e}")
# Ensure camera exists and is active
if scene.camera is None:
# Find first camera in scene
for obj in scene.objects:
if obj.type == 'CAMERA':
scene.camera = obj
print(f"Set active camera: {obj.name}")
break
print("Headless rendering initialization complete")
except Exception as e:
print(f"Warning: Headless rendering initialization had issues: {e}")
import traceback
traceback.print_exc()
# Final verification before rendering
print("\n=== Pre-render verification ===")
try:
scene = bpy.context.scene
print(f"Render engine: {scene.render.engine}")
print(f"Active camera: {scene.camera.name if scene.camera else 'None'}")
if scene.render.engine == 'CYCLES':
print(f"Cycles device: {scene.cycles.device}")
print(f"Cycles samples: {scene.cycles.samples}")
lights = [obj for obj in scene.objects if obj.type == 'LIGHT']
print(f"Lights in scene: {len(lights)}")
if scene.world:
if scene.world.use_nodes:
world_nodes = scene.world.node_tree
if world_nodes:
bg_shader = None
for node in world_nodes.nodes:
if node.type == 'BACKGROUND':
bg_shader = node
break
if bg_shader:
if hasattr(bg_shader.inputs, 'Strength'):
strength = bg_shader.inputs['Strength'].default_value
print(f"World background strength: {strength}")
if hasattr(bg_shader.inputs, 'Color'):
color = bg_shader.inputs['Color'].default_value
print(f"World background color: ({color[0]:.2f}, {color[1]:.2f}, {color[2]:.2f})")
else:
print("World exists but nodes are disabled")
else:
print("WARNING: No world in scene!")
print("=== Verification complete ===\n")
except Exception as e:
print(f"Warning: Verification failed: {e}")
print("Device configuration complete - blend file settings preserved, device optimized")
sys.stdout.flush()

View File

@@ -0,0 +1,29 @@
# Fix objects and collections hidden from render
vl = bpy.context.view_layer
# 1. Objects hidden in view layer
print("Checking for objects hidden from render that need to be enabled...")
try:
for obj in bpy.data.objects:
if obj.hide_get(view_layer=vl):
if any(k in obj.name.lower() for k in ["scrotum|","cage","genital","penis","dick","collision","body.001","couch"]):
obj.hide_set(False, view_layer=vl)
print("Enabled object:", obj.name)
except Exception as e:
print(f"Warning: Could not check/fix hidden render objects: {e}")
# 2. Collections disabled in renders OR set to Holdout (the final killer)
print("Checking for collections hidden from render that need to be enabled...")
try:
for col in bpy.data.collections:
if col.hide_render or (vl.layer_collection.children.get(col.name) and not vl.layer_collection.children[col.name].exclude == False):
if any(k in col.name.lower() for k in ["genital","nsfw","dick","private","hidden","cage","scrotum","collision","dick"]):
col.hide_render = False
if col.name in vl.layer_collection.children:
vl.layer_collection.children[col.name].exclude = False
vl.layer_collection.children[col.name].holdout = False
vl.layer_collection.children[col.name].indirect_only = False
print("Enabled collection:", col.name)
except Exception as e:
print(f"Warning: Could not check/fix hidden render collections: {e}")

View File

@@ -27,28 +27,25 @@ const (
type JobType string
const (
JobTypeMetadata JobType = "metadata" // Metadata extraction job - only needs blend file
JobTypeRender JobType = "render" // Render job - needs frame range, format, etc.
JobTypeRender JobType = "render" // Render job - needs frame range, format, etc.
)
// Job represents a job (metadata extraction or render)
// Job represents a render job
type Job struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
JobType JobType `json:"job_type"` // "metadata" or "render"
Name string `json:"name"`
Status JobStatus `json:"status"`
Progress float64 `json:"progress"` // 0.0 to 100.0
FrameStart *int `json:"frame_start,omitempty"` // Only for render jobs
FrameEnd *int `json:"frame_end,omitempty"` // Only for render jobs
OutputFormat *string `json:"output_format,omitempty"` // Only for render jobs - PNG, JPEG, EXR, etc.
AllowParallelRunners *bool `json:"allow_parallel_runners,omitempty"` // Only for render jobs
TimeoutSeconds int `json:"timeout_seconds"` // Job-level timeout (24 hours default)
BlendMetadata *BlendMetadata `json:"blend_metadata,omitempty"` // Extracted metadata from blend file
CreatedAt time.Time `json:"created_at"`
StartedAt *time.Time `json:"started_at,omitempty"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
ErrorMessage string `json:"error_message,omitempty"`
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
JobType JobType `json:"job_type"` // "render"
Name string `json:"name"`
Status JobStatus `json:"status"`
Progress float64 `json:"progress"` // 0.0 to 100.0
FrameStart *int `json:"frame_start,omitempty"` // Only for render jobs
FrameEnd *int `json:"frame_end,omitempty"` // Only for render jobs
OutputFormat *string `json:"output_format,omitempty"` // Only for render jobs - PNG, JPEG, EXR, etc.
BlendMetadata *BlendMetadata `json:"blend_metadata,omitempty"` // Extracted metadata from blend file
CreatedAt time.Time `json:"created_at"`
StartedAt *time.Time `json:"started_at,omitempty"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
ErrorMessage string `json:"error_message,omitempty"`
}
// RunnerStatus represents the status of a runner
@@ -87,9 +84,8 @@ const (
type TaskType string
const (
TaskTypeRender TaskType = "render"
TaskTypeMetadata TaskType = "metadata"
TaskTypeVideoGeneration TaskType = "video_generation"
TaskTypeRender TaskType = "render"
TaskTypeEncode TaskType = "encode"
)
// Task represents a render task assigned to a runner
@@ -97,8 +93,7 @@ type Task struct {
ID int64 `json:"id"`
JobID int64 `json:"job_id"`
RunnerID *int64 `json:"runner_id,omitempty"`
FrameStart int `json:"frame_start"`
FrameEnd int `json:"frame_end"`
Frame int `json:"frame"`
TaskType TaskType `json:"task_type"`
Status TaskStatus `json:"status"`
CurrentStep string `json:"current_step,omitempty"`
@@ -133,13 +128,18 @@ type JobFile struct {
// CreateJobRequest represents a request to create a new job
type CreateJobRequest struct {
JobType JobType `json:"job_type"` // "metadata" or "render"
Name string `json:"name"`
FrameStart *int `json:"frame_start,omitempty"` // Required for render jobs
FrameEnd *int `json:"frame_end,omitempty"` // Required for render jobs
OutputFormat *string `json:"output_format,omitempty"` // Required for render jobs
AllowParallelRunners *bool `json:"allow_parallel_runners,omitempty"` // Optional for render jobs, defaults to true
MetadataJobID *int64 `json:"metadata_job_id,omitempty"` // Optional: ID of metadata job to copy input files from
JobType JobType `json:"job_type"` // "render"
Name string `json:"name"`
FrameStart *int `json:"frame_start,omitempty"` // Required for render jobs
FrameEnd *int `json:"frame_end,omitempty"` // Required for render jobs
OutputFormat *string `json:"output_format,omitempty"` // Required for render jobs
RenderSettings *RenderSettings `json:"render_settings,omitempty"` // Optional: Override blend file render settings
UploadSessionID *string `json:"upload_session_id,omitempty"` // Optional: Session ID from file upload
UnhideObjects *bool `json:"unhide_objects,omitempty"` // Optional: Enable unhide tweaks for objects/collections
EnableExecution *bool `json:"enable_execution,omitempty"` // Optional: Enable auto-execution in Blender (adds --enable-autoexec flag, defaults to false)
BlenderVersion *string `json:"blender_version,omitempty"` // Optional: Override Blender version (e.g., "4.2" or "4.2.3")
PreserveHDR *bool `json:"preserve_hdr,omitempty"` // Optional: Preserve HDR range for EXR encoding (uses HLG with bt709 primaries)
PreserveAlpha *bool `json:"preserve_alpha,omitempty"` // Optional: Preserve alpha channel for EXR encoding (requires AV1 or VP9 codec)
}
// UpdateJobProgressRequest represents a request to update job progress
@@ -151,7 +151,7 @@ type UpdateJobProgressRequest struct {
type RegisterRunnerRequest struct {
Name string `json:"name"`
Hostname string `json:"hostname"`
IPAddress string `json:"ip_address"`
IPAddress string `json:"ip_address,omitempty"` // Optional, extracted from request by manager
Capabilities string `json:"capabilities"`
Priority *int `json:"priority,omitempty"` // Optional, defaults to 100 if not provided
}
@@ -225,19 +225,37 @@ type TaskLogEntry struct {
// BlendMetadata represents extracted metadata from a blend file
type BlendMetadata struct {
FrameStart int `json:"frame_start"`
FrameEnd int `json:"frame_end"`
RenderSettings RenderSettings `json:"render_settings"`
SceneInfo SceneInfo `json:"scene_info"`
FrameStart int `json:"frame_start"`
FrameEnd int `json:"frame_end"`
HasNegativeFrames bool `json:"has_negative_frames"` // True if blend file has negative frame numbers (not supported)
RenderSettings RenderSettings `json:"render_settings"`
SceneInfo SceneInfo `json:"scene_info"`
MissingFilesInfo *MissingFilesInfo `json:"missing_files_info,omitempty"`
UnhideObjects *bool `json:"unhide_objects,omitempty"` // Enable unhide tweaks for objects/collections
EnableExecution *bool `json:"enable_execution,omitempty"` // Enable auto-execution in Blender (adds --enable-autoexec flag, defaults to false)
BlenderVersion string `json:"blender_version,omitempty"` // Detected or overridden Blender version (e.g., "4.2" or "4.2.3")
PreserveHDR *bool `json:"preserve_hdr,omitempty"` // Preserve HDR range for EXR encoding (uses HLG with bt709 primaries)
PreserveAlpha *bool `json:"preserve_alpha,omitempty"` // Preserve alpha channel for EXR encoding (requires AV1 or VP9 codec)
}
// MissingFilesInfo represents information about missing files/addons
type MissingFilesInfo struct {
Checked bool `json:"checked"`
HasMissing bool `json:"has_missing"`
MissingFiles []string `json:"missing_files,omitempty"`
MissingAddons []string `json:"missing_addons,omitempty"`
Error string `json:"error,omitempty"`
}
// RenderSettings represents render settings from a blend file
type RenderSettings struct {
ResolutionX int `json:"resolution_x"`
ResolutionY int `json:"resolution_y"`
Samples int `json:"samples"`
OutputFormat string `json:"output_format"`
Engine string `json:"engine"`
ResolutionX int `json:"resolution_x"`
ResolutionY int `json:"resolution_y"`
FrameRate float64 `json:"frame_rate"`
Samples int `json:"samples,omitempty"` // Deprecated, use EngineSettings
OutputFormat string `json:"output_format"`
Engine string `json:"engine"`
EngineSettings map[string]interface{} `json:"engine_settings,omitempty"`
}
// SceneInfo represents scene information from a blend file

16
version/version.go Normal file
View File

@@ -0,0 +1,16 @@
// version/version.go
package version
import "time"
var Version string
var Date string
func init() {
if Version == "" {
Version = "0.0.0-dev"
}
if Date == "" {
Date = time.Now().Format("2006-01-02 15:04:05")
}
}

View File

@@ -241,7 +241,6 @@ function displayRunners(runners) {
<h3>${escapeHtml(runner.name)}</h3>
<div class="runner-info">
<span>Hostname: ${escapeHtml(runner.hostname)}</span>
<span>IP: ${escapeHtml(runner.ip_address)}</span>
<span>Last heartbeat: ${lastHeartbeat.toLocaleString()}</span>
</div>
<div class="runner-status ${isOnline ? 'online' : 'offline'}">

45
web/embed.go Normal file
View File

@@ -0,0 +1,45 @@
package web
import (
"embed"
"io/fs"
"net/http"
"strings"
)
//go:embed dist/*
var distFS embed.FS
// GetFileSystem returns an http.FileSystem for the embedded web UI files
func GetFileSystem() http.FileSystem {
subFS, err := fs.Sub(distFS, "dist")
if err != nil {
panic(err)
}
return http.FS(subFS)
}
// SPAHandler returns an http.Handler that serves the embedded SPA
// It serves static files if they exist, otherwise falls back to index.html
func SPAHandler() http.Handler {
fsys := GetFileSystem()
fileServer := http.FileServer(fsys)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
// Try to open the file
f, err := fsys.Open(strings.TrimPrefix(path, "/"))
if err != nil {
// File doesn't exist, serve index.html for SPA routing
r.URL.Path = "/"
fileServer.ServeHTTP(w, r)
return
}
f.Close()
// File exists, serve it
fileServer.ServeHTTP(w, r)
})
}

View File

@@ -5,6 +5,8 @@ import Layout from './components/Layout';
import JobList from './components/JobList';
import JobSubmission from './components/JobSubmission';
import AdminPanel from './components/AdminPanel';
import ErrorBoundary from './components/ErrorBoundary';
import LoadingSpinner from './components/LoadingSpinner';
import './styles/index.css';
function App() {
@@ -17,7 +19,7 @@ function App() {
if (loading) {
return (
<div className="min-h-screen flex items-center justify-center bg-gray-900">
<div className="animate-spin rounded-full h-12 w-12 border-b-2 border-orange-500"></div>
<LoadingSpinner size="md" />
</div>
);
}
@@ -26,26 +28,20 @@ function App() {
return loginComponent;
}
// Wrapper to check auth before changing tabs
const handleTabChange = async (newTab) => {
// Check auth before allowing navigation
try {
await refresh();
// If refresh succeeds, user is still authenticated
setActiveTab(newTab);
} catch (error) {
// Auth check failed, user will be set to null and login will show
console.error('Auth check failed on navigation:', error);
}
// Wrapper to change tabs - only check auth on mount, not on every navigation
const handleTabChange = (newTab) => {
setActiveTab(newTab);
};
return (
<Layout activeTab={activeTab} onTabChange={handleTabChange}>
{activeTab === 'jobs' && <JobList />}
{activeTab === 'submit' && (
<JobSubmission onSuccess={() => handleTabChange('jobs')} />
)}
{activeTab === 'admin' && <AdminPanel />}
<ErrorBoundary>
{activeTab === 'jobs' && <JobList />}
{activeTab === 'submit' && (
<JobSubmission onSuccess={() => handleTabChange('jobs')} />
)}
{activeTab === 'admin' && <AdminPanel />}
</ErrorBoundary>
</Layout>
);
}

View File

@@ -1,41 +1,136 @@
import { useState, useEffect } from 'react';
import { admin } from '../utils/api';
import { useState, useEffect, useRef } from 'react';
import { admin, jobs, normalizeArrayResponse } from '../utils/api';
import { wsManager } from '../utils/websocket';
import UserJobs from './UserJobs';
import PasswordChange from './PasswordChange';
import LoadingSpinner from './LoadingSpinner';
export default function AdminPanel() {
const [activeSection, setActiveSection] = useState('tokens');
const [tokens, setTokens] = useState([]);
const [activeSection, setActiveSection] = useState('api-keys');
const [apiKeys, setApiKeys] = useState([]);
const [runners, setRunners] = useState([]);
const [users, setUsers] = useState([]);
const [loading, setLoading] = useState(false);
const [newTokenExpires, setNewTokenExpires] = useState(24);
const [newToken, setNewToken] = useState(null);
const [newAPIKeyName, setNewAPIKeyName] = useState('');
const [newAPIKeyDescription, setNewAPIKeyDescription] = useState('');
const [newAPIKeyScope, setNewAPIKeyScope] = useState('user'); // Default to user scope
const [newAPIKey, setNewAPIKey] = useState(null);
const [selectedUser, setSelectedUser] = useState(null);
const [registrationEnabled, setRegistrationEnabled] = useState(true);
const [passwordChangeUser, setPasswordChangeUser] = useState(null);
const listenerIdRef = useRef(null); // Listener ID for shared WebSocket
const subscribedChannelsRef = useRef(new Set()); // Track confirmed subscribed channels
const pendingSubscriptionsRef = useRef(new Set()); // Track pending subscriptions (waiting for confirmation)
// Connect to shared WebSocket on mount
useEffect(() => {
listenerIdRef.current = wsManager.subscribe('adminpanel', {
open: () => {
console.log('AdminPanel: Shared WebSocket connected');
// Subscribe to runners if already viewing runners section
if (activeSection === 'runners') {
subscribeToRunners();
}
},
message: (data) => {
// Handle subscription responses - update both local refs and wsManager
if (data.type === 'subscribed' && data.channel) {
pendingSubscriptionsRef.current.delete(data.channel);
subscribedChannelsRef.current.add(data.channel);
wsManager.confirmSubscription(data.channel);
console.log('Successfully subscribed to channel:', data.channel);
} else if (data.type === 'subscription_error' && data.channel) {
pendingSubscriptionsRef.current.delete(data.channel);
subscribedChannelsRef.current.delete(data.channel);
wsManager.failSubscription(data.channel);
console.error('Subscription failed for channel:', data.channel, data.error);
}
// Handle runners channel messages
if (data.channel === 'runners' && data.type === 'runner_status') {
// Update runner in list
setRunners(prev => {
const index = prev.findIndex(r => r.id === data.runner_id);
if (index >= 0 && data.data) {
const updated = [...prev];
updated[index] = { ...updated[index], ...data.data };
return updated;
}
return prev;
});
}
},
error: (error) => {
console.error('AdminPanel: Shared WebSocket error:', error);
},
close: (event) => {
console.log('AdminPanel: Shared WebSocket closed:', event);
subscribedChannelsRef.current.clear();
pendingSubscriptionsRef.current.clear();
}
});
// Ensure connection is established
wsManager.connect();
return () => {
// Unsubscribe from all channels before unmounting
unsubscribeFromRunners();
if (listenerIdRef.current) {
wsManager.unsubscribe(listenerIdRef.current);
listenerIdRef.current = null;
}
};
}, []);
const subscribeToRunners = () => {
const channel = 'runners';
// Don't subscribe if already subscribed or pending
if (subscribedChannelsRef.current.has(channel) || pendingSubscriptionsRef.current.has(channel)) {
return;
}
wsManager.subscribeToChannel(channel);
subscribedChannelsRef.current.add(channel);
pendingSubscriptionsRef.current.add(channel);
console.log('Subscribing to runners channel');
};
const unsubscribeFromRunners = () => {
const channel = 'runners';
if (!subscribedChannelsRef.current.has(channel)) {
return; // Not subscribed
}
wsManager.unsubscribeFromChannel(channel);
subscribedChannelsRef.current.delete(channel);
pendingSubscriptionsRef.current.delete(channel);
console.log('Unsubscribed from runners channel');
};
useEffect(() => {
if (activeSection === 'tokens') {
loadTokens();
if (activeSection === 'api-keys') {
loadAPIKeys();
unsubscribeFromRunners();
} else if (activeSection === 'runners') {
loadRunners();
subscribeToRunners();
} else if (activeSection === 'users') {
loadUsers();
unsubscribeFromRunners();
} else if (activeSection === 'settings') {
loadSettings();
unsubscribeFromRunners();
}
}, [activeSection]);
const loadTokens = async () => {
const loadAPIKeys = async () => {
setLoading(true);
try {
const data = await admin.listTokens();
setTokens(Array.isArray(data) ? data : []);
const data = await admin.listAPIKeys();
setApiKeys(normalizeArrayResponse(data));
} catch (error) {
console.error('Failed to load tokens:', error);
setTokens([]);
alert('Failed to load tokens');
console.error('Failed to load API keys:', error);
setApiKeys([]);
alert('Failed to load API keys');
} finally {
setLoading(false);
}
@@ -45,7 +140,7 @@ export default function AdminPanel() {
setLoading(true);
try {
const data = await admin.listRunners();
setRunners(Array.isArray(data) ? data : []);
setRunners(normalizeArrayResponse(data));
} catch (error) {
console.error('Failed to load runners:', error);
setRunners([]);
@@ -59,7 +154,7 @@ export default function AdminPanel() {
setLoading(true);
try {
const data = await admin.listUsers();
setUsers(Array.isArray(data) ? data : []);
setUsers(normalizeArrayResponse(data));
} catch (error) {
console.error('Failed to load users:', error);
setUsers([]);
@@ -97,54 +192,61 @@ export default function AdminPanel() {
}
};
const generateToken = async () => {
const generateAPIKey = async () => {
if (!newAPIKeyName.trim()) {
alert('API key name is required');
return;
}
setLoading(true);
try {
const data = await admin.generateToken(newTokenExpires);
setNewToken(data.token);
await loadTokens();
const data = await admin.generateAPIKey(newAPIKeyName.trim(), newAPIKeyDescription.trim() || undefined, newAPIKeyScope);
setNewAPIKey(data);
setNewAPIKeyName('');
setNewAPIKeyDescription('');
setNewAPIKeyScope('user');
await loadAPIKeys();
} catch (error) {
console.error('Failed to generate token:', error);
alert('Failed to generate token');
console.error('Failed to generate API key:', error);
alert('Failed to generate API key');
} finally {
setLoading(false);
}
};
const revokeToken = async (tokenId) => {
if (!confirm('Are you sure you want to revoke this token?')) {
const [deletingKeyId, setDeletingKeyId] = useState(null);
const [deletingRunnerId, setDeletingRunnerId] = useState(null);
const revokeAPIKey = async (keyId) => {
if (!confirm('Are you sure you want to delete this API key? This action cannot be undone.')) {
return;
}
setDeletingKeyId(keyId);
try {
await admin.revokeToken(tokenId);
await loadTokens();
await admin.deleteAPIKey(keyId);
await loadAPIKeys();
} catch (error) {
console.error('Failed to revoke token:', error);
alert('Failed to revoke token');
console.error('Failed to delete API key:', error);
alert('Failed to delete API key');
} finally {
setDeletingKeyId(null);
}
};
const verifyRunner = async (runnerId) => {
try {
await admin.verifyRunner(runnerId);
await loadRunners();
alert('Runner verified');
} catch (error) {
console.error('Failed to verify runner:', error);
alert('Failed to verify runner');
}
};
const deleteRunner = async (runnerId) => {
if (!confirm('Are you sure you want to delete this runner?')) {
return;
}
setDeletingRunnerId(runnerId);
try {
await admin.deleteRunner(runnerId);
await loadRunners();
} catch (error) {
console.error('Failed to delete runner:', error);
alert('Failed to delete runner');
} finally {
setDeletingRunnerId(null);
}
};
@@ -153,12 +255,8 @@ export default function AdminPanel() {
alert('Copied to clipboard!');
};
const isTokenExpired = (expiresAt) => {
return new Date(expiresAt) < new Date();
};
const isTokenUsed = (used) => {
return used;
const isAPIKeyActive = (isActive) => {
return isActive;
};
return (
@@ -166,16 +264,16 @@ export default function AdminPanel() {
<div className="flex space-x-4 border-b border-gray-700">
<button
onClick={() => {
setActiveSection('tokens');
setActiveSection('api-keys');
setSelectedUser(null);
}}
className={`py-2 px-4 border-b-2 font-medium ${
activeSection === 'tokens'
activeSection === 'api-keys'
? 'border-orange-500 text-orange-500'
: 'border-transparent text-gray-400 hover:text-gray-300'
}`}
>
Registration Tokens
API Keys
</button>
<button
onClick={() => {
@@ -218,76 +316,112 @@ export default function AdminPanel() {
</button>
</div>
{activeSection === 'tokens' && (
{activeSection === 'api-keys' && (
<div className="space-y-6">
<div className="bg-gray-800 rounded-lg shadow-md p-6 border border-gray-700">
<h2 className="text-xl font-semibold mb-4 text-gray-100">Generate Registration Token</h2>
<div className="flex gap-4 items-end">
<div>
<label className="block text-sm font-medium text-gray-300 mb-2">
Expires in (hours)
</label>
<input
type="number"
min="1"
max="168"
value={newTokenExpires}
onChange={(e) => setNewTokenExpires(parseInt(e.target.value) || 24)}
className="w-32 px-3 py-2 bg-gray-900 border border-gray-600 rounded-lg text-gray-100 focus:ring-2 focus:ring-orange-500 focus:border-transparent"
/>
<h2 className="text-xl font-semibold mb-4 text-gray-100">Generate API Key</h2>
<div className="space-y-4">
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
<div>
<label className="block text-sm font-medium text-gray-300 mb-2">
Name *
</label>
<input
type="text"
value={newAPIKeyName}
onChange={(e) => setNewAPIKeyName(e.target.value)}
placeholder="e.g., production-runner-01"
className="w-full px-3 py-2 bg-gray-900 border border-gray-600 rounded-lg text-gray-100 focus:ring-2 focus:ring-orange-500 focus:border-transparent"
required
/>
</div>
<div>
<label className="block text-sm font-medium text-gray-300 mb-2">
Description
</label>
<input
type="text"
value={newAPIKeyDescription}
onChange={(e) => setNewAPIKeyDescription(e.target.value)}
placeholder="Optional description"
className="w-full px-3 py-2 bg-gray-900 border border-gray-600 rounded-lg text-gray-100 focus:ring-2 focus:ring-orange-500 focus:border-transparent"
/>
</div>
<div>
<label className="block text-sm font-medium text-gray-300 mb-2">
Scope
</label>
<select
value={newAPIKeyScope}
onChange={(e) => setNewAPIKeyScope(e.target.value)}
className="w-full px-3 py-2 bg-gray-900 border border-gray-600 rounded-lg text-gray-100 focus:ring-2 focus:ring-orange-500 focus:border-transparent"
>
<option value="user">User - Only jobs from API key owner</option>
<option value="manager">Manager - All jobs from any user</option>
</select>
</div>
</div>
<div className="flex justify-end">
<button
onClick={generateAPIKey}
disabled={loading || !newAPIKeyName.trim()}
className="px-6 py-2 bg-orange-600 text-white rounded-lg hover:bg-orange-500 disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
>
Generate API Key
</button>
</div>
<button
onClick={generateToken}
disabled={loading}
className="px-6 py-2 bg-orange-600 text-white rounded-lg hover:bg-orange-500 disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
>
Generate Token
</button>
</div>
{newToken && (
{newAPIKey && (
<div className="mt-4 p-4 bg-green-400/20 border border-green-400/50 rounded-lg">
<p className="text-sm font-medium text-green-400 mb-2">New Token Generated:</p>
<div className="flex items-center gap-2">
<code className="flex-1 px-3 py-2 bg-gray-900 border border-green-400/50 rounded text-sm font-mono break-all text-gray-100">
{newToken}
</code>
<button
onClick={() => copyToClipboard(newToken)}
className="px-4 py-2 bg-green-600 text-white rounded hover:bg-green-500 transition-colors text-sm"
>
Copy
</button>
<p className="text-sm font-medium text-green-400 mb-2">New API Key Generated:</p>
<div className="space-y-2">
<div className="flex items-center gap-2">
<code className="flex-1 px-3 py-2 bg-gray-900 border border-green-400/50 rounded text-sm font-mono break-all text-gray-100">
{newAPIKey.key}
</code>
<button
onClick={() => copyToClipboard(newAPIKey.key)}
className="px-4 py-2 bg-green-600 text-white rounded hover:bg-green-500 transition-colors text-sm whitespace-nowrap"
>
Copy Key
</button>
</div>
<div className="text-xs text-green-400/80">
<p><strong>Name:</strong> {newAPIKey.name}</p>
{newAPIKey.description && <p><strong>Description:</strong> {newAPIKey.description}</p>}
</div>
<p className="text-xs text-green-400/80 mt-2">
Save this API key securely. It will not be shown again.
</p>
</div>
<p className="text-xs text-green-400/80 mt-2">
Save this token securely. It will not be shown again.
</p>
</div>
)}
</div>
<div className="bg-gray-800 rounded-lg shadow-md p-6 border border-gray-700">
<h2 className="text-xl font-semibold mb-4 text-gray-100">Active Tokens</h2>
<h2 className="text-xl font-semibold mb-4 text-gray-100">API Keys</h2>
{loading ? (
<div className="flex justify-center py-8">
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-orange-500"></div>
</div>
) : !tokens || tokens.length === 0 ? (
<p className="text-gray-400 text-center py-8">No tokens generated yet.</p>
<LoadingSpinner size="sm" className="py-8" />
) : !apiKeys || apiKeys.length === 0 ? (
<p className="text-gray-400 text-center py-8">No API keys generated yet.</p>
) : (
<div className="overflow-x-auto">
<table className="min-w-full divide-y divide-gray-700">
<thead className="bg-gray-900">
<tr>
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
Token
Name
</th>
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
Scope
</th>
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
Key Prefix
</th>
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
Status
</th>
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
Expires At
</th>
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
Created At
</th>
@@ -297,46 +431,54 @@ export default function AdminPanel() {
</tr>
</thead>
<tbody className="bg-gray-800 divide-y divide-gray-700">
{tokens.map((token) => {
const expired = isTokenExpired(token.expires_at);
const used = isTokenUsed(token.used);
return (
<tr key={token.id}>
<td className="px-6 py-4 whitespace-nowrap">
<code className="text-sm font-mono text-gray-100">
{token.token.substring(0, 16)}...
</code>
</td>
<td className="px-6 py-4 whitespace-nowrap">
{expired ? (
<span className="px-2 py-1 text-xs font-medium rounded-full bg-red-400/20 text-red-400">
Expired
</span>
) : used ? (
<span className="px-2 py-1 text-xs font-medium rounded-full bg-yellow-400/20 text-yellow-400">
Used
</span>
) : (
<span className="px-2 py-1 text-xs font-medium rounded-full bg-green-400/20 text-green-400">
Active
</span>
)}
</td>
<td className="px-6 py-4 whitespace-nowrap text-sm text-gray-400">
{new Date(token.expires_at).toLocaleString()}
</td>
<td className="px-6 py-4 whitespace-nowrap text-sm text-gray-400">
{new Date(token.created_at).toLocaleString()}
</td>
<td className="px-6 py-4 whitespace-nowrap text-sm">
{!used && !expired && (
<button
onClick={() => revokeToken(token.id)}
className="text-red-400 hover:text-red-300 font-medium"
>
Revoke
</button>
{apiKeys.map((key) => {
return (
<tr key={key.id}>
<td className="px-6 py-4 whitespace-nowrap">
<div>
<div className="text-sm font-medium text-gray-100">{key.name}</div>
{key.description && (
<div className="text-sm text-gray-400">{key.description}</div>
)}
</div>
</td>
<td className="px-6 py-4 whitespace-nowrap">
<span className={`px-2 py-1 text-xs font-medium rounded-full ${
key.scope === 'manager'
? 'bg-purple-400/20 text-purple-400'
: 'bg-blue-400/20 text-blue-400'
}`}>
{key.scope === 'manager' ? 'Manager' : 'User'}
</span>
</td>
<td className="px-6 py-4 whitespace-nowrap">
<code className="text-sm font-mono text-gray-300">
{key.key_prefix}
</code>
</td>
<td className="px-6 py-4 whitespace-nowrap">
{!key.is_active ? (
<span className="px-2 py-1 text-xs font-medium rounded-full bg-gray-500/20 text-gray-400">
Revoked
</span>
) : (
<span className="px-2 py-1 text-xs font-medium rounded-full bg-green-400/20 text-green-400">
Active
</span>
)}
</td>
<td className="px-6 py-4 whitespace-nowrap text-sm text-gray-400">
{new Date(key.created_at).toLocaleString()}
</td>
<td className="px-6 py-4 whitespace-nowrap text-sm space-x-2">
<button
onClick={() => revokeAPIKey(key.id)}
disabled={deletingKeyId === key.id}
className="text-red-400 hover:text-red-300 font-medium disabled:opacity-50 disabled:cursor-not-allowed"
title="Delete API key"
>
{deletingKeyId === key.id ? 'Deleting...' : 'Delete'}
</button>
</td>
</tr>
);
@@ -353,9 +495,7 @@ export default function AdminPanel() {
<div className="bg-gray-800 rounded-lg shadow-md p-6 border border-gray-700">
<h2 className="text-xl font-semibold mb-4 text-gray-100">Runner Management</h2>
{loading ? (
<div className="flex justify-center py-8">
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-orange-500"></div>
</div>
<LoadingSpinner size="sm" className="py-8" />
) : !runners || runners.length === 0 ? (
<p className="text-gray-400 text-center py-8">No runners registered.</p>
) : (
@@ -369,14 +509,11 @@ export default function AdminPanel() {
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
Hostname
</th>
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
IP Address
</th>
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
Status
</th>
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
Verified
API Key
</th>
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
Priority
@@ -403,9 +540,6 @@ export default function AdminPanel() {
<td className="px-6 py-4 whitespace-nowrap text-sm text-gray-400">
{runner.hostname}
</td>
<td className="px-6 py-4 whitespace-nowrap text-sm text-gray-400">
{runner.ip_address}
</td>
<td className="px-6 py-4 whitespace-nowrap">
<span
className={`px-2 py-1 text-xs font-medium rounded-full ${
@@ -417,16 +551,10 @@ export default function AdminPanel() {
{isOnline ? 'Online' : 'Offline'}
</span>
</td>
<td className="px-6 py-4 whitespace-nowrap">
<span
className={`px-2 py-1 text-xs font-medium rounded-full ${
runner.verified
? 'bg-green-400/20 text-green-400'
: 'bg-yellow-400/20 text-yellow-400'
}`}
>
{runner.verified ? 'Verified' : 'Unverified'}
</span>
<td className="px-6 py-4 whitespace-nowrap text-sm text-gray-400">
<code className="text-xs font-mono bg-gray-900 px-2 py-1 rounded">
jk_r{runner.id % 10}_...
</code>
</td>
<td className="px-6 py-4 whitespace-nowrap text-sm text-gray-400">
{runner.priority}
@@ -452,20 +580,13 @@ export default function AdminPanel() {
<td className="px-6 py-4 whitespace-nowrap text-sm text-gray-400">
{new Date(runner.last_heartbeat).toLocaleString()}
</td>
<td className="px-6 py-4 whitespace-nowrap text-sm space-x-2">
{!runner.verified && (
<button
onClick={() => verifyRunner(runner.id)}
className="text-orange-400 hover:text-orange-300 font-medium"
>
Verify
</button>
)}
<td className="px-6 py-4 whitespace-nowrap text-sm">
<button
onClick={() => deleteRunner(runner.id)}
className="text-red-400 hover:text-red-300 font-medium"
disabled={deletingRunnerId === runner.id}
className="text-red-400 hover:text-red-300 font-medium disabled:opacity-50 disabled:cursor-not-allowed"
>
Delete
{deletingRunnerId === runner.id ? 'Deleting...' : 'Delete'}
</button>
</td>
</tr>
@@ -515,9 +636,7 @@ export default function AdminPanel() {
<div className="bg-gray-800 rounded-lg shadow-md p-6 border border-gray-700">
<h2 className="text-xl font-semibold mb-4 text-gray-100">User Management</h2>
{loading ? (
<div className="flex justify-center py-8">
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-orange-500"></div>
</div>
<LoadingSpinner size="sm" className="py-8" />
) : !users || users.length === 0 ? (
<p className="text-gray-400 text-center py-8">No users found.</p>
) : (

View File

@@ -0,0 +1,41 @@
import React from 'react';
class ErrorBoundary extends React.Component {
constructor(props) {
super(props);
this.state = { hasError: false, error: null };
}
static getDerivedStateFromError(error) {
return { hasError: true, error };
}
componentDidCatch(error, errorInfo) {
console.error('ErrorBoundary caught an error:', error, errorInfo);
}
render() {
if (this.state.hasError) {
return (
<div className="p-6 bg-red-400/20 border border-red-400/50 rounded-lg text-red-400">
<h2 className="text-xl font-semibold mb-2">Something went wrong</h2>
<p className="mb-4">{this.state.error?.message || 'An unexpected error occurred'}</p>
<button
onClick={() => {
this.setState({ hasError: false, error: null });
window.location.reload();
}}
className="px-4 py-2 bg-red-600 text-white rounded-lg hover:bg-red-500 transition-colors"
>
Reload Page
</button>
</div>
);
}
return this.props.children;
}
}
export default ErrorBoundary;

View File

@@ -0,0 +1,26 @@
import React from 'react';
/**
* Shared ErrorMessage component for consistent error display
* Sanitizes error messages to prevent XSS
*/
export default function ErrorMessage({ error, className = '' }) {
if (!error) return null;
// Sanitize error message - escape HTML entities
const sanitize = (text) => {
const div = document.createElement('div');
div.textContent = text;
return div.innerHTML;
};
const sanitizedError = typeof error === 'string' ? sanitize(error) : sanitize(error.message || 'An error occurred');
return (
<div className={`p-4 bg-red-400/20 border border-red-400/50 rounded-lg text-red-400 ${className}`}>
<p className="font-semibold">Error:</p>
<p dangerouslySetInnerHTML={{ __html: sanitizedError }} />
</div>
);
}

View File

@@ -0,0 +1,191 @@
import { useState } from 'react';
export default function FileExplorer({ files, onDownload, onPreview, onVideoPreview, isImageFile }) {
const [expandedPaths, setExpandedPaths] = useState(new Set()); // Root folder collapsed by default
// Build directory tree from file paths
const buildTree = (files) => {
const tree = {};
files.forEach(file => {
const path = file.file_name;
// Handle both paths with slashes and single filenames
const parts = path.includes('/') ? path.split('/').filter(p => p) : [path];
// If it's a single file at root (no slashes), treat it specially
if (parts.length === 1 && !path.includes('/')) {
tree[parts[0]] = {
name: parts[0],
isFile: true,
file: file,
children: {},
path: parts[0]
};
return;
}
let current = tree;
parts.forEach((part, index) => {
if (!current[part]) {
current[part] = {
name: part,
isFile: index === parts.length - 1,
file: index === parts.length - 1 ? file : null,
children: {},
path: parts.slice(0, index + 1).join('/')
};
}
current = current[part].children;
});
});
return tree;
};
const togglePath = (path) => {
const newExpanded = new Set(expandedPaths);
if (newExpanded.has(path)) {
newExpanded.delete(path);
} else {
newExpanded.add(path);
}
setExpandedPaths(newExpanded);
};
const renderTree = (node, level = 0, parentPath = '') => {
const items = Object.values(node).sort((a, b) => {
// Directories first, then files
if (a.isFile !== b.isFile) {
return a.isFile ? 1 : -1;
}
return a.name.localeCompare(b.name);
});
return items.map((item) => {
const fullPath = parentPath ? `${parentPath}/${item.name}` : item.name;
const isExpanded = expandedPaths.has(fullPath);
const indent = level * 20;
if (item.isFile) {
const file = item.file;
const isImage = isImageFile && isImageFile(file.file_name);
const isVideo = file.file_name.toLowerCase().endsWith('.mp4');
const sizeMB = (file.file_size / 1024 / 1024).toFixed(2);
const isArchive = file.file_name.endsWith('.tar') || file.file_name.endsWith('.zip');
return (
<div key={fullPath} className="flex items-center justify-between py-1.5 hover:bg-gray-800/50 rounded px-2" style={{ paddingLeft: `${indent + 8}px` }}>
<div className="flex items-center gap-2 flex-1 min-w-0">
<span className="text-gray-500 text-sm">{isArchive ? '📦' : isVideo ? '🎬' : '📄'}</span>
<span className="text-gray-200 text-sm truncate" title={item.name}>
{item.name}
</span>
<span className="text-gray-500 text-xs ml-2">{sizeMB} MB</span>
</div>
<div className="flex gap-2 ml-4 shrink-0">
{isVideo && onVideoPreview && (
<button
onClick={() => onVideoPreview(file)}
className="px-2 py-1 bg-purple-600 text-white rounded text-xs hover:bg-purple-500 transition-colors"
title="Play Video"
>
</button>
)}
{isImage && onPreview && (
<button
onClick={() => onPreview(file)}
className="px-2 py-1 bg-blue-600 text-white rounded text-xs hover:bg-blue-500 transition-colors"
title="Preview"
>
👁
</button>
)}
{onDownload && file.id && (
<button
onClick={() => onDownload(file.id, file.file_name)}
className="px-2 py-1 bg-orange-600 text-white rounded text-xs hover:bg-orange-500 transition-colors"
title="Download"
>
</button>
)}
</div>
</div>
);
} else {
const hasChildren = Object.keys(item.children).length > 0;
return (
<div key={fullPath}>
<div
className="flex items-center gap-2 py-1.5 hover:bg-gray-800/50 rounded px-2 cursor-pointer select-none"
style={{ paddingLeft: `${indent + 8}px` }}
onClick={() => hasChildren && togglePath(fullPath)}
>
<span className="text-gray-400 text-xs w-4 flex items-center justify-center">
{hasChildren ? (isExpanded ? '▼' : '▶') : '•'}
</span>
<span className="text-gray-500 text-sm">
{hasChildren ? (isExpanded ? '📂' : '📁') : '📁'}
</span>
<span className="text-gray-300 text-sm font-medium">{item.name}</span>
{hasChildren && (
<span className="text-gray-500 text-xs ml-2">
({Object.keys(item.children).length})
</span>
)}
</div>
{hasChildren && isExpanded && (
<div className="ml-2">
{renderTree(item.children, level + 1, fullPath)}
</div>
)}
</div>
);
}
});
};
const tree = buildTree(files);
if (Object.keys(tree).length === 0) {
return (
<div className="text-gray-400 text-sm py-4 text-center">
No files
</div>
);
}
// Wrap tree in a root folder
const rootExpanded = expandedPaths.has('');
return (
<div className="bg-gray-900 rounded-lg border border-gray-700 p-3">
<div className="space-y-1">
<div>
<div
className="flex items-center gap-2 py-1.5 hover:bg-gray-800/50 rounded px-2 cursor-pointer select-none"
onClick={() => togglePath('')}
>
<span className="text-gray-400 text-xs w-4 flex items-center justify-center">
{rootExpanded ? '▼' : '▶'}
</span>
<span className="text-gray-500 text-sm">
{rootExpanded ? '📂' : '📁'}
</span>
<span className="text-gray-300 text-sm font-medium">Files</span>
<span className="text-gray-500 text-xs ml-2">
({Object.keys(tree).length})
</span>
</div>
{rootExpanded && (
<div className="ml-2">
{renderTree(tree)}
</div>
)}
</div>
</div>
</div>
);
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,29 +1,142 @@
import { useState, useEffect } from 'react';
import { jobs } from '../utils/api';
import { useState, useEffect, useRef } from 'react';
import { jobs, normalizeArrayResponse } from '../utils/api';
import { wsManager } from '../utils/websocket';
import JobDetails from './JobDetails';
import LoadingSpinner from './LoadingSpinner';
export default function JobList() {
const [jobList, setJobList] = useState([]);
const [loading, setLoading] = useState(true);
const [selectedJob, setSelectedJob] = useState(null);
const [pagination, setPagination] = useState({ total: 0, limit: 50, offset: 0 });
const [hasMore, setHasMore] = useState(true);
const listenerIdRef = useRef(null);
useEffect(() => {
loadJobs();
const interval = setInterval(loadJobs, 5000);
return () => clearInterval(interval);
// Use shared WebSocket manager for real-time updates
listenerIdRef.current = wsManager.subscribe('joblist', {
open: () => {
console.log('JobList: Shared WebSocket connected');
// Load initial job list via HTTP to get current state
loadJobs();
},
message: (data) => {
console.log('JobList: Client WebSocket message received:', data.type, data.channel, data);
// Handle jobs channel messages (always broadcasted)
if (data.channel === 'jobs') {
if (data.type === 'job_update' && data.data) {
console.log('JobList: Updating job:', data.job_id, data.data);
// Update job in list
setJobList(prev => {
const prevArray = Array.isArray(prev) ? prev : [];
const index = prevArray.findIndex(j => j.id === data.job_id);
if (index >= 0) {
const updated = [...prevArray];
updated[index] = { ...updated[index], ...data.data };
console.log('JobList: Updated job at index', index, updated[index]);
return updated;
}
// If job not in current page, reload to get updated list
if (data.data.status === 'completed' || data.data.status === 'failed') {
loadJobs();
}
return prevArray;
});
} else if (data.type === 'job_created' && data.data) {
console.log('JobList: New job created:', data.job_id, data.data);
// New job created - add to list
setJobList(prev => {
const prevArray = Array.isArray(prev) ? prev : [];
// Check if job already exists (avoid duplicates)
if (prevArray.findIndex(j => j.id === data.job_id) >= 0) {
return prevArray;
}
// Add new job at the beginning
return [data.data, ...prevArray];
});
}
} else if (data.type === 'connected') {
// Connection established
console.log('JobList: WebSocket connected');
}
},
error: (error) => {
console.error('JobList: Shared WebSocket error:', error);
},
close: (event) => {
console.log('JobList: Shared WebSocket closed:', event);
}
});
// Ensure connection is established
wsManager.connect();
return () => {
if (listenerIdRef.current) {
wsManager.unsubscribe(listenerIdRef.current);
listenerIdRef.current = null;
}
};
}, []);
const loadJobs = async () => {
const loadJobs = async (append = false) => {
try {
const data = await jobs.list();
setJobList(data);
const offset = append ? pagination.offset + pagination.limit : 0;
const result = await jobs.listSummary({
limit: pagination.limit,
offset,
sort: 'created_at:desc'
});
// Handle both old format (array) and new format (object with data, total, etc.)
const jobsArray = normalizeArrayResponse(result);
const total = result.total !== undefined ? result.total : jobsArray.length;
if (append) {
setJobList(prev => {
const prevArray = Array.isArray(prev) ? prev : [];
return [...prevArray, ...jobsArray];
});
setPagination(prev => ({ ...prev, offset, total }));
} else {
setJobList(jobsArray);
setPagination({ total, limit: result.limit || pagination.limit, offset: result.offset || 0 });
}
setHasMore(offset + jobsArray.length < total);
} catch (error) {
console.error('Failed to load jobs:', error);
// Ensure jobList is always an array even on error
if (!append) {
setJobList([]);
}
} finally {
setLoading(false);
}
};
const loadMore = () => {
if (!loading && hasMore) {
loadJobs(true);
}
};
// Keep selectedJob in sync with the job list when it refreshes
useEffect(() => {
if (selectedJob && jobList.length > 0) {
const freshJob = jobList.find(j => j.id === selectedJob.id);
if (freshJob) {
// Update to the fresh object from the list to keep it in sync
setSelectedJob(freshJob);
} else {
// Job was deleted or no longer exists, clear selection
setSelectedJob(null);
}
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [jobList]); // Only depend on jobList, not selectedJob to avoid infinite loops
const handleCancel = async (jobId) => {
if (!confirm('Are you sure you want to cancel this job?')) return;
try {
@@ -37,12 +150,21 @@ export default function JobList() {
const handleDelete = async (jobId) => {
if (!confirm('Are you sure you want to permanently delete this job? This action cannot be undone.')) return;
try {
await jobs.delete(jobId);
loadJobs();
// Optimistically update the list
setJobList(prev => {
const prevArray = Array.isArray(prev) ? prev : [];
return prevArray.filter(j => j.id !== jobId);
});
if (selectedJob && selectedJob.id === jobId) {
setSelectedJob(null);
}
// Then actually delete
await jobs.delete(jobId);
// Reload to ensure consistency
loadJobs();
} catch (error) {
// On error, reload to restore correct state
loadJobs();
alert('Failed to delete job: ' + error.message);
}
};
@@ -58,12 +180,8 @@ export default function JobList() {
return colors[status] || colors.pending;
};
if (loading) {
return (
<div className="flex justify-center items-center h-64">
<div className="animate-spin rounded-full h-12 w-12 border-b-2 border-orange-500"></div>
</div>
);
if (loading && jobList.length === 0) {
return <LoadingSpinner size="md" className="h-64" />;
}
if (jobList.length === 0) {
@@ -90,8 +208,10 @@ export default function JobList() {
</div>
<div className="space-y-2 text-sm text-gray-400 mb-4">
<p>Frames: {job.frame_start} - {job.frame_end}</p>
<p>Format: {job.output_format}</p>
{job.frame_start !== undefined && job.frame_end !== undefined && (
<p>Frames: {job.frame_start} - {job.frame_end}</p>
)}
{job.output_format && <p>Format: {job.output_format}</p>}
<p>Created: {new Date(job.created_at).toLocaleString()}</p>
</div>
@@ -110,7 +230,15 @@ export default function JobList() {
<div className="flex gap-2">
<button
onClick={() => setSelectedJob(job)}
onClick={() => {
// Fetch full job details when viewing
jobs.get(job.id).then(fullJob => {
setSelectedJob(fullJob);
}).catch(err => {
console.error('Failed to load job details:', err);
setSelectedJob(job); // Fallback to summary
});
}}
className="flex-1 px-4 py-2 bg-orange-600 text-white rounded-lg hover:bg-orange-500 transition-colors font-medium"
>
View Details
@@ -137,6 +265,18 @@ export default function JobList() {
))}
</div>
{hasMore && (
<div className="flex justify-center mt-6">
<button
onClick={loadMore}
disabled={loading}
className="px-6 py-2 bg-gray-700 text-gray-200 rounded-lg hover:bg-gray-600 transition-colors font-medium disabled:opacity-50"
>
{loading ? 'Loading...' : 'Load More'}
</button>
</div>
)}
{selectedJob && (
<JobDetails
job={selectedJob}
@@ -147,4 +287,3 @@ export default function JobList() {
</>
);
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,19 @@
import React from 'react';
/**
* Shared LoadingSpinner component with size variants
*/
export default function LoadingSpinner({ size = 'md', className = '', borderColor = 'border-orange-500' }) {
const sizeClasses = {
sm: 'h-8 w-8',
md: 'h-12 w-12',
lg: 'h-16 w-16',
};
return (
<div className={`flex justify-center items-center ${className}`}>
<div className={`animate-spin rounded-full border-b-2 ${borderColor} ${sizeClasses[size]}`}></div>
</div>
);
}

View File

@@ -1,5 +1,6 @@
import { useState, useEffect } from 'react';
import { auth } from '../utils/api';
import ErrorMessage from './ErrorMessage';
export default function Login() {
const [providers, setProviders] = useState({
@@ -92,11 +93,7 @@ export default function Login() {
</div>
<div className="space-y-4">
{error && (
<div className="p-4 bg-red-400/20 border border-red-400/50 rounded-lg text-red-400 text-sm">
{error}
</div>
)}
<ErrorMessage error={error} className="text-sm" />
{providers.local && (
<div className="pb-4 border-b border-gray-700">
<div className="flex gap-2 mb-4">

View File

@@ -1,5 +1,6 @@
import { useState } from 'react';
import { auth } from '../utils/api';
import ErrorMessage from './ErrorMessage';
import { useAuth } from '../hooks/useAuth';
export default function PasswordChange({ targetUserId = null, targetUserName = null, onSuccess }) {
@@ -64,11 +65,7 @@ export default function PasswordChange({ targetUserId = null, targetUserName = n
{isChangingOtherUser ? `Change Password for ${targetUserName || 'User'}` : 'Change Password'}
</h2>
{error && (
<div className="mb-4 p-3 bg-red-400/20 border border-red-400/50 rounded-lg text-red-400 text-sm">
{error}
</div>
)}
<ErrorMessage error={error} className="mb-4 text-sm" />
{success && (
<div className="mb-4 p-3 bg-green-400/20 border border-green-400/50 rounded-lg text-green-400 text-sm">

View File

@@ -1,22 +1,71 @@
import { useState, useEffect } from 'react';
import { admin } from '../utils/api';
import { useState, useEffect, useRef } from 'react';
import { admin, normalizeArrayResponse } from '../utils/api';
import { wsManager } from '../utils/websocket';
import JobDetails from './JobDetails';
import LoadingSpinner from './LoadingSpinner';
export default function UserJobs({ userId, userName, onBack }) {
const [jobList, setJobList] = useState([]);
const [loading, setLoading] = useState(true);
const [selectedJob, setSelectedJob] = useState(null);
const listenerIdRef = useRef(null);
useEffect(() => {
loadJobs();
const interval = setInterval(loadJobs, 5000);
return () => clearInterval(interval);
// Use shared WebSocket manager for real-time updates instead of polling
listenerIdRef.current = wsManager.subscribe(`userjobs_${userId}`, {
open: () => {
console.log('UserJobs: Shared WebSocket connected');
loadJobs();
},
message: (data) => {
// Handle jobs channel messages (always broadcasted)
if (data.channel === 'jobs') {
if (data.type === 'job_update' && data.data) {
// Update job in list if it belongs to this user
setJobList(prev => {
const prevArray = Array.isArray(prev) ? prev : [];
const index = prevArray.findIndex(j => j.id === data.job_id);
if (index >= 0) {
const updated = [...prevArray];
updated[index] = { ...updated[index], ...data.data };
return updated;
}
// If job not in current list, reload to get updated list
if (data.data.status === 'completed' || data.data.status === 'failed') {
loadJobs();
}
return prevArray;
});
} else if (data.type === 'job_created' && data.data) {
// New job created - reload to check if it belongs to this user
loadJobs();
}
}
},
error: (error) => {
console.error('UserJobs: Shared WebSocket error:', error);
},
close: (event) => {
console.log('UserJobs: Shared WebSocket closed:', event);
}
});
// Ensure connection is established
wsManager.connect();
return () => {
if (listenerIdRef.current) {
wsManager.unsubscribe(listenerIdRef.current);
listenerIdRef.current = null;
}
};
}, [userId]);
const loadJobs = async () => {
try {
const data = await admin.getUserJobs(userId);
setJobList(Array.isArray(data) ? data : []);
setJobList(normalizeArrayResponse(data));
} catch (error) {
console.error('Failed to load jobs:', error);
setJobList([]);
@@ -47,11 +96,7 @@ export default function UserJobs({ userId, userName, onBack }) {
}
if (loading) {
return (
<div className="flex justify-center items-center h-64">
<div className="animate-spin rounded-full h-12 w-12 border-b-2 border-orange-500"></div>
</div>
);
return <LoadingSpinner size="md" className="h-64" />;
}
return (

View File

@@ -1,4 +1,6 @@
import { useState, useRef, useEffect } from 'react';
import ErrorMessage from './ErrorMessage';
import LoadingSpinner from './LoadingSpinner';
export default function VideoPlayer({ videoUrl, onClose }) {
const videoRef = useRef(null);
@@ -55,10 +57,10 @@ export default function VideoPlayer({ videoUrl, onClose }) {
if (error) {
return (
<div className="bg-red-50 border border-red-200 rounded-lg p-4 text-red-700">
{error}
<div className="mt-2 text-sm text-red-600">
<a href={videoUrl} download className="underline">Download video instead</a>
<div>
<ErrorMessage error={error} />
<div className="mt-2 text-sm text-gray-400">
<a href={videoUrl} download className="text-orange-400 hover:text-orange-300 underline">Download video instead</a>
</div>
</div>
);
@@ -68,7 +70,7 @@ export default function VideoPlayer({ videoUrl, onClose }) {
<div className="relative bg-black rounded-lg overflow-hidden">
{loading && (
<div className="absolute inset-0 flex items-center justify-center bg-black bg-opacity-50 z-10">
<div className="animate-spin rounded-full h-12 w-12 border-b-2 border-white"></div>
<LoadingSpinner size="lg" className="border-white" />
</div>
)}
<video

View File

@@ -3,12 +3,96 @@ const API_BASE = '/api';
// Global auth error handler - will be set by useAuth hook
let onAuthError = null;
// Request debouncing and deduplication
const pendingRequests = new Map(); // key: endpoint+params, value: Promise
const requestQueue = new Map(); // key: endpoint+params, value: { resolve, reject, timestamp }
const DEBOUNCE_DELAY = 100; // 100ms debounce delay
const DEDUPE_WINDOW = 5000; // 5 seconds - same request within this window uses cached promise
// Generate cache key from endpoint and params
function getCacheKey(endpoint, options = {}) {
const params = new URLSearchParams();
Object.keys(options).sort().forEach(key => {
if (options[key] !== undefined && options[key] !== null) {
params.append(key, String(options[key]));
}
});
const query = params.toString();
return `${endpoint}${query ? '?' + query : ''}`;
}
// Utility function to normalize array responses (handles both old and new formats)
export function normalizeArrayResponse(response) {
const data = response?.data || response;
return Array.isArray(data) ? data : [];
}
// Sentinel value to indicate a request was superseded (instead of rejecting)
// Export it so components can check for it
export const REQUEST_SUPERSEDED = Symbol('REQUEST_SUPERSEDED');
// Debounced request wrapper
function debounceRequest(key, requestFn, delay = DEBOUNCE_DELAY) {
return new Promise((resolve, reject) => {
// Check if there's a pending request for this key
if (pendingRequests.has(key)) {
const pending = pendingRequests.get(key);
// If request is very recent (within dedupe window), reuse it
const now = Date.now();
if (pending.timestamp && (now - pending.timestamp) < DEDUPE_WINDOW) {
pending.promise.then(resolve).catch(reject);
return;
} else {
// Request is older than dedupe window - remove it and create new one
pendingRequests.delete(key);
}
}
// Clear any existing timeout for this key
if (requestQueue.has(key)) {
const queued = requestQueue.get(key);
clearTimeout(queued.timeout);
// Resolve with sentinel value instead of rejecting - this prevents errors from propagating
// The new request will handle the actual response
queued.resolve(REQUEST_SUPERSEDED);
}
// Queue new request
const timeout = setTimeout(() => {
requestQueue.delete(key);
const promise = requestFn();
const timestamp = Date.now();
pendingRequests.set(key, { promise, timestamp });
promise
.then(result => {
pendingRequests.delete(key);
resolve(result);
})
.catch(error => {
pendingRequests.delete(key);
reject(error);
});
}, delay);
requestQueue.set(key, { resolve, reject, timeout });
});
}
export const setAuthErrorHandler = (handler) => {
onAuthError = handler;
};
const handleAuthError = (response) => {
// Whitelist of endpoints that should NOT trigger auth error handling
// These are endpoints that can legitimately return 401/403 without meaning the user is logged out
const AUTH_CHECK_ENDPOINTS = ['/auth/me', '/auth/logout'];
const handleAuthError = (response, endpoint) => {
if (response.status === 401 || response.status === 403) {
// Don't trigger auth error handler for endpoints that check auth status
if (AUTH_CHECK_ENDPOINTS.includes(endpoint)) {
return;
}
// Trigger auth error handler if set (this will clear user state)
if (onAuthError) {
onAuthError();
@@ -22,60 +106,79 @@ const handleAuthError = (response) => {
}
};
// Extract error message from response - centralized to avoid duplication
async function extractErrorMessage(response) {
try {
const errorData = await response.json();
return errorData?.error || response.statusText;
} catch {
return response.statusText;
}
}
export const api = {
async get(endpoint) {
async get(endpoint, options = {}) {
const abortController = options.signal || new AbortController();
const response = await fetch(`${API_BASE}${endpoint}`, {
credentials: 'include', // Include cookies for session
signal: abortController.signal,
});
if (!response.ok) {
// Handle auth errors before parsing response
// Don't redirect on /auth/me - that's the auth check itself
if ((response.status === 401 || response.status === 403) && !endpoint.startsWith('/auth/')) {
handleAuthError(response);
// Don't redirect - let React handle UI change through state
}
const errorData = await response.json().catch(() => null);
const errorMessage = errorData?.error || response.statusText;
handleAuthError(response, endpoint);
const errorMessage = await extractErrorMessage(response);
throw new Error(errorMessage);
}
return response.json();
},
async post(endpoint, data) {
async post(endpoint, data, options = {}) {
const abortController = options.signal || new AbortController();
const response = await fetch(`${API_BASE}${endpoint}`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(data),
credentials: 'include', // Include cookies for session
signal: abortController.signal,
});
if (!response.ok) {
// Handle auth errors before parsing response
// Don't redirect on /auth/* endpoints - those are login/logout
if ((response.status === 401 || response.status === 403) && !endpoint.startsWith('/auth/')) {
handleAuthError(response);
// Don't redirect - let React handle UI change through state
}
const errorData = await response.json().catch(() => null);
const errorMessage = errorData?.error || response.statusText;
handleAuthError(response, endpoint);
const errorMessage = await extractErrorMessage(response);
throw new Error(errorMessage);
}
return response.json();
},
async delete(endpoint) {
async patch(endpoint, data, options = {}) {
const abortController = options.signal || new AbortController();
const response = await fetch(`${API_BASE}${endpoint}`, {
method: 'DELETE',
method: 'PATCH',
headers: { 'Content-Type': 'application/json' },
body: data ? JSON.stringify(data) : undefined,
credentials: 'include', // Include cookies for session
signal: abortController.signal,
});
if (!response.ok) {
// Handle auth errors before parsing response
// Don't redirect on /auth/* endpoints
if ((response.status === 401 || response.status === 403) && !endpoint.startsWith('/auth/')) {
handleAuthError(response);
// Don't redirect - let React handle UI change through state
}
const errorData = await response.json().catch(() => null);
const errorMessage = errorData?.error || response.statusText;
handleAuthError(response, endpoint);
const errorMessage = await extractErrorMessage(response);
throw new Error(errorMessage);
}
return response.json();
},
async delete(endpoint, options = {}) {
const abortController = options.signal || new AbortController();
const response = await fetch(`${API_BASE}${endpoint}`, {
method: 'DELETE',
credentials: 'include', // Include cookies for session
signal: abortController.signal,
});
if (!response.ok) {
// Handle auth errors before parsing response
handleAuthError(response, endpoint);
const errorMessage = await extractErrorMessage(response);
throw new Error(errorMessage);
}
return response.json();
@@ -112,8 +215,7 @@ export const api = {
} else {
// Handle auth errors
if (xhr.status === 401 || xhr.status === 403) {
handleAuthError({ status: xhr.status });
// Don't redirect - let React handle UI change through state
handleAuthError({ status: xhr.status }, endpoint);
}
try {
const errorData = JSON.parse(xhr.responseText);
@@ -174,12 +276,53 @@ export const auth = {
};
export const jobs = {
async list() {
return api.get('/jobs');
async list(options = {}) {
const key = getCacheKey('/jobs', options);
return debounceRequest(key, () => {
const params = new URLSearchParams();
if (options.limit) params.append('limit', options.limit.toString());
if (options.offset) params.append('offset', options.offset.toString());
if (options.status) params.append('status', options.status);
if (options.sort) params.append('sort', options.sort);
const query = params.toString();
return api.get(`/jobs${query ? '?' + query : ''}`);
});
},
async get(id) {
return api.get(`/jobs/${id}`);
async listSummary(options = {}) {
const key = getCacheKey('/jobs/summary', options);
return debounceRequest(key, () => {
const params = new URLSearchParams();
if (options.limit) params.append('limit', options.limit.toString());
if (options.offset) params.append('offset', options.offset.toString());
if (options.status) params.append('status', options.status);
if (options.sort) params.append('sort', options.sort);
const query = params.toString();
return api.get(`/jobs/summary${query ? '?' + query : ''}`, options);
});
},
async get(id, options = {}) {
const key = getCacheKey(`/jobs/${id}`, options);
return debounceRequest(key, async () => {
if (options.etag) {
// Include ETag in request headers for conditional requests
const headers = { 'If-None-Match': options.etag };
const response = await fetch(`${API_BASE}/jobs/${id}`, {
credentials: 'include',
headers,
});
if (response.status === 304) {
return null; // Not modified
}
if (!response.ok) {
const errorData = await response.json().catch(() => null);
throw new Error(errorData?.error || response.statusText);
}
return response.json();
}
return api.get(`/jobs/${id}`, options);
});
},
async create(jobData) {
@@ -198,31 +341,82 @@ export const jobs = {
return api.uploadFile(`/jobs/${jobId}/upload`, file, onProgress, mainBlendFile);
},
async getFiles(jobId) {
return api.get(`/jobs/${jobId}/files`);
async uploadFileForJobCreation(file, onProgress, mainBlendFile) {
return api.uploadFile(`/jobs/upload`, file, onProgress, mainBlendFile);
},
async getFiles(jobId, options = {}) {
const key = getCacheKey(`/jobs/${jobId}/files`, options);
return debounceRequest(key, () => {
const params = new URLSearchParams();
if (options.limit) params.append('limit', options.limit.toString());
if (options.offset) params.append('offset', options.offset.toString());
if (options.file_type) params.append('file_type', options.file_type);
if (options.extension) params.append('extension', options.extension);
const query = params.toString();
return api.get(`/jobs/${jobId}/files${query ? '?' + query : ''}`, options);
});
},
async getFilesCount(jobId, options = {}) {
const key = getCacheKey(`/jobs/${jobId}/files/count`, options);
return debounceRequest(key, () => {
const params = new URLSearchParams();
if (options.file_type) params.append('file_type', options.file_type);
const query = params.toString();
return api.get(`/jobs/${jobId}/files/count${query ? '?' + query : ''}`);
});
},
async getContextArchive(jobId, options = {}) {
return api.get(`/jobs/${jobId}/context`, options);
},
downloadFile(jobId, fileId) {
return `${API_BASE}/jobs/${jobId}/files/${fileId}/download`;
},
previewEXR(jobId, fileId) {
return `${API_BASE}/jobs/${jobId}/files/${fileId}/preview-exr`;
},
getVideoUrl(jobId) {
return `${API_BASE}/jobs/${jobId}/video`;
},
async getTaskLogs(jobId, taskId, options = {}) {
const params = new URLSearchParams();
if (options.stepName) params.append('step_name', options.stepName);
if (options.logLevel) params.append('log_level', options.logLevel);
if (options.limit) params.append('limit', options.limit.toString());
const query = params.toString();
return api.get(`/jobs/${jobId}/tasks/${taskId}/logs${query ? '?' + query : ''}`);
const key = getCacheKey(`/jobs/${jobId}/tasks/${taskId}/logs`, options);
return debounceRequest(key, async () => {
const params = new URLSearchParams();
if (options.stepName) params.append('step_name', options.stepName);
if (options.logLevel) params.append('log_level', options.logLevel);
if (options.limit) params.append('limit', options.limit.toString());
if (options.sinceId) params.append('since_id', options.sinceId.toString());
const query = params.toString();
const result = await api.get(`/jobs/${jobId}/tasks/${taskId}/logs${query ? '?' + query : ''}`, options);
// Handle both old format (array) and new format (object with logs, last_id, limit)
if (Array.isArray(result)) {
return { logs: result, last_id: result.length > 0 ? result[result.length - 1].id : 0, limit: options.limit || 100 };
}
return result;
});
},
async getTaskSteps(jobId, taskId) {
return api.get(`/jobs/${jobId}/tasks/${taskId}/steps`);
async getTaskSteps(jobId, taskId, options = {}) {
return api.get(`/jobs/${jobId}/tasks/${taskId}/steps`, options);
},
// New unified client WebSocket - DEPRECATED: Use wsManager from websocket.js instead
// This is kept for backwards compatibility but should not be used
streamClientWebSocket() {
console.warn('streamClientWebSocket() is deprecated - use wsManager from websocket.js instead');
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsHost = window.location.host;
const url = `${wsProtocol}//${wsHost}${API_BASE}/ws`;
return new WebSocket(url);
},
// Old WebSocket methods (to be removed after migration)
streamTaskLogsWebSocket(jobId, taskId, lastId = 0) {
// Convert HTTP to WebSocket URL
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
@@ -231,6 +425,20 @@ export const jobs = {
return new WebSocket(url);
},
streamJobsWebSocket() {
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsHost = window.location.host;
const url = `${wsProtocol}//${wsHost}${API_BASE}/jobs/ws-old`;
return new WebSocket(url);
},
streamJobWebSocket(jobId) {
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsHost = window.location.host;
const url = `${wsProtocol}//${wsHost}${API_BASE}/jobs/${jobId}/ws`;
return new WebSocket(url);
},
async retryTask(jobId, taskId) {
return api.post(`/jobs/${jobId}/tasks/${taskId}/retry`);
},
@@ -239,8 +447,50 @@ export const jobs = {
return api.get(`/jobs/${jobId}/metadata`);
},
async getTasks(jobId) {
return api.get(`/jobs/${jobId}/tasks`);
async getTasks(jobId, options = {}) {
const key = getCacheKey(`/jobs/${jobId}/tasks`, options);
return debounceRequest(key, () => {
const params = new URLSearchParams();
if (options.limit) params.append('limit', options.limit.toString());
if (options.offset) params.append('offset', options.offset.toString());
if (options.status) params.append('status', options.status);
if (options.frameStart) params.append('frame_start', options.frameStart.toString());
if (options.frameEnd) params.append('frame_end', options.frameEnd.toString());
if (options.sort) params.append('sort', options.sort);
const query = params.toString();
return api.get(`/jobs/${jobId}/tasks${query ? '?' + query : ''}`, options);
});
},
async getTasksSummary(jobId, options = {}) {
const key = getCacheKey(`/jobs/${jobId}/tasks/summary`, options);
return debounceRequest(key, () => {
const params = new URLSearchParams();
if (options.limit) params.append('limit', options.limit.toString());
if (options.offset) params.append('offset', options.offset.toString());
if (options.status) params.append('status', options.status);
if (options.sort) params.append('sort', options.sort);
const query = params.toString();
return api.get(`/jobs/${jobId}/tasks/summary${query ? '?' + query : ''}`, options);
});
},
async batchGetJobs(jobIds) {
// Sort jobIds for consistent cache key
const sortedIds = [...jobIds].sort((a, b) => a - b);
const key = getCacheKey('/jobs/batch', { job_ids: sortedIds.join(',') });
return debounceRequest(key, () => {
return api.post('/jobs/batch', { job_ids: jobIds });
});
},
async batchGetTasks(jobId, taskIds) {
// Sort taskIds for consistent cache key
const sortedIds = [...taskIds].sort((a, b) => a - b);
const key = getCacheKey(`/jobs/${jobId}/tasks/batch`, { task_ids: sortedIds.join(',') });
return debounceRequest(key, () => {
return api.post(`/jobs/${jobId}/tasks/batch`, { task_ids: taskIds });
});
},
};
@@ -249,16 +499,22 @@ export const runners = {
};
export const admin = {
async generateToken(expiresInHours) {
return api.post('/admin/runners/tokens', { expires_in_hours: expiresInHours });
async generateAPIKey(name, description, scope) {
const data = { name, scope };
if (description) data.description = description;
return api.post('/admin/runners/api-keys', data);
},
async listTokens() {
return api.get('/admin/runners/tokens');
async listAPIKeys() {
return api.get('/admin/runners/api-keys');
},
async revokeToken(tokenId) {
return api.delete(`/admin/runners/tokens/${tokenId}`);
async revokeAPIKey(keyId) {
return api.patch(`/admin/runners/api-keys/${keyId}/revoke`);
},
async deleteAPIKey(keyId) {
return api.delete(`/admin/runners/api-keys/${keyId}`);
},
async listRunners() {

271
web/src/utils/websocket.js Normal file
View File

@@ -0,0 +1,271 @@
// Shared WebSocket connection manager
// All components should use this instead of creating their own connections
class WebSocketManager {
constructor() {
this.ws = null;
this.listeners = new Map(); // Map of listener IDs to callback functions
this.reconnectTimeout = null;
this.reconnectDelay = 2000;
this.isConnecting = false;
this.listenerIdCounter = 0;
this.verboseLogging = false; // Set to true to enable verbose WebSocket logging
// Track server-side channel subscriptions for re-subscription on reconnect
this.serverSubscriptions = new Set(); // Channels we want to be subscribed to
this.confirmedSubscriptions = new Set(); // Channels confirmed by server
this.pendingSubscriptions = new Set(); // Channels waiting for confirmation
}
connect() {
// If already connected or connecting, don't create a new connection
if (this.ws && (this.ws.readyState === WebSocket.CONNECTING || this.ws.readyState === WebSocket.OPEN)) {
return;
}
if (this.isConnecting) {
return;
}
this.isConnecting = true;
try {
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsHost = window.location.host;
const API_BASE = '/api';
const url = `${wsProtocol}//${wsHost}${API_BASE}/jobs/ws`;
this.ws = new WebSocket(url);
this.ws.onopen = () => {
if (this.verboseLogging) {
console.log('Shared WebSocket connected');
}
this.isConnecting = false;
// Re-subscribe to all channels that were previously subscribed
this.resubscribeToChannels();
this.notifyListeners('open', {});
};
this.ws.onmessage = (event) => {
try {
const data = JSON.parse(event.data);
if (this.verboseLogging) {
console.log('WebSocketManager: Message received:', data.type, data.channel || 'no channel', data);
}
this.notifyListeners('message', data);
} catch (error) {
console.error('WebSocketManager: Failed to parse message:', error, 'Raw data:', event.data);
}
};
this.ws.onerror = (error) => {
console.error('Shared WebSocket error:', error);
this.isConnecting = false;
this.notifyListeners('error', error);
};
this.ws.onclose = (event) => {
if (this.verboseLogging) {
console.log('Shared WebSocket closed:', {
code: event.code,
reason: event.reason,
wasClean: event.wasClean
});
}
this.ws = null;
this.isConnecting = false;
// Clear confirmed/pending but keep serverSubscriptions for re-subscription
this.confirmedSubscriptions.clear();
this.pendingSubscriptions.clear();
this.notifyListeners('close', event);
// Always retry connection if we have listeners
if (this.listeners.size > 0) {
if (this.reconnectTimeout) {
clearTimeout(this.reconnectTimeout);
}
this.reconnectTimeout = setTimeout(() => {
if (!this.ws || this.ws.readyState === WebSocket.CLOSED) {
this.connect();
}
}, this.reconnectDelay);
}
};
} catch (error) {
console.error('Failed to create WebSocket:', error);
this.isConnecting = false;
// Retry after delay
this.reconnectTimeout = setTimeout(() => {
this.connect();
}, this.reconnectDelay);
}
}
subscribe(listenerId, callbacks) {
// Generate ID if not provided
if (!listenerId) {
listenerId = `listener_${this.listenerIdCounter++}`;
}
if (this.verboseLogging) {
console.log('WebSocketManager: Subscribing listener:', listenerId, 'WebSocket state:', this.ws ? this.ws.readyState : 'no connection');
}
this.listeners.set(listenerId, callbacks);
// Connect if not already connected
if (!this.ws || this.ws.readyState === WebSocket.CLOSED) {
if (this.verboseLogging) {
console.log('WebSocketManager: WebSocket not connected, connecting...');
}
this.connect();
}
// If already open, notify immediately
if (this.ws && this.ws.readyState === WebSocket.OPEN && callbacks.open) {
if (this.verboseLogging) {
console.log('WebSocketManager: WebSocket already open, calling open callback for listener:', listenerId);
}
// Use setTimeout to ensure this happens after the listener is registered
setTimeout(() => {
if (callbacks.open) {
callbacks.open();
}
}, 0);
}
return listenerId;
}
unsubscribe(listenerId) {
this.listeners.delete(listenerId);
// If no more listeners, we could close the connection, but let's keep it open
// in case other components need it
}
send(data) {
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
if (this.verboseLogging) {
console.log('WebSocketManager: Sending message:', data);
}
this.ws.send(JSON.stringify(data));
} else {
console.warn('WebSocketManager: Cannot send message - connection not open. State:', this.ws ? this.ws.readyState : 'no connection', 'Message:', data);
}
}
notifyListeners(eventType, data) {
this.listeners.forEach((callbacks) => {
if (callbacks[eventType]) {
try {
callbacks[eventType](data);
} catch (error) {
console.error('Error in WebSocket listener:', error);
}
}
});
}
getReadyState() {
return this.ws ? this.ws.readyState : WebSocket.CLOSED;
}
// Subscribe to a server-side channel (will be re-subscribed on reconnect)
subscribeToChannel(channel) {
if (this.serverSubscriptions.has(channel)) {
// Already subscribed or pending
return;
}
this.serverSubscriptions.add(channel);
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
if (!this.confirmedSubscriptions.has(channel) && !this.pendingSubscriptions.has(channel)) {
this.pendingSubscriptions.add(channel);
this.send({ type: 'subscribe', channel });
if (this.verboseLogging) {
console.log('WebSocketManager: Subscribing to channel:', channel);
}
}
}
}
// Unsubscribe from a server-side channel (won't be re-subscribed on reconnect)
unsubscribeFromChannel(channel) {
this.serverSubscriptions.delete(channel);
this.confirmedSubscriptions.delete(channel);
this.pendingSubscriptions.delete(channel);
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
this.send({ type: 'unsubscribe', channel });
if (this.verboseLogging) {
console.log('WebSocketManager: Unsubscribing from channel:', channel);
}
}
}
// Mark a channel subscription as confirmed (call this when server confirms)
confirmSubscription(channel) {
this.pendingSubscriptions.delete(channel);
this.confirmedSubscriptions.add(channel);
if (this.verboseLogging) {
console.log('WebSocketManager: Subscription confirmed for channel:', channel);
}
}
// Mark a channel subscription as failed (call this when server rejects)
failSubscription(channel) {
this.pendingSubscriptions.delete(channel);
this.serverSubscriptions.delete(channel);
if (this.verboseLogging) {
console.log('WebSocketManager: Subscription failed for channel:', channel);
}
}
// Check if subscribed to a channel
isSubscribedToChannel(channel) {
return this.confirmedSubscriptions.has(channel);
}
// Re-subscribe to all channels after reconnect
resubscribeToChannels() {
if (this.serverSubscriptions.size === 0) {
return;
}
if (this.verboseLogging) {
console.log('WebSocketManager: Re-subscribing to channels:', Array.from(this.serverSubscriptions));
}
for (const channel of this.serverSubscriptions) {
if (!this.pendingSubscriptions.has(channel)) {
this.pendingSubscriptions.add(channel);
this.send({ type: 'subscribe', channel });
}
}
}
disconnect() {
if (this.reconnectTimeout) {
clearTimeout(this.reconnectTimeout);
this.reconnectTimeout = null;
}
if (this.ws) {
this.ws.close();
this.ws = null;
}
this.listeners.clear();
this.serverSubscriptions.clear();
this.confirmedSubscriptions.clear();
this.pendingSubscriptions.clear();
}
}
// Export singleton instance
export const wsManager = new WebSocketManager();