Compare commits
21 Commits
2a0ff98834
...
0.0.2
| Author | SHA1 | Date | |
|---|---|---|---|
| bb57ce8659 | |||
| 1a8836e6aa | |||
| b51b96a618 | |||
| 8e561922c9 | |||
| 1c4bd78f56 | |||
| 3f2982ddb3 | |||
| 0b852c5087 | |||
| 5e56c7f0e8 | |||
| 0a8f40b9cb | |||
| 7440511740 | |||
| c7c8762164 | |||
| 94490237fe | |||
| edc8ea160c | |||
| 11e7552b5b | |||
| 690e6b13f8 | |||
| a53ea4dce7 | |||
| 3217bbfe4d | |||
| 4ac05d50a1 | |||
| a029714e08 | |||
| f9ff4d0138 | |||
| f7e1766d8b |
24
.gitea/workflows/release-tag.yaml
Normal file
24
.gitea/workflows/release-tag.yaml
Normal file
@@ -0,0 +1,24 @@
|
||||
name: Release Tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- run: git fetch --force --tags
|
||||
- uses: actions/setup-go@main
|
||||
with:
|
||||
go-version-file: 'go.mod'
|
||||
- uses: goreleaser/goreleaser-action@master
|
||||
with:
|
||||
distribution: goreleaser
|
||||
version: 'latest'
|
||||
args: release
|
||||
env:
|
||||
GITEA_TOKEN: ${{secrets.RELEASE_TOKEN}}
|
||||
17
.gitea/workflows/test-pr.yaml
Normal file
17
.gitea/workflows/test-pr.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
name: PR Check
|
||||
on:
|
||||
- pull_request
|
||||
|
||||
jobs:
|
||||
check-and-test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
- uses: actions/setup-go@main
|
||||
with:
|
||||
go-version-file: 'go.mod'
|
||||
- uses: FedericoCarboni/setup-ffmpeg@v3
|
||||
- run: go mod tidy
|
||||
- run: cd web && npm install && npm run build
|
||||
- run: go build ./...
|
||||
- run: go test -race -v -shuffle=on ./...
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -27,7 +27,11 @@ go.work
|
||||
jiggablend.db
|
||||
jiggablend.db.wal
|
||||
jiggablend.db-shm
|
||||
jiggablend.db-journal
|
||||
|
||||
# Log files
|
||||
*.log
|
||||
logs/
|
||||
# Secrets and configuration
|
||||
runner-secrets.json
|
||||
runner-secrets-*.json
|
||||
@@ -61,6 +65,7 @@ lerna-debug.log*
|
||||
*.o
|
||||
*.a
|
||||
*.so
|
||||
/dist/
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
@@ -69,6 +74,7 @@ lerna-debug.log*
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
/logs/
|
||||
|
||||
# OS files
|
||||
Thumbs.db
|
||||
|
||||
48
.goreleaser.yaml
Normal file
48
.goreleaser.yaml
Normal file
@@ -0,0 +1,48 @@
|
||||
version: 2
|
||||
|
||||
before:
|
||||
hooks:
|
||||
- go mod tidy -v
|
||||
- sh -c "cd web && npm install && npm run build"
|
||||
|
||||
builds:
|
||||
- id: default
|
||||
main: ./cmd/jiggablend
|
||||
binary: jiggablend
|
||||
ldflags:
|
||||
- -X jiggablend/version.Version={{.Version}}
|
||||
- -X jiggablend/version.Date={{.Date}}
|
||||
env:
|
||||
- CGO_ENABLED=1
|
||||
goos:
|
||||
- linux
|
||||
goarch:
|
||||
- amd64
|
||||
|
||||
checksum:
|
||||
name_template: "checksums.txt"
|
||||
|
||||
archives:
|
||||
- id: default
|
||||
name_template: "{{ .ProjectName }}-{{ .Os }}-{{ .Arch }}"
|
||||
formats: tar.gz
|
||||
format_overrides:
|
||||
- goos: windows
|
||||
formats: zip
|
||||
files:
|
||||
- README.md
|
||||
- LICENSE
|
||||
|
||||
changelog:
|
||||
sort: asc
|
||||
filters:
|
||||
exclude:
|
||||
- "^docs:"
|
||||
- "^test:"
|
||||
|
||||
release:
|
||||
name_template: "{{ .ProjectName }}-{{ .Version }}"
|
||||
|
||||
gitea_urls:
|
||||
api: https://git.s1d3sw1ped.com/api/v1
|
||||
download: https://git.s1d3sw1ped.com
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright © 2026 s1d3sw1ped
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
102
Makefile
102
Makefile
@@ -1,58 +1,61 @@
|
||||
.PHONY: build build-manager build-runner build-web run-manager run-runner run cleanup cleanup-manager cleanup-runner clean test
|
||||
.PHONY: build build-web run run-manager run-runner cleanup cleanup-manager cleanup-runner clean-bin clean-web test help install
|
||||
|
||||
# Build all
|
||||
build: clean-bin build-manager build-runner
|
||||
|
||||
# Build manager
|
||||
build-manager: clean-bin build-web
|
||||
go build -o bin/manager ./cmd/manager
|
||||
|
||||
# Build runner
|
||||
build-runner: clean-bin
|
||||
GOOS=linux GOARCH=amd64 go build -o bin/runner ./cmd/runner
|
||||
# Build the jiggablend binary (includes embedded web UI)
|
||||
build:
|
||||
@echo "Building with GoReleaser..."
|
||||
goreleaser build --clean --snapshot --single-target
|
||||
@mkdir -p bin
|
||||
@find dist -name jiggablend -type f -exec cp {} bin/jiggablend \;
|
||||
|
||||
# Build web UI
|
||||
build-web: clean-web
|
||||
cd web && npm install && npm run build
|
||||
|
||||
# Cleanup manager (database and storage)
|
||||
# Cleanup manager logs
|
||||
cleanup-manager:
|
||||
@echo "Cleaning up manager database and storage..."
|
||||
@rm -f jiggablend.db 2>/dev/null || true
|
||||
@rm -f jiggablend.db-shm 2>/dev/null || true
|
||||
@rm -f jiggablend.db-wal 2>/dev/null || true
|
||||
@rm -rf jiggablend-storage 2>/dev/null || true
|
||||
@echo "Cleaning up manager logs..."
|
||||
@rm -rf logs/manager.log 2>/dev/null || true
|
||||
@echo "Manager cleanup complete"
|
||||
|
||||
# Cleanup runner (workspaces and secrets)
|
||||
# Cleanup runner logs
|
||||
cleanup-runner:
|
||||
@echo "Cleaning up runner workspaces and secrets..."
|
||||
@rm -rf jiggablend-workspaces jiggablend-workspace* *workspace* runner-secrets*.json 2>/dev/null || true
|
||||
@echo "Cleaning up runner logs..."
|
||||
@rm -rf logs/runner*.log 2>/dev/null || true
|
||||
@echo "Runner cleanup complete"
|
||||
|
||||
# Cleanup both manager and runner
|
||||
# Cleanup both manager and runner logs
|
||||
cleanup: cleanup-manager cleanup-runner
|
||||
|
||||
# Run all parallel
|
||||
run: cleanup-manager cleanup-runner build-manager build-runner
|
||||
# Run manager and runner in parallel (for testing)
|
||||
run: cleanup build init-test
|
||||
@echo "Starting manager and runner in parallel..."
|
||||
@echo "Press Ctrl+C to stop both..."
|
||||
@trap 'kill $$MANAGER_PID $$RUNNER_PID 2>/dev/null; exit' INT TERM; \
|
||||
FIXED_REGISTRATION_TOKEN=test-token ENABLE_LOCAL_AUTH=true LOCAL_TEST_EMAIL=test@example.com LOCAL_TEST_PASSWORD=testpassword bin/manager & \
|
||||
bin/jiggablend manager -l manager.log & \
|
||||
MANAGER_PID=$$!; \
|
||||
REGISTRATION_TOKEN=test-token bin/runner & \
|
||||
sleep 2; \
|
||||
bin/jiggablend runner -l runner.log --api-key=jk_r0_test_key_123456789012345678901234567890 & \
|
||||
RUNNER_PID=$$!; \
|
||||
wait $$MANAGER_PID $$RUNNER_PID
|
||||
|
||||
# Run manager
|
||||
# Note: ENABLE_LOCAL_AUTH enables local user registration/login
|
||||
# LOCAL_TEST_EMAIL and LOCAL_TEST_PASSWORD create a test user on startup (if it doesn't exist)
|
||||
run-manager: cleanup-manager build-manager
|
||||
FIXED_REGISTRATION_TOKEN=test-token ENABLE_LOCAL_AUTH=true LOCAL_TEST_EMAIL=test@example.com LOCAL_TEST_PASSWORD=testpassword bin/manager
|
||||
# Run manager server
|
||||
run-manager: cleanup-manager build init-test
|
||||
bin/jiggablend manager -l manager.log
|
||||
|
||||
# Run runner
|
||||
run-runner: cleanup-runner build-runner
|
||||
REGISTRATION_TOKEN=test-token bin/runner
|
||||
run-runner: cleanup-runner build
|
||||
bin/jiggablend runner -l runner.log --api-key=jk_r0_test_key_123456789012345678901234567890
|
||||
|
||||
# Initialize for testing (first run setup)
|
||||
init-test: build
|
||||
@echo "Initializing test configuration..."
|
||||
bin/jiggablend manager config enable localauth
|
||||
bin/jiggablend manager config set fixed-apikey jk_r0_test_key_123456789012345678901234567890 -f -y
|
||||
bin/jiggablend manager config add user test@example.com testpassword --admin -f -y
|
||||
@echo "Test configuration complete!"
|
||||
@echo "fixed api key: jk_r0_test_key_123456789012345678901234567890"
|
||||
@echo "test user: test@example.com"
|
||||
@echo "test password: testpassword"
|
||||
|
||||
# Clean bin build artifacts
|
||||
clean-bin:
|
||||
@@ -66,3 +69,38 @@ clean-web:
|
||||
test:
|
||||
go test ./... -timeout 30s
|
||||
|
||||
# Show help
|
||||
help:
|
||||
@echo "Jiggablend Build and Run Makefile"
|
||||
@echo ""
|
||||
@echo "Build targets:"
|
||||
@echo " build - Build jiggablend binary with embedded web UI"
|
||||
@echo " build-web - Build web UI only"
|
||||
@echo ""
|
||||
@echo "Run targets:"
|
||||
@echo " run - Run manager and runner in parallel (for testing)"
|
||||
@echo " run-manager - Run manager server"
|
||||
@echo " run-runner - Run runner with test API key"
|
||||
@echo " init-test - Initialize test configuration (run once)"
|
||||
@echo ""
|
||||
@echo "Cleanup targets:"
|
||||
@echo " cleanup - Clean all logs"
|
||||
@echo " cleanup-manager - Clean manager logs"
|
||||
@echo " cleanup-runner - Clean runner logs"
|
||||
@echo ""
|
||||
@echo "Other targets:"
|
||||
@echo " clean-bin - Clean build artifacts"
|
||||
@echo " clean-web - Clean web build artifacts"
|
||||
@echo " test - Run Go tests"
|
||||
@echo " help - Show this help"
|
||||
@echo ""
|
||||
@echo "CLI Usage:"
|
||||
@echo " jiggablend manager serve - Start the manager server"
|
||||
@echo " jiggablend runner - Start a runner"
|
||||
@echo " jiggablend manager config show - Show configuration"
|
||||
@echo " jiggablend manager config enable localauth"
|
||||
@echo " jiggablend manager config add user --email=x --password=y"
|
||||
@echo " jiggablend manager config add apikey --name=mykey"
|
||||
@echo " jiggablend manager config set fixed-apikey <key>"
|
||||
@echo " jiggablend manager config list users"
|
||||
@echo " jiggablend manager config list apikeys"
|
||||
|
||||
266
README.md
266
README.md
@@ -4,28 +4,41 @@ A distributed Blender render farm system built with Go. The system consists of a
|
||||
|
||||
## Architecture
|
||||
|
||||
- **Manager**: Central server with REST API, web UI, DuckDB database, and local file storage
|
||||
- **Manager**: Central server with REST API, embedded web UI, SQLite database, and local file storage
|
||||
- **Runner**: Linux amd64 client that connects to manager, receives jobs, executes Blender renders, and reports back
|
||||
|
||||
Both manager and runner are part of a single binary (`jiggablend`) with subcommands.
|
||||
|
||||
## Features
|
||||
|
||||
- OAuth authentication (Google and Discord)
|
||||
- Web-based job submission and monitoring
|
||||
- Distributed rendering across multiple runners
|
||||
- Real-time job progress tracking
|
||||
- File upload/download for Blender files and rendered outputs
|
||||
- Runner health monitoring
|
||||
- **Authentication**: OAuth (Google and Discord) and local authentication with user management
|
||||
- **Web UI**: Modern React-based interface for job submission and monitoring
|
||||
- **Distributed Rendering**: Scale across multiple runners with automatic job distribution
|
||||
- **Real-time Updates**: WebSocket-based progress tracking and job status updates
|
||||
- **Video Encoding**: Automatic video encoding from EXR/PNG sequences with multiple codec support:
|
||||
- H.264 (MP4) - SDR and HDR support
|
||||
- AV1 (MP4) - With alpha channel support
|
||||
- VP9 (WebM) - With alpha channel and HDR support
|
||||
- **Output Formats**: PNG, JPEG, EXR, and video formats (MP4, WebM)
|
||||
- **Blender Version Management**: Support for multiple Blender versions with automatic detection
|
||||
- **Metadata Extraction**: Automatic extraction of scene metadata from Blender files
|
||||
- **Admin Panel**: User and runner management interface
|
||||
- **Runner Management**: API key-based authentication for runners with health monitoring
|
||||
- **HDR Support**: Preserve HDR range in video encoding with HLG transfer function
|
||||
- **Alpha Channel**: Preserve alpha channel in video encoding (AV1 and VP9)
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### Manager
|
||||
- Go 1.21 or later
|
||||
- DuckDB (via Go driver)
|
||||
- Go 1.25.4 or later
|
||||
- SQLite (via Go driver)
|
||||
- Blender installed and in PATH (for metadata extraction)
|
||||
- ImageMagick installed (for EXR preview conversion)
|
||||
|
||||
### Runner
|
||||
- Linux amd64
|
||||
- Blender installed and in PATH
|
||||
- FFmpeg installed (optional, for video processing)
|
||||
- Blender installed (can use bundled versions from storage)
|
||||
- FFmpeg installed (required for video encoding)
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -44,44 +57,87 @@ go mod download
|
||||
|
||||
### Manager
|
||||
|
||||
Set the following environment variables for authentication (optional):
|
||||
Configuration is managed through the CLI using `jiggablend manager config` commands. The configuration is stored in the SQLite database.
|
||||
|
||||
#### Initial Setup
|
||||
|
||||
For testing, use the Makefile helper:
|
||||
```bash
|
||||
make init-test
|
||||
```
|
||||
|
||||
This will:
|
||||
- Enable local authentication
|
||||
- Set a fixed API key for testing
|
||||
- Create a test admin user (test@example.com / testpassword)
|
||||
|
||||
#### Manual Configuration
|
||||
|
||||
```bash
|
||||
# OAuth Providers (optional)
|
||||
export GOOGLE_CLIENT_ID="your-google-client-id"
|
||||
export GOOGLE_CLIENT_SECRET="your-google-client-secret"
|
||||
export GOOGLE_REDIRECT_URL="http://localhost:8080/api/auth/google/callback"
|
||||
# Enable local authentication
|
||||
jiggablend manager config enable localauth
|
||||
|
||||
export DISCORD_CLIENT_ID="your-discord-client-id"
|
||||
export DISCORD_CLIENT_SECRET="your-discord-client-secret"
|
||||
export DISCORD_REDIRECT_URL="http://localhost:8080/api/auth/discord/callback"
|
||||
# Add a user
|
||||
jiggablend manager config add user <email> <password> --admin
|
||||
|
||||
# Local Authentication (optional)
|
||||
export ENABLE_LOCAL_AUTH="true"
|
||||
# Generate an API key for runners
|
||||
jiggablend manager config add apikey <name> --scope manager
|
||||
|
||||
# Test User (optional, for testing only)
|
||||
# Creates a local user on startup if it doesn't exist
|
||||
export LOCAL_TEST_EMAIL="test@example.com"
|
||||
export LOCAL_TEST_PASSWORD="testpassword"
|
||||
# Set OAuth credentials
|
||||
jiggablend manager config set google-oauth <client-id> <client-secret> --redirect-url <url>
|
||||
jiggablend manager config set discord-oauth <client-id> <client-secret> --redirect-url <url>
|
||||
|
||||
# View current configuration
|
||||
jiggablend manager config show
|
||||
|
||||
# List users and API keys
|
||||
jiggablend manager config list users
|
||||
jiggablend manager config list apikeys
|
||||
```
|
||||
|
||||
#### Environment Variables
|
||||
|
||||
You can also use environment variables with the `JIGGABLEND_` prefix:
|
||||
- `JIGGABLEND_PORT` - Server port (default: 8080)
|
||||
- `JIGGABLEND_DB` - Database path (default: jiggablend.db)
|
||||
- `JIGGABLEND_STORAGE` - Storage path (default: ./jiggablend-storage)
|
||||
- `JIGGABLEND_LOG_FILE` - Log file path
|
||||
- `JIGGABLEND_LOG_LEVEL` - Log level (debug, info, warn, error)
|
||||
- `JIGGABLEND_VERBOSE` - Enable verbose logging
|
||||
|
||||
### Runner
|
||||
|
||||
No configuration required. Runner will auto-detect hostname and IP.
|
||||
The runner requires an API key to connect to the manager. The runner will auto-detect hostname and IP.
|
||||
|
||||
## Usage
|
||||
|
||||
### Building
|
||||
|
||||
```bash
|
||||
# Build the unified binary (includes embedded web UI)
|
||||
make build
|
||||
|
||||
# Or build directly
|
||||
go build -o bin/jiggablend ./cmd/jiggablend
|
||||
|
||||
# Build web UI separately
|
||||
make build-web
|
||||
```
|
||||
|
||||
### Running the Manager
|
||||
|
||||
```bash
|
||||
# Using make
|
||||
# Using make (includes test setup)
|
||||
make run-manager
|
||||
|
||||
# Or directly
|
||||
go run ./cmd/manager
|
||||
bin/jiggablend manager
|
||||
|
||||
# With custom options
|
||||
go run ./cmd/manager -port 8080 -db jiggablend.db -storage ./storage
|
||||
bin/jiggablend manager --port 8080 --db jiggablend.db --storage ./jiggablend-storage --log-file manager.log
|
||||
|
||||
# Using environment variables
|
||||
JIGGABLEND_PORT=8080 JIGGABLEND_DB=jiggablend.db bin/jiggablend manager
|
||||
```
|
||||
|
||||
The manager will start on `http://localhost:8080` by default.
|
||||
@@ -89,26 +145,28 @@ The manager will start on `http://localhost:8080` by default.
|
||||
### Running a Runner
|
||||
|
||||
```bash
|
||||
# Using make
|
||||
# Using make (uses test API key)
|
||||
make run-runner
|
||||
|
||||
# Or directly
|
||||
go run ./cmd/runner
|
||||
# Or directly (requires API key)
|
||||
bin/jiggablend runner --api-key <your-api-key>
|
||||
|
||||
# With custom options
|
||||
go run ./cmd/runner -manager http://localhost:8080 -name my-runner
|
||||
bin/jiggablend runner --manager http://localhost:8080 --name my-runner --api-key <key> --log-file runner.log
|
||||
|
||||
# Using environment variables
|
||||
JIGGABLEND_MANAGER=http://localhost:8080 JIGGABLEND_API_KEY=<key> bin/jiggablend runner
|
||||
```
|
||||
|
||||
### Building
|
||||
### Running Both (for Testing)
|
||||
|
||||
```bash
|
||||
# Build manager
|
||||
make build-manager
|
||||
|
||||
# Build runner (Linux amd64)
|
||||
make build-runner
|
||||
# Run manager and runner in parallel
|
||||
make run
|
||||
```
|
||||
|
||||
This will start both the manager and a test runner with a fixed API key.
|
||||
|
||||
## OAuth Setup
|
||||
|
||||
### Google OAuth
|
||||
@@ -118,7 +176,10 @@ make build-runner
|
||||
3. Enable Google+ API
|
||||
4. Create OAuth 2.0 credentials
|
||||
5. Add authorized redirect URI: `http://localhost:8080/api/auth/google/callback`
|
||||
6. Set environment variables with Client ID and Secret
|
||||
6. Configure using CLI:
|
||||
```bash
|
||||
jiggablend manager config set google-oauth <client-id> <client-secret> --redirect-url http://localhost:8080/api/auth/google/callback
|
||||
```
|
||||
|
||||
### Discord OAuth
|
||||
|
||||
@@ -126,25 +187,39 @@ make build-runner
|
||||
2. Create a new application
|
||||
3. Go to OAuth2 section
|
||||
4. Add redirect URI: `http://localhost:8080/api/auth/discord/callback`
|
||||
5. Set environment variables with Client ID and Secret
|
||||
5. Configure using CLI:
|
||||
```bash
|
||||
jiggablend manager config set discord-oauth <client-id> <client-secret> --redirect-url http://localhost:8080/api/auth/discord/callback
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
jiggablend/
|
||||
├── cmd/
|
||||
│ ├── manager/ # Manager server application
|
||||
│ └── runner/ # Runner client application
|
||||
│ └── jiggablend/ # Unified CLI application
|
||||
│ ├── cmd/ # Cobra command definitions
|
||||
│ └── main.go # Entry point
|
||||
├── internal/
|
||||
│ ├── api/ # REST API handlers
|
||||
│ ├── auth/ # OAuth authentication
|
||||
│ ├── database/ # DuckDB database models and migrations
|
||||
│ ├── queue/ # Job queue management
|
||||
│ ├── storage/ # File storage operations
|
||||
│ └── runner/ # Runner management logic
|
||||
│ ├── auth/ # Authentication (OAuth, local, sessions)
|
||||
│ ├── config/ # Configuration management
|
||||
│ ├── database/ # SQLite database models and migrations
|
||||
│ ├── logger/ # Logging utilities
|
||||
│ ├── manager/ # Manager server logic
|
||||
│ ├── runner/ # Runner client logic
|
||||
│ │ ├── api/ # Manager API client
|
||||
│ │ ├── blender/ # Blender version detection
|
||||
│ │ ├── encoding/ # Video encoding (H.264, AV1, VP9)
|
||||
│ │ ├── tasks/ # Task execution (render, encode, process)
|
||||
│ │ └── workspace/ # Workspace management
|
||||
│ └── storage/ # File storage operations
|
||||
├── pkg/
|
||||
│ └── types/ # Shared types and models
|
||||
├── web/ # Static web UI files
|
||||
│ ├── executils/ # Execution utilities
|
||||
│ ├── scripts/ # Python scripts for Blender
|
||||
│ └── types/ # Shared types and models
|
||||
├── web/ # React web UI
|
||||
│ ├── src/ # Source files
|
||||
│ └── dist/ # Built files (embedded in binary)
|
||||
├── go.mod
|
||||
└── Makefile
|
||||
```
|
||||
@@ -156,27 +231,106 @@ jiggablend/
|
||||
- `GET /api/auth/google/callback` - Google OAuth callback
|
||||
- `GET /api/auth/discord/login` - Initiate Discord OAuth
|
||||
- `GET /api/auth/discord/callback` - Discord OAuth callback
|
||||
- `POST /api/auth/login` - Local authentication login
|
||||
- `POST /api/auth/register` - User registration (if enabled)
|
||||
- `POST /api/auth/logout` - Logout
|
||||
- `GET /api/auth/me` - Get current user
|
||||
- `POST /api/auth/password/change` - Change password
|
||||
|
||||
### Jobs
|
||||
- `POST /api/jobs` - Create a new job
|
||||
- `GET /api/jobs` - List user's jobs
|
||||
- `GET /api/jobs/{id}` - Get job details
|
||||
- `DELETE /api/jobs/{id}` - Cancel a job
|
||||
- `POST /api/jobs/{id}/upload` - Upload job file
|
||||
- `POST /api/jobs/{id}/upload` - Upload job file (Blender file)
|
||||
- `GET /api/jobs/{id}/files` - List job files
|
||||
- `GET /api/jobs/{id}/files/{fileId}/download` - Download job file
|
||||
- `GET /api/jobs/{id}/metadata` - Extract metadata from uploaded file
|
||||
- `GET /api/jobs/{id}/outputs` - List job output files
|
||||
|
||||
### Runners
|
||||
- `GET /api/admin/runners` - List all runners (admin only)
|
||||
- `POST /api/runner/register` - Register a runner (uses registration token)
|
||||
- `POST /api/runner/heartbeat` - Update runner heartbeat (runner authenticated)
|
||||
### Blender
|
||||
- `GET /api/blender/versions` - List available Blender versions
|
||||
|
||||
### Runners (Internal API)
|
||||
- `POST /api/runner/register` - Register a runner (uses API key)
|
||||
- `POST /api/runner/heartbeat` - Update runner heartbeat
|
||||
- `GET /api/runner/tasks` - Get pending tasks for runner
|
||||
- `POST /api/runner/tasks/{id}/complete` - Mark task as complete
|
||||
- `GET /api/runner/files/{jobId}/{fileName}` - Download file for runner
|
||||
- `POST /api/runner/files/{jobId}/upload` - Upload file from runner
|
||||
|
||||
### Admin (Admin Only)
|
||||
- `GET /api/admin/runners` - List all runners
|
||||
- `GET /api/admin/jobs` - List all jobs
|
||||
- `GET /api/admin/users` - List all users
|
||||
- `GET /api/admin/stats` - System statistics
|
||||
|
||||
### WebSocket
|
||||
- `WS /api/ws` - WebSocket connection for real-time updates
|
||||
- Subscribe to job channels: `job:{jobId}`
|
||||
- Receive job status updates, progress, and logs
|
||||
|
||||
## Output Formats
|
||||
|
||||
The system supports the following output formats:
|
||||
|
||||
### Image Formats
|
||||
- **PNG** - Standard PNG output
|
||||
- **JPEG** - JPEG output
|
||||
- **EXR** - OpenEXR format (HDR)
|
||||
|
||||
### Video Formats
|
||||
- **EXR_264_MP4** - H.264 encoded MP4 from EXR sequence (SDR or HDR)
|
||||
- **EXR_AV1_MP4** - AV1 encoded MP4 from EXR sequence (with alpha channel support)
|
||||
- **EXR_VP9_WEBM** - VP9 encoded WebM from EXR sequence (with alpha channel and HDR support)
|
||||
|
||||
Video encoding features:
|
||||
- 2-pass encoding for optimal quality
|
||||
- HDR preservation using HLG transfer function
|
||||
- Alpha channel preservation (AV1 and VP9 only)
|
||||
- Automatic detection of source format (EXR or PNG)
|
||||
- Software encoding (libx264, libaom-av1, libvpx-vp9)
|
||||
|
||||
## Storage Structure
|
||||
|
||||
The manager uses a local storage directory (default: `./jiggablend-storage`) with the following structure:
|
||||
|
||||
```
|
||||
jiggablend-storage/
|
||||
├── blender-versions/ # Bundled Blender versions
|
||||
│ └── <version>/
|
||||
├── jobs/ # Job context files
|
||||
│ └── <job-id>/
|
||||
│ └── context.tar
|
||||
├── outputs/ # Rendered outputs
|
||||
│ └── <job-id>/
|
||||
├── temp/ # Temporary files
|
||||
└── uploads/ # Uploaded files
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
make test
|
||||
# Or directly
|
||||
go test ./... -timeout 30s
|
||||
```
|
||||
|
||||
### Web UI Development
|
||||
|
||||
The web UI is built with React and Vite. To develop the UI:
|
||||
|
||||
```bash
|
||||
cd web
|
||||
npm install
|
||||
npm run dev # Development server
|
||||
npm run build # Build for production
|
||||
```
|
||||
|
||||
The built files are embedded in the Go binary using `embed.FS`.
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
|
||||
177
cmd/jiggablend/cmd/manager.go
Normal file
177
cmd/jiggablend/cmd/manager.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"jiggablend/internal/auth"
|
||||
"jiggablend/internal/config"
|
||||
"jiggablend/internal/database"
|
||||
"jiggablend/internal/logger"
|
||||
manager "jiggablend/internal/manager"
|
||||
"jiggablend/internal/storage"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var managerCmd = &cobra.Command{
|
||||
Use: "manager",
|
||||
Short: "Start the Jiggablend manager server",
|
||||
Long: `Start the Jiggablend manager server to coordinate render jobs.`,
|
||||
Run: runManager,
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(managerCmd)
|
||||
|
||||
// Flags with env binding via viper
|
||||
managerCmd.Flags().StringP("port", "p", "8080", "Server port")
|
||||
managerCmd.Flags().String("db", "jiggablend.db", "Database path")
|
||||
managerCmd.Flags().String("storage", "./jiggablend-storage", "Storage path")
|
||||
managerCmd.Flags().StringP("log-file", "l", "", "Log file path (truncated on start, if not set logs only to stdout)")
|
||||
managerCmd.Flags().String("log-level", "info", "Log level (debug, info, warn, error)")
|
||||
managerCmd.Flags().BoolP("verbose", "v", false, "Enable verbose logging (same as --log-level=debug)")
|
||||
|
||||
// Bind flags to viper with JIGGABLEND_ prefix
|
||||
viper.SetEnvPrefix("JIGGABLEND")
|
||||
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
|
||||
viper.AutomaticEnv()
|
||||
|
||||
viper.BindPFlag("port", managerCmd.Flags().Lookup("port"))
|
||||
viper.BindPFlag("db", managerCmd.Flags().Lookup("db"))
|
||||
viper.BindPFlag("storage", managerCmd.Flags().Lookup("storage"))
|
||||
viper.BindPFlag("log_file", managerCmd.Flags().Lookup("log-file"))
|
||||
viper.BindPFlag("log_level", managerCmd.Flags().Lookup("log-level"))
|
||||
viper.BindPFlag("verbose", managerCmd.Flags().Lookup("verbose"))
|
||||
}
|
||||
|
||||
func runManager(cmd *cobra.Command, args []string) {
|
||||
// Get config values (flags take precedence over env vars)
|
||||
port := viper.GetString("port")
|
||||
dbPath := viper.GetString("db")
|
||||
storagePath := viper.GetString("storage")
|
||||
logFile := viper.GetString("log_file")
|
||||
logLevel := viper.GetString("log_level")
|
||||
verbose := viper.GetBool("verbose")
|
||||
|
||||
// Initialize logger
|
||||
if logFile != "" {
|
||||
if err := logger.InitWithFile(logFile); err != nil {
|
||||
logger.Fatalf("Failed to initialize logger: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if l := logger.GetDefault(); l != nil {
|
||||
l.Close()
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
logger.InitStdout()
|
||||
}
|
||||
|
||||
// Set log level
|
||||
if verbose {
|
||||
logger.SetLevel(logger.LevelDebug)
|
||||
} else {
|
||||
logger.SetLevel(logger.ParseLevel(logLevel))
|
||||
}
|
||||
|
||||
if logFile != "" {
|
||||
logger.Infof("Logging to file: %s", logFile)
|
||||
}
|
||||
logger.Debugf("Log level: %s", logLevel)
|
||||
|
||||
// Initialize database
|
||||
db, err := database.NewDB(dbPath)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to initialize database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Initialize config from database
|
||||
cfg := config.NewConfig(db)
|
||||
if err := cfg.InitializeFromEnv(); err != nil {
|
||||
logger.Fatalf("Failed to initialize config: %v", err)
|
||||
}
|
||||
logger.Info("Configuration loaded from database")
|
||||
|
||||
// Initialize auth
|
||||
authHandler, err := auth.NewAuth(db, cfg)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to initialize auth: %v", err)
|
||||
}
|
||||
|
||||
// Initialize storage
|
||||
storageHandler, err := storage.NewStorage(storagePath)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to initialize storage: %v", err)
|
||||
}
|
||||
|
||||
// Check if Blender is available
|
||||
if err := checkBlenderAvailable(); err != nil {
|
||||
logger.Fatalf("Blender is not available: %v\n"+
|
||||
"The manager requires Blender to be installed and in PATH for metadata extraction.\n"+
|
||||
"Please install Blender and ensure it's accessible via the 'blender' command.", err)
|
||||
}
|
||||
logger.Info("Blender is available")
|
||||
|
||||
// Check if ImageMagick is available
|
||||
if err := checkImageMagickAvailable(); err != nil {
|
||||
logger.Fatalf("ImageMagick is not available: %v\n"+
|
||||
"The manager requires ImageMagick to be installed and in PATH for EXR preview conversion.\n"+
|
||||
"Please install ImageMagick and ensure 'magick' or 'convert' command is accessible.", err)
|
||||
}
|
||||
logger.Info("ImageMagick is available")
|
||||
|
||||
// Create manager server
|
||||
server, err := manager.NewManager(db, cfg, authHandler, storageHandler)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to create server: %v", err)
|
||||
}
|
||||
|
||||
// Start server
|
||||
addr := fmt.Sprintf(":%s", port)
|
||||
logger.Infof("Starting manager server on %s", addr)
|
||||
logger.Infof("Database: %s", dbPath)
|
||||
logger.Infof("Storage: %s", storagePath)
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: server,
|
||||
MaxHeaderBytes: 1 << 20,
|
||||
ReadTimeout: 0,
|
||||
WriteTimeout: 0,
|
||||
}
|
||||
|
||||
if err := httpServer.ListenAndServe(); err != nil {
|
||||
logger.Fatalf("Server failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func checkBlenderAvailable() error {
|
||||
cmd := exec.Command("blender", "--version")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to run 'blender --version': %w (output: %s)", err, string(output))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkImageMagickAvailable() error {
|
||||
// Try 'magick' first (ImageMagick 7+)
|
||||
cmd := exec.Command("magick", "--version")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fall back to 'convert' (ImageMagick 6 or legacy mode)
|
||||
cmd = exec.Command("convert", "--version")
|
||||
output, err = cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to run 'magick --version' or 'convert --version': %w (output: %s)", err, string(output))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
621
cmd/jiggablend/cmd/managerconfig.go
Normal file
621
cmd/jiggablend/cmd/managerconfig.go
Normal file
@@ -0,0 +1,621 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"jiggablend/internal/config"
|
||||
"jiggablend/internal/database"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var (
|
||||
configDBPath string
|
||||
configYes bool // Auto-confirm prompts
|
||||
configForce bool // Force override existing
|
||||
)
|
||||
|
||||
var configCmd = &cobra.Command{
|
||||
Use: "config",
|
||||
Short: "Configure the manager",
|
||||
Long: `Configure the Jiggablend manager settings stored in the database.`,
|
||||
}
|
||||
|
||||
// --- Enable/Disable commands ---
|
||||
|
||||
var enableCmd = &cobra.Command{
|
||||
Use: "enable",
|
||||
Short: "Enable a feature",
|
||||
}
|
||||
|
||||
var disableCmd = &cobra.Command{
|
||||
Use: "disable",
|
||||
Short: "Disable a feature",
|
||||
}
|
||||
|
||||
var enableLocalAuthCmd = &cobra.Command{
|
||||
Use: "localauth",
|
||||
Short: "Enable local authentication",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
withConfig(func(cfg *config.Config, db *database.DB) {
|
||||
if err := cfg.SetBool(config.KeyEnableLocalAuth, true); err != nil {
|
||||
exitWithError("Failed to enable local auth: %v", err)
|
||||
}
|
||||
fmt.Println("Local authentication enabled")
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
var disableLocalAuthCmd = &cobra.Command{
|
||||
Use: "localauth",
|
||||
Short: "Disable local authentication",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
withConfig(func(cfg *config.Config, db *database.DB) {
|
||||
if err := cfg.SetBool(config.KeyEnableLocalAuth, false); err != nil {
|
||||
exitWithError("Failed to disable local auth: %v", err)
|
||||
}
|
||||
fmt.Println("Local authentication disabled")
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
var enableRegistrationCmd = &cobra.Command{
|
||||
Use: "registration",
|
||||
Short: "Enable user registration",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
withConfig(func(cfg *config.Config, db *database.DB) {
|
||||
if err := cfg.SetBool(config.KeyRegistrationEnabled, true); err != nil {
|
||||
exitWithError("Failed to enable registration: %v", err)
|
||||
}
|
||||
fmt.Println("User registration enabled")
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
var disableRegistrationCmd = &cobra.Command{
|
||||
Use: "registration",
|
||||
Short: "Disable user registration",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
withConfig(func(cfg *config.Config, db *database.DB) {
|
||||
if err := cfg.SetBool(config.KeyRegistrationEnabled, false); err != nil {
|
||||
exitWithError("Failed to disable registration: %v", err)
|
||||
}
|
||||
fmt.Println("User registration disabled")
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
var enableProductionCmd = &cobra.Command{
|
||||
Use: "production",
|
||||
Short: "Enable production mode",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
withConfig(func(cfg *config.Config, db *database.DB) {
|
||||
if err := cfg.SetBool(config.KeyProductionMode, true); err != nil {
|
||||
exitWithError("Failed to enable production mode: %v", err)
|
||||
}
|
||||
fmt.Println("Production mode enabled")
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
var disableProductionCmd = &cobra.Command{
|
||||
Use: "production",
|
||||
Short: "Disable production mode",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
withConfig(func(cfg *config.Config, db *database.DB) {
|
||||
if err := cfg.SetBool(config.KeyProductionMode, false); err != nil {
|
||||
exitWithError("Failed to disable production mode: %v", err)
|
||||
}
|
||||
fmt.Println("Production mode disabled")
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
// --- Add commands ---
|
||||
|
||||
var addCmd = &cobra.Command{
|
||||
Use: "add",
|
||||
Short: "Add a resource",
|
||||
}
|
||||
|
||||
var (
|
||||
addUserName string
|
||||
addUserAdmin bool
|
||||
)
|
||||
|
||||
var addUserCmd = &cobra.Command{
|
||||
Use: "user <email> <password>",
|
||||
Short: "Add a local user",
|
||||
Long: `Add a new local user account to the database.`,
|
||||
Args: cobra.ExactArgs(2),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
email := args[0]
|
||||
password := args[1]
|
||||
|
||||
name := addUserName
|
||||
if name == "" {
|
||||
// Use email prefix as name
|
||||
if atIndex := strings.Index(email, "@"); atIndex > 0 {
|
||||
name = email[:atIndex]
|
||||
} else {
|
||||
name = email
|
||||
}
|
||||
}
|
||||
if len(password) < 8 {
|
||||
exitWithError("Password must be at least 8 characters")
|
||||
}
|
||||
|
||||
withConfig(func(cfg *config.Config, db *database.DB) {
|
||||
// Check if user exists
|
||||
var exists bool
|
||||
err := db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ?)", email).Scan(&exists)
|
||||
})
|
||||
if err != nil {
|
||||
exitWithError("Failed to check user: %v", err)
|
||||
}
|
||||
isAdmin := addUserAdmin
|
||||
if exists {
|
||||
if !configForce {
|
||||
exitWithError("User with email %s already exists (use -f to override)", email)
|
||||
}
|
||||
// Confirm override
|
||||
if !configYes && !confirm(fmt.Sprintf("User %s already exists. Override?", email)) {
|
||||
fmt.Println("Aborted")
|
||||
return
|
||||
}
|
||||
// Update existing user
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
exitWithError("Failed to hash password: %v", err)
|
||||
}
|
||||
err = db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec(
|
||||
"UPDATE users SET name = ?, password_hash = ?, is_admin = ? WHERE email = ?",
|
||||
name, string(hashedPassword), isAdmin, email,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
exitWithError("Failed to update user: %v", err)
|
||||
}
|
||||
fmt.Printf("Updated user: %s (admin: %v)\n", email, isAdmin)
|
||||
return
|
||||
}
|
||||
|
||||
// Hash password
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
exitWithError("Failed to hash password: %v", err)
|
||||
}
|
||||
|
||||
// Check if first user (make admin)
|
||||
var userCount int
|
||||
db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
})
|
||||
if userCount == 0 {
|
||||
isAdmin = true
|
||||
}
|
||||
|
||||
// Confirm creation
|
||||
if !configYes && !confirm(fmt.Sprintf("Create user %s (admin: %v)?", email, isAdmin)) {
|
||||
fmt.Println("Aborted")
|
||||
return
|
||||
}
|
||||
|
||||
// Create user
|
||||
err = db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?)",
|
||||
email, name, email, string(hashedPassword), isAdmin,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
exitWithError("Failed to create user: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Created user: %s (admin: %v)\n", email, isAdmin)
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
var addAPIKeyScope string
|
||||
|
||||
var addAPIKeyCmd = &cobra.Command{
|
||||
Use: "apikey [name]",
|
||||
Short: "Add a runner API key",
|
||||
Long: `Generate a new API key for runner authentication.`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
name := "cli-generated"
|
||||
if len(args) > 0 {
|
||||
name = args[0]
|
||||
}
|
||||
|
||||
withConfig(func(cfg *config.Config, db *database.DB) {
|
||||
// Check if API key with same name exists
|
||||
var exists bool
|
||||
err := db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM runner_api_keys WHERE name = ?)", name).Scan(&exists)
|
||||
})
|
||||
if err != nil {
|
||||
exitWithError("Failed to check API key: %v", err)
|
||||
}
|
||||
if exists {
|
||||
if !configForce {
|
||||
exitWithError("API key with name %s already exists (use -f to create another)", name)
|
||||
}
|
||||
if !configYes && !confirm(fmt.Sprintf("API key named '%s' already exists. Create another?", name)) {
|
||||
fmt.Println("Aborted")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Confirm creation
|
||||
if !configYes && !confirm(fmt.Sprintf("Generate new API key '%s' (scope: %s)?", name, addAPIKeyScope)) {
|
||||
fmt.Println("Aborted")
|
||||
return
|
||||
}
|
||||
|
||||
// Generate API key
|
||||
key, keyPrefix, keyHash, err := generateAPIKey()
|
||||
if err != nil {
|
||||
exitWithError("Failed to generate API key: %v", err)
|
||||
}
|
||||
|
||||
// Get first user ID for created_by (or use 0 if no users)
|
||||
var createdBy int64
|
||||
db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT id FROM users ORDER BY id ASC LIMIT 1").Scan(&createdBy)
|
||||
})
|
||||
|
||||
// Store in database
|
||||
err = db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec(
|
||||
`INSERT INTO runner_api_keys (key_prefix, key_hash, name, scope, is_active, created_by)
|
||||
VALUES (?, ?, ?, ?, true, ?)`,
|
||||
keyPrefix, keyHash, name, addAPIKeyScope, createdBy,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
exitWithError("Failed to store API key: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Generated API key: %s\n", key)
|
||||
fmt.Printf("Name: %s, Scope: %s\n", name, addAPIKeyScope)
|
||||
fmt.Println("\nSave this key - it cannot be retrieved later!")
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
// --- Set commands ---
|
||||
|
||||
var setCmd = &cobra.Command{
|
||||
Use: "set",
|
||||
Short: "Set a configuration value",
|
||||
}
|
||||
|
||||
var setFixedAPIKeyCmd = &cobra.Command{
|
||||
Use: "fixed-apikey [key]",
|
||||
Short: "Set a fixed API key for testing",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
withConfig(func(cfg *config.Config, db *database.DB) {
|
||||
// Check if already set
|
||||
existing := cfg.FixedAPIKey()
|
||||
if existing != "" && !configForce {
|
||||
exitWithError("Fixed API key already set (use -f to override)")
|
||||
}
|
||||
if existing != "" && !configYes && !confirm("Fixed API key already set. Override?") {
|
||||
fmt.Println("Aborted")
|
||||
return
|
||||
}
|
||||
if err := cfg.Set(config.KeyFixedAPIKey, args[0]); err != nil {
|
||||
exitWithError("Failed to set fixed API key: %v", err)
|
||||
}
|
||||
fmt.Println("Fixed API key set")
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
var setAllowedOriginsCmd = &cobra.Command{
|
||||
Use: "allowed-origins [origins]",
|
||||
Short: "Set allowed CORS origins (comma-separated)",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
withConfig(func(cfg *config.Config, db *database.DB) {
|
||||
if err := cfg.Set(config.KeyAllowedOrigins, args[0]); err != nil {
|
||||
exitWithError("Failed to set allowed origins: %v", err)
|
||||
}
|
||||
fmt.Printf("Allowed origins set to: %s\n", args[0])
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
var setGoogleOAuthRedirectURL string
|
||||
|
||||
var setGoogleOAuthCmd = &cobra.Command{
|
||||
Use: "google-oauth <client-id> <client-secret>",
|
||||
Short: "Set Google OAuth credentials",
|
||||
Args: cobra.ExactArgs(2),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
clientID := args[0]
|
||||
clientSecret := args[1]
|
||||
|
||||
withConfig(func(cfg *config.Config, db *database.DB) {
|
||||
// Check if already configured
|
||||
existing := cfg.GoogleClientID()
|
||||
if existing != "" && !configForce {
|
||||
exitWithError("Google OAuth already configured (use -f to override)")
|
||||
}
|
||||
if existing != "" && !configYes && !confirm("Google OAuth already configured. Override?") {
|
||||
fmt.Println("Aborted")
|
||||
return
|
||||
}
|
||||
if err := cfg.Set(config.KeyGoogleClientID, clientID); err != nil {
|
||||
exitWithError("Failed to set Google client ID: %v", err)
|
||||
}
|
||||
if err := cfg.Set(config.KeyGoogleClientSecret, clientSecret); err != nil {
|
||||
exitWithError("Failed to set Google client secret: %v", err)
|
||||
}
|
||||
if setGoogleOAuthRedirectURL != "" {
|
||||
if err := cfg.Set(config.KeyGoogleRedirectURL, setGoogleOAuthRedirectURL); err != nil {
|
||||
exitWithError("Failed to set Google redirect URL: %v", err)
|
||||
}
|
||||
}
|
||||
fmt.Println("Google OAuth configured")
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
var setDiscordOAuthRedirectURL string
|
||||
|
||||
var setDiscordOAuthCmd = &cobra.Command{
|
||||
Use: "discord-oauth <client-id> <client-secret>",
|
||||
Short: "Set Discord OAuth credentials",
|
||||
Args: cobra.ExactArgs(2),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
clientID := args[0]
|
||||
clientSecret := args[1]
|
||||
|
||||
withConfig(func(cfg *config.Config, db *database.DB) {
|
||||
// Check if already configured
|
||||
existing := cfg.DiscordClientID()
|
||||
if existing != "" && !configForce {
|
||||
exitWithError("Discord OAuth already configured (use -f to override)")
|
||||
}
|
||||
if existing != "" && !configYes && !confirm("Discord OAuth already configured. Override?") {
|
||||
fmt.Println("Aborted")
|
||||
return
|
||||
}
|
||||
if err := cfg.Set(config.KeyDiscordClientID, clientID); err != nil {
|
||||
exitWithError("Failed to set Discord client ID: %v", err)
|
||||
}
|
||||
if err := cfg.Set(config.KeyDiscordClientSecret, clientSecret); err != nil {
|
||||
exitWithError("Failed to set Discord client secret: %v", err)
|
||||
}
|
||||
if setDiscordOAuthRedirectURL != "" {
|
||||
if err := cfg.Set(config.KeyDiscordRedirectURL, setDiscordOAuthRedirectURL); err != nil {
|
||||
exitWithError("Failed to set Discord redirect URL: %v", err)
|
||||
}
|
||||
}
|
||||
fmt.Println("Discord OAuth configured")
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
// --- Show command ---
|
||||
|
||||
var showCmd = &cobra.Command{
|
||||
Use: "show",
|
||||
Short: "Show current configuration",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
withConfig(func(cfg *config.Config, db *database.DB) {
|
||||
all, err := cfg.GetAll()
|
||||
if err != nil {
|
||||
exitWithError("Failed to get config: %v", err)
|
||||
}
|
||||
|
||||
if len(all) == 0 {
|
||||
fmt.Println("No configuration stored")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("Current configuration:")
|
||||
fmt.Println("----------------------")
|
||||
for key, value := range all {
|
||||
// Redact sensitive values
|
||||
if strings.Contains(key, "secret") || strings.Contains(key, "api_key") || strings.Contains(key, "password") {
|
||||
fmt.Printf(" %s: [REDACTED]\n", key)
|
||||
} else {
|
||||
fmt.Printf(" %s: %s\n", key, value)
|
||||
}
|
||||
}
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
// --- List commands ---
|
||||
|
||||
var listCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List resources",
|
||||
}
|
||||
|
||||
var listUsersCmd = &cobra.Command{
|
||||
Use: "users",
|
||||
Short: "List all users",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
withConfig(func(cfg *config.Config, db *database.DB) {
|
||||
var rows *sql.Rows
|
||||
err := db.With(func(conn *sql.DB) error {
|
||||
var err error
|
||||
rows, err = conn.Query("SELECT id, email, name, oauth_provider, is_admin, created_at FROM users ORDER BY id")
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
exitWithError("Failed to list users: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
fmt.Printf("%-6s %-30s %-20s %-10s %-6s %s\n", "ID", "Email", "Name", "Provider", "Admin", "Created")
|
||||
fmt.Println(strings.Repeat("-", 100))
|
||||
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var email, name, provider string
|
||||
var isAdmin bool
|
||||
var createdAt string
|
||||
if err := rows.Scan(&id, &email, &name, &provider, &isAdmin, &createdAt); err != nil {
|
||||
continue
|
||||
}
|
||||
adminStr := "no"
|
||||
if isAdmin {
|
||||
adminStr = "yes"
|
||||
}
|
||||
fmt.Printf("%-6d %-30s %-20s %-10s %-6s %s\n", id, email, name, provider, adminStr, createdAt[:19])
|
||||
}
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
var listAPIKeysCmd = &cobra.Command{
|
||||
Use: "apikeys",
|
||||
Short: "List all API keys",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
withConfig(func(cfg *config.Config, db *database.DB) {
|
||||
var rows *sql.Rows
|
||||
err := db.With(func(conn *sql.DB) error {
|
||||
var err error
|
||||
rows, err = conn.Query("SELECT id, key_prefix, name, scope, is_active, created_at FROM runner_api_keys ORDER BY id")
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
exitWithError("Failed to list API keys: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
fmt.Printf("%-6s %-12s %-20s %-10s %-8s %s\n", "ID", "Prefix", "Name", "Scope", "Active", "Created")
|
||||
fmt.Println(strings.Repeat("-", 80))
|
||||
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var prefix, name, scope string
|
||||
var isActive bool
|
||||
var createdAt string
|
||||
if err := rows.Scan(&id, &prefix, &name, &scope, &isActive, &createdAt); err != nil {
|
||||
continue
|
||||
}
|
||||
activeStr := "no"
|
||||
if isActive {
|
||||
activeStr = "yes"
|
||||
}
|
||||
fmt.Printf("%-6d %-12s %-20s %-10s %-8s %s\n", id, prefix, name, scope, activeStr, createdAt[:19])
|
||||
}
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
managerCmd.AddCommand(configCmd)
|
||||
|
||||
// Global config flags
|
||||
configCmd.PersistentFlags().StringVar(&configDBPath, "db", "jiggablend.db", "Database path")
|
||||
configCmd.PersistentFlags().BoolVarP(&configYes, "yes", "y", false, "Auto-confirm prompts")
|
||||
configCmd.PersistentFlags().BoolVarP(&configForce, "force", "f", false, "Force override existing")
|
||||
|
||||
// Enable/Disable
|
||||
configCmd.AddCommand(enableCmd)
|
||||
configCmd.AddCommand(disableCmd)
|
||||
enableCmd.AddCommand(enableLocalAuthCmd)
|
||||
enableCmd.AddCommand(enableRegistrationCmd)
|
||||
enableCmd.AddCommand(enableProductionCmd)
|
||||
disableCmd.AddCommand(disableLocalAuthCmd)
|
||||
disableCmd.AddCommand(disableRegistrationCmd)
|
||||
disableCmd.AddCommand(disableProductionCmd)
|
||||
|
||||
// Add
|
||||
configCmd.AddCommand(addCmd)
|
||||
addCmd.AddCommand(addUserCmd)
|
||||
addUserCmd.Flags().StringVarP(&addUserName, "name", "n", "", "User display name")
|
||||
addUserCmd.Flags().BoolVarP(&addUserAdmin, "admin", "a", false, "Make user an admin")
|
||||
|
||||
addCmd.AddCommand(addAPIKeyCmd)
|
||||
addAPIKeyCmd.Flags().StringVarP(&addAPIKeyScope, "scope", "s", "manager", "API key scope (manager or user)")
|
||||
|
||||
// Set
|
||||
configCmd.AddCommand(setCmd)
|
||||
setCmd.AddCommand(setFixedAPIKeyCmd)
|
||||
setCmd.AddCommand(setAllowedOriginsCmd)
|
||||
setCmd.AddCommand(setGoogleOAuthCmd)
|
||||
setCmd.AddCommand(setDiscordOAuthCmd)
|
||||
|
||||
setGoogleOAuthCmd.Flags().StringVarP(&setGoogleOAuthRedirectURL, "redirect-url", "r", "", "Google OAuth redirect URL")
|
||||
setDiscordOAuthCmd.Flags().StringVarP(&setDiscordOAuthRedirectURL, "redirect-url", "r", "", "Discord OAuth redirect URL")
|
||||
|
||||
// Show
|
||||
configCmd.AddCommand(showCmd)
|
||||
|
||||
// List
|
||||
configCmd.AddCommand(listCmd)
|
||||
listCmd.AddCommand(listUsersCmd)
|
||||
listCmd.AddCommand(listAPIKeysCmd)
|
||||
}
|
||||
|
||||
// withConfig opens the database and runs the callback with config access
|
||||
func withConfig(fn func(cfg *config.Config, db *database.DB)) {
|
||||
db, err := database.NewDB(configDBPath)
|
||||
if err != nil {
|
||||
exitWithError("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
cfg := config.NewConfig(db)
|
||||
fn(cfg, db)
|
||||
}
|
||||
|
||||
// generateAPIKey generates a new API key
|
||||
func generateAPIKey() (key, prefix, hash string, err error) {
|
||||
randomBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
randomStr := hex.EncodeToString(randomBytes)
|
||||
|
||||
prefixDigit := make([]byte, 1)
|
||||
if _, err := rand.Read(prefixDigit); err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
prefix = fmt.Sprintf("jk_r%d", prefixDigit[0]%10)
|
||||
key = fmt.Sprintf("%s_%s", prefix, randomStr)
|
||||
|
||||
keyHash := sha256.Sum256([]byte(key))
|
||||
hash = hex.EncodeToString(keyHash[:])
|
||||
|
||||
return key, prefix, hash, nil
|
||||
}
|
||||
|
||||
// confirm prompts the user for confirmation
|
||||
func confirm(prompt string) bool {
|
||||
fmt.Printf("%s [y/N]: ", prompt)
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
response, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
response = strings.TrimSpace(strings.ToLower(response))
|
||||
return response == "y" || response == "yes"
|
||||
}
|
||||
|
||||
34
cmd/jiggablend/cmd/root.go
Normal file
34
cmd/jiggablend/cmd/root.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "jiggablend",
|
||||
Short: "Jiggablend - Distributed Blender Render Farm",
|
||||
Long: `Jiggablend is a distributed render farm for Blender.
|
||||
|
||||
Run 'jiggablend manager' to start the manager server.
|
||||
Run 'jiggablend runner' to start a render runner.
|
||||
Run 'jiggablend manager config' to configure the manager.`,
|
||||
}
|
||||
|
||||
// Execute runs the root command
|
||||
func Execute() error {
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Global flags can be added here if needed
|
||||
rootCmd.CompletionOptions.DisableDefaultCmd = true
|
||||
}
|
||||
|
||||
// exitWithError prints an error and exits
|
||||
func exitWithError(msg string, args ...interface{}) {
|
||||
fmt.Fprintf(os.Stderr, "Error: "+msg+"\n", args...)
|
||||
os.Exit(1)
|
||||
}
|
||||
208
cmd/jiggablend/cmd/runner.go
Normal file
208
cmd/jiggablend/cmd/runner.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"jiggablend/internal/logger"
|
||||
"jiggablend/internal/runner"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var runnerViper = viper.New()
|
||||
|
||||
var runnerCmd = &cobra.Command{
|
||||
Use: "runner",
|
||||
Short: "Start the Jiggablend render runner",
|
||||
Long: `Start the Jiggablend render runner that connects to a manager and processes render tasks.`,
|
||||
Run: runRunner,
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(runnerCmd)
|
||||
|
||||
runnerCmd.Flags().StringP("manager", "m", "http://localhost:8080", "Manager URL")
|
||||
runnerCmd.Flags().StringP("name", "n", "", "Runner name")
|
||||
runnerCmd.Flags().String("hostname", "", "Runner hostname")
|
||||
runnerCmd.Flags().StringP("api-key", "k", "", "API key for authentication")
|
||||
runnerCmd.Flags().StringP("log-file", "l", "", "Log file path (truncated on start, if not set logs only to stdout)")
|
||||
runnerCmd.Flags().String("log-level", "info", "Log level (debug, info, warn, error)")
|
||||
runnerCmd.Flags().BoolP("verbose", "v", false, "Enable verbose logging (same as --log-level=debug)")
|
||||
runnerCmd.Flags().Duration("poll-interval", 5*time.Second, "Job polling interval")
|
||||
|
||||
// Bind flags to viper with JIGGABLEND_ prefix
|
||||
runnerViper.SetEnvPrefix("JIGGABLEND")
|
||||
runnerViper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
|
||||
runnerViper.AutomaticEnv()
|
||||
|
||||
runnerViper.BindPFlag("manager", runnerCmd.Flags().Lookup("manager"))
|
||||
runnerViper.BindPFlag("name", runnerCmd.Flags().Lookup("name"))
|
||||
runnerViper.BindPFlag("hostname", runnerCmd.Flags().Lookup("hostname"))
|
||||
runnerViper.BindPFlag("api_key", runnerCmd.Flags().Lookup("api-key"))
|
||||
runnerViper.BindPFlag("log_file", runnerCmd.Flags().Lookup("log-file"))
|
||||
runnerViper.BindPFlag("log_level", runnerCmd.Flags().Lookup("log-level"))
|
||||
runnerViper.BindPFlag("verbose", runnerCmd.Flags().Lookup("verbose"))
|
||||
runnerViper.BindPFlag("poll_interval", runnerCmd.Flags().Lookup("poll-interval"))
|
||||
}
|
||||
|
||||
func runRunner(cmd *cobra.Command, args []string) {
|
||||
// Get config values (flags take precedence over env vars)
|
||||
managerURL := runnerViper.GetString("manager")
|
||||
name := runnerViper.GetString("name")
|
||||
hostname := runnerViper.GetString("hostname")
|
||||
apiKey := runnerViper.GetString("api_key")
|
||||
logFile := runnerViper.GetString("log_file")
|
||||
logLevel := runnerViper.GetString("log_level")
|
||||
verbose := runnerViper.GetBool("verbose")
|
||||
pollInterval := runnerViper.GetDuration("poll_interval")
|
||||
|
||||
var r *runner.Runner
|
||||
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
logger.Errorf("Runner panicked: %v", rec)
|
||||
if r != nil {
|
||||
r.Cleanup()
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
if hostname == "" {
|
||||
hostname, _ = os.Hostname()
|
||||
}
|
||||
|
||||
// Generate unique runner ID suffix
|
||||
runnerIDStr := generateShortID()
|
||||
|
||||
// Generate runner name with ID if not provided
|
||||
if name == "" {
|
||||
name = fmt.Sprintf("runner-%s-%s", hostname, runnerIDStr)
|
||||
} else {
|
||||
name = fmt.Sprintf("%s-%s", name, runnerIDStr)
|
||||
}
|
||||
|
||||
// Initialize logger
|
||||
if logFile != "" {
|
||||
if err := logger.InitWithFile(logFile); err != nil {
|
||||
logger.Fatalf("Failed to initialize logger: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if l := logger.GetDefault(); l != nil {
|
||||
l.Close()
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
logger.InitStdout()
|
||||
}
|
||||
|
||||
// Set log level
|
||||
if verbose {
|
||||
logger.SetLevel(logger.LevelDebug)
|
||||
} else {
|
||||
logger.SetLevel(logger.ParseLevel(logLevel))
|
||||
}
|
||||
|
||||
logger.Info("Runner starting up...")
|
||||
logger.Debugf("Generated runner ID suffix: %s", runnerIDStr)
|
||||
if logFile != "" {
|
||||
logger.Infof("Logging to file: %s", logFile)
|
||||
}
|
||||
|
||||
// Create runner
|
||||
r = runner.New(managerURL, name, hostname)
|
||||
|
||||
// Check for required tools early to fail fast
|
||||
if err := r.CheckRequiredTools(); err != nil {
|
||||
logger.Fatalf("Required tool check failed: %v", err)
|
||||
}
|
||||
|
||||
// Clean up orphaned workspace directories
|
||||
r.Cleanup()
|
||||
|
||||
// Probe capabilities and log them
|
||||
logger.Debug("Probing runner capabilities...")
|
||||
capabilities := r.ProbeCapabilities()
|
||||
capList := []string{}
|
||||
for cap, value := range capabilities {
|
||||
if enabled, ok := value.(bool); ok && enabled {
|
||||
capList = append(capList, cap)
|
||||
}
|
||||
}
|
||||
if len(capList) > 0 {
|
||||
logger.Infof("Detected capabilities: %s", strings.Join(capList, ", "))
|
||||
} else {
|
||||
logger.Warn("No capabilities detected")
|
||||
}
|
||||
|
||||
// Register with API key
|
||||
if apiKey == "" {
|
||||
logger.Fatal("API key required (use --api-key or set JIGGABLEND_API_KEY env var)")
|
||||
}
|
||||
|
||||
// Retry registration with exponential backoff
|
||||
backoff := 1 * time.Second
|
||||
maxBackoff := 30 * time.Second
|
||||
maxRetries := 10
|
||||
retryCount := 0
|
||||
|
||||
var runnerID int64
|
||||
|
||||
for {
|
||||
var err error
|
||||
runnerID, err = r.Register(apiKey)
|
||||
if err == nil {
|
||||
logger.Infof("Registered runner with ID: %d", runnerID)
|
||||
break
|
||||
}
|
||||
|
||||
errMsg := err.Error()
|
||||
if strings.Contains(errMsg, "token error:") {
|
||||
logger.Fatalf("Registration failed (token error): %v", err)
|
||||
}
|
||||
|
||||
retryCount++
|
||||
if retryCount >= maxRetries {
|
||||
logger.Fatalf("Failed to register runner after %d attempts: %v", maxRetries, err)
|
||||
}
|
||||
|
||||
logger.Warnf("Registration failed (attempt %d/%d): %v, retrying in %v", retryCount, maxRetries, err, backoff)
|
||||
time.Sleep(backoff)
|
||||
backoff *= 2
|
||||
if backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
}
|
||||
|
||||
// Signal handlers
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
sig := <-sigChan
|
||||
logger.Infof("Received signal: %v, killing all processes and cleaning up...", sig)
|
||||
r.KillAllProcesses()
|
||||
r.Cleanup()
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
// Start polling for jobs
|
||||
logger.Infof("Runner started, polling for jobs (interval: %v)...", pollInterval)
|
||||
r.Start(pollInterval)
|
||||
}
|
||||
|
||||
func generateShortID() string {
|
||||
bytes := make([]byte, 4)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return fmt.Sprintf("%x", os.Getpid()^int(time.Now().Unix()))
|
||||
}
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
25
cmd/jiggablend/cmd/version.go
Normal file
25
cmd/jiggablend/cmd/version.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"jiggablend/version"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var versionCmd = &cobra.Command{
|
||||
Use: "version",
|
||||
Short: "Print the version information",
|
||||
Long: `Print the version and build date of jiggablend.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
fmt.Printf("jiggablend version %s\n", version.Version)
|
||||
if version.Date != "" {
|
||||
fmt.Printf("Build date: %s\n", version.Date)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(versionCmd)
|
||||
}
|
||||
14
cmd/jiggablend/main.go
Normal file
14
cmd/jiggablend/main.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"jiggablend/cmd/jiggablend/cmd"
|
||||
_ "jiggablend/version"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := cmd.Execute(); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"jiggablend/internal/api"
|
||||
"jiggablend/internal/auth"
|
||||
"jiggablend/internal/database"
|
||||
"jiggablend/internal/storage"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var (
|
||||
port = flag.String("port", getEnv("PORT", "8080"), "Server port")
|
||||
dbPath = flag.String("db", getEnv("DB_PATH", "jiggablend.db"), "Database path")
|
||||
storagePath = flag.String("storage", getEnv("STORAGE_PATH", "./jiggablend-storage"), "Storage path")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
// Initialize database
|
||||
db, err := database.NewDB(*dbPath)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to initialize database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Initialize auth
|
||||
authHandler, err := auth.NewAuth(db.DB)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to initialize auth: %v", err)
|
||||
}
|
||||
|
||||
// Initialize storage
|
||||
storageHandler, err := storage.NewStorage(*storagePath)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to initialize storage: %v", err)
|
||||
}
|
||||
|
||||
// Create API server
|
||||
server, err := api.NewServer(db, authHandler, storageHandler)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create server: %v", err)
|
||||
}
|
||||
|
||||
// Start server
|
||||
addr := fmt.Sprintf(":%s", *port)
|
||||
log.Printf("Starting manager server on %s", addr)
|
||||
log.Printf("Database: %s", *dbPath)
|
||||
log.Printf("Storage: %s", *storagePath)
|
||||
|
||||
if err := http.ListenAndServe(addr, server); err != nil {
|
||||
log.Fatalf("Server failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func getEnv(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
@@ -1,221 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"jiggablend/internal/runner"
|
||||
)
|
||||
|
||||
type SecretsFile struct {
|
||||
RunnerID int64 `json:"runner_id"`
|
||||
RunnerSecret string `json:"runner_secret"`
|
||||
ManagerSecret string `json:"manager_secret"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
var (
|
||||
managerURL = flag.String("manager", getEnv("MANAGER_URL", "http://localhost:8080"), "Manager URL")
|
||||
name = flag.String("name", getEnv("RUNNER_NAME", ""), "Runner name")
|
||||
hostname = flag.String("hostname", getEnv("RUNNER_HOSTNAME", ""), "Runner hostname")
|
||||
ipAddress = flag.String("ip", getEnv("RUNNER_IP", ""), "Runner IP address")
|
||||
token = flag.String("token", getEnv("REGISTRATION_TOKEN", ""), "Registration token")
|
||||
secretsFile = flag.String("secrets-file", getEnv("SECRETS_FILE", ""), "Path to secrets file for persistent storage (default: ./runner-secrets.json, or ./runner-secrets-{id}.json if multiple runners)")
|
||||
runnerIDSuffix = flag.String("runner-id", getEnv("RUNNER_ID", ""), "Unique runner ID suffix (auto-generated if not provided)")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
if *hostname == "" {
|
||||
*hostname, _ = os.Hostname()
|
||||
}
|
||||
if *ipAddress == "" {
|
||||
*ipAddress = "127.0.0.1"
|
||||
}
|
||||
|
||||
// Generate or use provided runner ID suffix
|
||||
runnerIDStr := *runnerIDSuffix
|
||||
if runnerIDStr == "" {
|
||||
runnerIDStr = generateShortID()
|
||||
}
|
||||
|
||||
// Generate runner name with ID if not provided
|
||||
if *name == "" {
|
||||
*name = fmt.Sprintf("runner-%s-%s", *hostname, runnerIDStr)
|
||||
} else {
|
||||
// Append ID to provided name to ensure uniqueness
|
||||
*name = fmt.Sprintf("%s-%s", *name, runnerIDStr)
|
||||
}
|
||||
|
||||
// Set default secrets file if not provided - always use current directory
|
||||
if *secretsFile == "" {
|
||||
if *runnerIDSuffix != "" || getEnv("RUNNER_ID", "") != "" {
|
||||
// Multiple runners - use local file with ID
|
||||
*secretsFile = fmt.Sprintf("./runner-secrets-%s.json", runnerIDStr)
|
||||
} else {
|
||||
// Single runner - use local file
|
||||
*secretsFile = "./runner-secrets.json"
|
||||
}
|
||||
}
|
||||
|
||||
client := runner.NewClient(*managerURL, *name, *hostname, *ipAddress)
|
||||
|
||||
// Probe capabilities once at startup (before any registration attempts)
|
||||
log.Printf("Probing runner capabilities...")
|
||||
client.ProbeCapabilities()
|
||||
capabilities := client.GetCapabilities()
|
||||
capList := []string{}
|
||||
for cap, value := range capabilities {
|
||||
// Only show boolean true capabilities and numeric GPU counts
|
||||
if enabled, ok := value.(bool); ok && enabled {
|
||||
capList = append(capList, cap)
|
||||
} else if count, ok := value.(int); ok && count > 0 {
|
||||
capList = append(capList, fmt.Sprintf("%s=%d", cap, count))
|
||||
} else if count, ok := value.(float64); ok && count > 0 {
|
||||
capList = append(capList, fmt.Sprintf("%s=%.0f", cap, count))
|
||||
}
|
||||
}
|
||||
if len(capList) > 0 {
|
||||
log.Printf("Detected capabilities: %s", strings.Join(capList, ", "))
|
||||
} else {
|
||||
log.Printf("Warning: No capabilities detected")
|
||||
}
|
||||
|
||||
// Try to load secrets from file
|
||||
var runnerID int64
|
||||
var runnerSecret, managerSecret string
|
||||
if *secretsFile != "" {
|
||||
if secrets, err := loadSecrets(*secretsFile); err == nil {
|
||||
runnerID = secrets.RunnerID
|
||||
runnerSecret = secrets.RunnerSecret
|
||||
managerSecret = secrets.ManagerSecret
|
||||
client.SetSecrets(runnerID, runnerSecret, managerSecret)
|
||||
log.Printf("Loaded secrets from %s", *secretsFile)
|
||||
}
|
||||
}
|
||||
|
||||
// If no secrets loaded, register with token (with retry logic)
|
||||
if runnerID == 0 {
|
||||
if *token == "" {
|
||||
log.Fatalf("Registration token required (use --token or set REGISTRATION_TOKEN env var)")
|
||||
}
|
||||
|
||||
// Retry registration with exponential backoff
|
||||
backoff := 1 * time.Second
|
||||
maxBackoff := 30 * time.Second
|
||||
maxRetries := 10
|
||||
retryCount := 0
|
||||
|
||||
for {
|
||||
var err error
|
||||
runnerID, runnerSecret, managerSecret, err = client.Register(*token)
|
||||
if err == nil {
|
||||
log.Printf("Registered runner with ID: %d", runnerID)
|
||||
|
||||
// Always save secrets to file (secretsFile is now always set to a default if not provided)
|
||||
secrets := SecretsFile{
|
||||
RunnerID: runnerID,
|
||||
RunnerSecret: runnerSecret,
|
||||
ManagerSecret: managerSecret,
|
||||
}
|
||||
if err := saveSecrets(*secretsFile, secrets); err != nil {
|
||||
log.Printf("Warning: Failed to save secrets to %s: %v", *secretsFile, err)
|
||||
} else {
|
||||
log.Printf("Saved secrets to %s", *secretsFile)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Check if it's a token error (invalid/expired/used token) - shutdown immediately
|
||||
errMsg := err.Error()
|
||||
if strings.Contains(errMsg, "token error:") {
|
||||
log.Fatalf("Registration failed (token error): %v", err)
|
||||
}
|
||||
|
||||
// Only retry on connection errors or other retryable errors
|
||||
retryCount++
|
||||
if retryCount >= maxRetries {
|
||||
log.Fatalf("Failed to register runner after %d attempts: %v", maxRetries, err)
|
||||
}
|
||||
|
||||
log.Printf("Registration failed (attempt %d/%d): %v, retrying in %v", retryCount, maxRetries, err, backoff)
|
||||
time.Sleep(backoff)
|
||||
backoff *= 2
|
||||
if backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start WebSocket connection with reconnection
|
||||
go client.ConnectWebSocketWithReconnect()
|
||||
|
||||
// Start heartbeat loop (for WebSocket ping/pong and HTTP fallback)
|
||||
go client.HeartbeatLoop()
|
||||
|
||||
// ProcessTasks is now handled via WebSocket, but kept for HTTP fallback
|
||||
// WebSocket will handle task assignment automatically
|
||||
log.Printf("Runner started, connecting to manager via WebSocket...")
|
||||
|
||||
// Set up signal handlers to kill processes on shutdown
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
sig := <-sigChan
|
||||
log.Printf("Received signal: %v, killing all processes and shutting down...", sig)
|
||||
client.KillAllProcesses()
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
// Block forever
|
||||
select {}
|
||||
}
|
||||
|
||||
func loadSecrets(path string) (*SecretsFile, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var secrets SecretsFile
|
||||
if err := json.Unmarshal(data, &secrets); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &secrets, nil
|
||||
}
|
||||
|
||||
func saveSecrets(path string, secrets SecretsFile) error {
|
||||
data, err := json.MarshalIndent(secrets, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(path, data, 0600)
|
||||
}
|
||||
|
||||
func getEnv(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// generateShortID generates a short random ID (8 hex characters)
|
||||
func generateShortID() string {
|
||||
bytes := make([]byte, 4)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
// Fallback to timestamp-based ID if crypto/rand fails
|
||||
return fmt.Sprintf("%x", os.Getpid()^int(time.Now().Unix()))
|
||||
}
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
BIN
examples/frame_0800.exr
Normal file
BIN
examples/frame_0800.exr
Normal file
Binary file not shown.
BIN
examples/frame_0800.png
Normal file
BIN
examples/frame_0800.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 24 MiB |
44
go.mod
44
go.mod
@@ -5,35 +5,33 @@ go 1.25.4
|
||||
require (
|
||||
github.com/go-chi/chi/v5 v5.2.3
|
||||
github.com/go-chi/cors v1.2.2
|
||||
github.com/golang-migrate/migrate/v4 v4.19.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/marcboeker/go-duckdb/v2 v2.4.3
|
||||
github.com/mattn/go-sqlite3 v1.14.32
|
||||
github.com/spf13/cobra v1.10.1
|
||||
github.com/spf13/viper v1.21.0
|
||||
golang.org/x/crypto v0.45.0
|
||||
golang.org/x/oauth2 v0.33.0
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
||||
github.com/apache/arrow-go/v18 v18.4.1 // indirect
|
||||
github.com/duckdb/duckdb-go-bindings v0.1.21 // indirect
|
||||
github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.21 // indirect
|
||||
github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.21 // indirect
|
||||
github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.21 // indirect
|
||||
github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.21 // indirect
|
||||
github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.21 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.5.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/google/flatbuffers v25.2.10+incompatible // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/marcboeker/go-duckdb/arrowmapping v0.0.21 // indirect
|
||||
github.com/marcboeker/go-duckdb/mapping v0.0.21 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.22 // indirect
|
||||
github.com/zeebo/xxh3 v1.0.2 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 // indirect
|
||||
golang.org/x/mod v0.27.0 // indirect
|
||||
golang.org/x/sync v0.16.0 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/tools v0.36.0 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
)
|
||||
|
||||
119
go.sum
119
go.sum
@@ -1,88 +1,79 @@
|
||||
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
|
||||
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
||||
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
|
||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||
github.com/apache/arrow-go/v18 v18.4.1 h1:q/jVkBWCJOB9reDgaIZIdruLQUb1kbkvOnOFezVH1C4=
|
||||
github.com/apache/arrow-go/v18 v18.4.1/go.mod h1:tLyFubsAl17bvFdUAy24bsSvA/6ww95Iqi67fTpGu3E=
|
||||
github.com/apache/thrift v0.22.0 h1:r7mTJdj51TMDe6RtcmNdQxgn9XcyfGDOzegMDRg47uc=
|
||||
github.com/apache/thrift v0.22.0/go.mod h1:1e7J/O1Ae6ZQMTYdy9xa3w9k+XHWPfRvdPyJeynQ+/g=
|
||||
cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY=
|
||||
cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/duckdb/duckdb-go-bindings v0.1.21 h1:bOb/MXNT4PN5JBZ7wpNg6hrj9+cuDjWDa4ee9UdbVyI=
|
||||
github.com/duckdb/duckdb-go-bindings v0.1.21/go.mod h1:pBnfviMzANT/9hi4bg+zW4ykRZZPCXlVuvBWEcZofkc=
|
||||
github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.21 h1:Sjjhf2F/zCjPF53c2VXOSKk0PzieMriSoyr5wfvr9d8=
|
||||
github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.21/go.mod h1:Ezo7IbAfB8NP7CqPIN8XEHKUg5xdRRQhcPPlCXImXYA=
|
||||
github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.21 h1:IUk0FFUB6dpWLhlN9hY1mmdPX7Hkn3QpyrAmn8pmS8g=
|
||||
github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.21/go.mod h1:eS7m/mLnPQgVF4za1+xTyorKRBuK0/BA44Oy6DgrGXI=
|
||||
github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.21 h1:Qpc7ZE3n6Nwz30KTvaAwI6nGkXjXmMxBTdFpC8zDEYI=
|
||||
github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.21/go.mod h1:1GOuk1PixiESxLaCGFhag+oFi7aP+9W8byymRAvunBk=
|
||||
github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.21 h1:eX2DhobAZOgjXkh8lPnKAyrxj8gXd2nm+K71f6KV/mo=
|
||||
github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.21/go.mod h1:o7crKMpT2eOIi5/FY6HPqaXcvieeLSqdXXaXbruGX7w=
|
||||
github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.21 h1:hhziFnGV7mpA+v5J5G2JnYQ+UWCCP3NQ+OTvxFX10D8=
|
||||
github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.21/go.mod h1:IlOhJdVKUJCAPj3QsDszUo8DVdvp1nBFp4TUJVdw99s=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE=
|
||||
github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
|
||||
github.com/go-chi/cors v1.2.2 h1:Jmey33TE+b+rB7fT8MUy1u0I4L+NARQlK6LhzKPSyQE=
|
||||
github.com/go-chi/cors v1.2.2/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
|
||||
github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/flatbuffers v25.2.10+incompatible h1:F3vclr7C3HpB1k9mxCGRMXq6FdUalZ6H/pNX4FP1v0Q=
|
||||
github.com/google/flatbuffers v25.2.10+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
|
||||
github.com/golang-migrate/migrate/v4 v4.19.0 h1:RcjOnCGz3Or6HQYEJ/EEVLfWnmw9KnoigPSjzhCuaSE=
|
||||
github.com/golang-migrate/migrate/v4 v4.19.0/go.mod h1:9dyEcu+hO+G9hPSw8AIg50yg622pXJsoHItQnDGZkI0=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4=
|
||||
github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/marcboeker/go-duckdb/arrowmapping v0.0.21 h1:geHnVjlsAJGczSWEqYigy/7ARuD+eBtjd0kLN80SPJQ=
|
||||
github.com/marcboeker/go-duckdb/arrowmapping v0.0.21/go.mod h1:flFTc9MSqQCh2Xm62RYvG3Kyj29h7OtsTb6zUx1CdK8=
|
||||
github.com/marcboeker/go-duckdb/mapping v0.0.21 h1:6woNXZn8EfYdc9Vbv0qR6acnt0TM1s1eFqnrJZVrqEs=
|
||||
github.com/marcboeker/go-duckdb/mapping v0.0.21/go.mod h1:q3smhpLyv2yfgkQd7gGHMd+H/Z905y+WYIUjrl29vT4=
|
||||
github.com/marcboeker/go-duckdb/v2 v2.4.3 h1:bHUkphPsAp2Bh/VFEdiprGpUekxBNZiWWtK+Bv/ljRk=
|
||||
github.com/marcboeker/go-duckdb/v2 v2.4.3/go.mod h1:taim9Hktg2igHdNBmg5vgTfHAlV26z3gBI0QXQOcuyI=
|
||||
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs=
|
||||
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY=
|
||||
github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI=
|
||||
github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE=
|
||||
github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU=
|
||||
github.com/pierrec/lz4/v4 v4.1.22/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
|
||||
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
||||
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
|
||||
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.11.0 h1:ib4sjIrwZKxE5u/Japgo/7SJV3PvgjGiRNAvTVGqQl8=
|
||||
github.com/stretchr/testify v1.11.0/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ=
|
||||
github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0=
|
||||
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
|
||||
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
|
||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
|
||||
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
|
||||
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
||||
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
|
||||
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
|
||||
github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s=
|
||||
github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0=
|
||||
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
||||
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
|
||||
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM=
|
||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8=
|
||||
golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ=
|
||||
golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc=
|
||||
golang.org/x/oauth2 v0.33.0 h1:4Q+qn+E5z8gPRJfmRy7C2gGG3T4jIprK6aSYgTXGRpo=
|
||||
golang.org/x/oauth2 v0.33.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg=
|
||||
golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s=
|
||||
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY=
|
||||
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
1955
internal/api/jobs.go
1955
internal/api/jobs.go
File diff suppressed because it is too large
Load Diff
@@ -1,158 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"jiggablend/pkg/types"
|
||||
)
|
||||
|
||||
// handleSubmitMetadata handles metadata submission from runner
|
||||
func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) {
|
||||
jobID, err := parseID(r, "jobId")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Get runner ID from context (set by runnerAuthMiddleware)
|
||||
runnerID, ok := r.Context().Value(runnerIDContextKey).(int64)
|
||||
if !ok {
|
||||
s.respondError(w, http.StatusUnauthorized, "runner_id not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
var metadata types.BlendMetadata
|
||||
if err := json.NewDecoder(r.Body).Decode(&metadata); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid metadata JSON")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify job exists
|
||||
var jobUserID int64
|
||||
err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Job not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify job: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Find the metadata extraction task for this job
|
||||
// First try to find task assigned to this runner, then fall back to any metadata task for this job
|
||||
var taskID int64
|
||||
err = s.db.QueryRow(
|
||||
`SELECT id FROM tasks WHERE job_id = ? AND task_type = ? AND runner_id = ?`,
|
||||
jobID, types.TaskTypeMetadata, runnerID,
|
||||
).Scan(&taskID)
|
||||
if err == sql.ErrNoRows {
|
||||
// Fall back to any metadata task for this job (in case assignment changed)
|
||||
err = s.db.QueryRow(
|
||||
`SELECT id FROM tasks WHERE job_id = ? AND task_type = ? ORDER BY created_at DESC LIMIT 1`,
|
||||
jobID, types.TaskTypeMetadata,
|
||||
).Scan(&taskID)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Metadata extraction task not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to find task: %v", err))
|
||||
return
|
||||
}
|
||||
// Update the task to be assigned to this runner if it wasn't already
|
||||
s.db.Exec(
|
||||
`UPDATE tasks SET runner_id = ? WHERE id = ? AND runner_id IS NULL`,
|
||||
runnerID, taskID,
|
||||
)
|
||||
} else if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to find task: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Convert metadata to JSON
|
||||
metadataJSON, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, "Failed to marshal metadata")
|
||||
return
|
||||
}
|
||||
|
||||
// Update job with metadata
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE jobs SET blend_metadata = ? WHERE id = ?`,
|
||||
string(metadataJSON), jobID,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update job metadata: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Mark task as completed
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, completed_at = CURRENT_TIMESTAMP WHERE id = ?`,
|
||||
types.TaskStatusCompleted, taskID,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to mark metadata task as completed: %v", err)
|
||||
} else {
|
||||
// Update job status and progress after metadata task completes
|
||||
s.updateJobStatusFromTasks(jobID)
|
||||
}
|
||||
|
||||
log.Printf("Metadata extracted for job %d: frame_start=%d, frame_end=%d", jobID, metadata.FrameStart, metadata.FrameEnd)
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Metadata submitted successfully"})
|
||||
}
|
||||
|
||||
// handleGetJobMetadata retrieves metadata for a job
|
||||
func (s *Server) handleGetJobMetadata(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := getUserID(r)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jobID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Verify job belongs to user
|
||||
var jobUserID int64
|
||||
var blendMetadataJSON sql.NullString
|
||||
err = s.db.QueryRow(
|
||||
`SELECT user_id, blend_metadata FROM jobs WHERE id = ?`,
|
||||
jobID,
|
||||
).Scan(&jobUserID, &blendMetadataJSON)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Job not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query job: %v", err))
|
||||
return
|
||||
}
|
||||
if jobUserID != userID {
|
||||
s.respondError(w, http.StatusForbidden, "Access denied")
|
||||
return
|
||||
}
|
||||
|
||||
if !blendMetadataJSON.Valid || blendMetadataJSON.String == "" {
|
||||
s.respondJSON(w, http.StatusOK, nil)
|
||||
return
|
||||
}
|
||||
|
||||
var metadata types.BlendMetadata
|
||||
if err := json.Unmarshal([]byte(blendMetadataJSON.String), &metadata); err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, "Failed to parse metadata")
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, metadata)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,658 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
authpkg "jiggablend/internal/auth"
|
||||
"jiggablend/internal/database"
|
||||
"jiggablend/internal/storage"
|
||||
"jiggablend/pkg/types"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/go-chi/cors"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// Server represents the API server
|
||||
type Server struct {
|
||||
db *database.DB
|
||||
auth *authpkg.Auth
|
||||
secrets *authpkg.Secrets
|
||||
storage *storage.Storage
|
||||
router *chi.Mux
|
||||
|
||||
// WebSocket connections
|
||||
wsUpgrader websocket.Upgrader
|
||||
runnerConns map[int64]*websocket.Conn
|
||||
runnerConnsMu sync.RWMutex
|
||||
frontendConns map[string]*websocket.Conn // key: "jobId:taskId"
|
||||
frontendConnsMu sync.RWMutex
|
||||
// Mutexes for each frontend connection to serialize writes
|
||||
frontendConnsWriteMu map[string]*sync.Mutex // key: "jobId:taskId"
|
||||
frontendConnsWriteMuMu sync.RWMutex
|
||||
// Throttling for progress updates (per job)
|
||||
progressUpdateTimes map[int64]time.Time // key: jobID
|
||||
progressUpdateTimesMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewServer creates a new API server
|
||||
func NewServer(db *database.DB, auth *authpkg.Auth, storage *storage.Storage) (*Server, error) {
|
||||
secrets, err := authpkg.NewSecrets(db.DB)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize secrets: %w", err)
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
db: db,
|
||||
auth: auth,
|
||||
secrets: secrets,
|
||||
storage: storage,
|
||||
router: chi.NewRouter(),
|
||||
wsUpgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // Allow all origins for now
|
||||
},
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
},
|
||||
runnerConns: make(map[int64]*websocket.Conn),
|
||||
frontendConns: make(map[string]*websocket.Conn),
|
||||
frontendConnsWriteMu: make(map[string]*sync.Mutex),
|
||||
progressUpdateTimes: make(map[int64]time.Time),
|
||||
}
|
||||
|
||||
s.setupMiddleware()
|
||||
s.setupRoutes()
|
||||
s.StartBackgroundTasks()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// setupMiddleware configures middleware
|
||||
func (s *Server) setupMiddleware() {
|
||||
s.router.Use(middleware.Logger)
|
||||
s.router.Use(middleware.Recoverer)
|
||||
// Note: Timeout middleware is NOT applied globally to avoid conflicts with WebSocket connections
|
||||
// WebSocket connections are long-lived and should not have HTTP timeouts
|
||||
|
||||
s.router.Use(cors.Handler(cors.Options{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "Range"},
|
||||
ExposedHeaders: []string{"Link", "Content-Range", "Accept-Ranges", "Content-Length"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 300,
|
||||
}))
|
||||
}
|
||||
|
||||
// setupRoutes configures routes
|
||||
func (s *Server) setupRoutes() {
|
||||
// Public routes
|
||||
s.router.Route("/api/auth", func(r chi.Router) {
|
||||
r.Get("/providers", s.handleGetAuthProviders)
|
||||
r.Get("/google/login", s.handleGoogleLogin)
|
||||
r.Get("/google/callback", s.handleGoogleCallback)
|
||||
r.Get("/discord/login", s.handleDiscordLogin)
|
||||
r.Get("/discord/callback", s.handleDiscordCallback)
|
||||
r.Get("/local/available", s.handleLocalLoginAvailable)
|
||||
r.Post("/local/register", s.handleLocalRegister)
|
||||
r.Post("/local/login", s.handleLocalLogin)
|
||||
r.Post("/logout", s.handleLogout)
|
||||
r.Get("/me", s.handleGetMe)
|
||||
r.Post("/change-password", s.handleChangePassword)
|
||||
})
|
||||
|
||||
// Protected routes
|
||||
s.router.Route("/api/jobs", func(r chi.Router) {
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(s.auth.Middleware(next.ServeHTTP))
|
||||
})
|
||||
r.Post("/", s.handleCreateJob)
|
||||
r.Get("/", s.handleListJobs)
|
||||
r.Get("/{id}", s.handleGetJob)
|
||||
r.Delete("/{id}", s.handleCancelJob)
|
||||
r.Post("/{id}/delete", s.handleDeleteJob)
|
||||
r.Post("/{id}/upload", s.handleUploadJobFile)
|
||||
r.Get("/{id}/files", s.handleListJobFiles)
|
||||
r.Get("/{id}/files/{fileId}/download", s.handleDownloadJobFile)
|
||||
r.Get("/{id}/video", s.handleStreamVideo)
|
||||
r.Get("/{id}/metadata", s.handleGetJobMetadata)
|
||||
r.Get("/{id}/tasks", s.handleListJobTasks)
|
||||
r.Get("/{id}/tasks/{taskId}/logs", s.handleGetTaskLogs)
|
||||
// WebSocket route - no timeout middleware (long-lived connection)
|
||||
r.With(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Remove timeout middleware for WebSocket
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}).Get("/{id}/tasks/{taskId}/logs/ws", s.handleStreamTaskLogsWebSocket)
|
||||
r.Get("/{id}/tasks/{taskId}/steps", s.handleGetTaskSteps)
|
||||
r.Post("/{id}/tasks/{taskId}/retry", s.handleRetryTask)
|
||||
})
|
||||
|
||||
// Admin routes
|
||||
s.router.Route("/api/admin", func(r chi.Router) {
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(s.auth.AdminMiddleware(next.ServeHTTP))
|
||||
})
|
||||
r.Route("/runners", func(r chi.Router) {
|
||||
r.Route("/tokens", func(r chi.Router) {
|
||||
r.Post("/", s.handleGenerateRegistrationToken)
|
||||
r.Get("/", s.handleListRegistrationTokens)
|
||||
r.Delete("/{id}", s.handleRevokeRegistrationToken)
|
||||
})
|
||||
r.Get("/", s.handleListRunnersAdmin)
|
||||
r.Post("/{id}/verify", s.handleVerifyRunner)
|
||||
r.Delete("/{id}", s.handleDeleteRunner)
|
||||
})
|
||||
r.Route("/users", func(r chi.Router) {
|
||||
r.Get("/", s.handleListUsers)
|
||||
r.Get("/{id}/jobs", s.handleGetUserJobs)
|
||||
r.Post("/{id}/admin", s.handleSetUserAdminStatus)
|
||||
})
|
||||
r.Route("/settings", func(r chi.Router) {
|
||||
r.Get("/registration", s.handleGetRegistrationEnabled)
|
||||
r.Post("/registration", s.handleSetRegistrationEnabled)
|
||||
})
|
||||
})
|
||||
|
||||
// Runner API
|
||||
s.router.Route("/api/runner", func(r chi.Router) {
|
||||
// Registration doesn't require auth (uses token)
|
||||
r.With(middleware.Timeout(60*time.Second)).Post("/register", s.handleRegisterRunner)
|
||||
|
||||
// WebSocket endpoint (auth handled in handler) - no timeout middleware
|
||||
r.Get("/ws", s.handleRunnerWebSocket)
|
||||
|
||||
// File operations still use HTTP (WebSocket not suitable for large files)
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(s.runnerAuthMiddleware(next.ServeHTTP))
|
||||
})
|
||||
r.Post("/tasks/{id}/progress", s.handleUpdateTaskProgress)
|
||||
r.Post("/tasks/{id}/steps", s.handleUpdateTaskStep)
|
||||
r.Get("/files/{jobId}/*", s.handleDownloadFileForRunner)
|
||||
r.Post("/files/{jobId}/upload", s.handleUploadFileFromRunner)
|
||||
r.Get("/jobs/{jobId}/status", s.handleGetJobStatusForRunner)
|
||||
r.Get("/jobs/{jobId}/files", s.handleGetJobFilesForRunner)
|
||||
r.Post("/jobs/{jobId}/metadata", s.handleSubmitMetadata)
|
||||
})
|
||||
})
|
||||
|
||||
// Serve static files (built React app)
|
||||
s.router.Handle("/*", http.FileServer(http.Dir("./web/dist")))
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.router.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// JSON response helpers
|
||||
func (s *Server) respondJSON(w http.ResponseWriter, status int, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
if err := json.NewEncoder(w).Encode(data); err != nil {
|
||||
log.Printf("Failed to encode JSON response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) respondError(w http.ResponseWriter, status int, message string) {
|
||||
s.respondJSON(w, status, map[string]string{"error": message})
|
||||
}
|
||||
|
||||
// Auth handlers
|
||||
func (s *Server) handleGoogleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
url, err := s.auth.GoogleLoginURL()
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, url, http.StatusFound)
|
||||
}
|
||||
|
||||
func (s *Server) handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "Missing code parameter")
|
||||
return
|
||||
}
|
||||
|
||||
session, err := s.auth.GoogleCallback(r.Context(), code)
|
||||
if err != nil {
|
||||
// If registration is disabled, redirect back to login with error
|
||||
if err.Error() == "registration is disabled" {
|
||||
http.Redirect(w, r, "/?error=registration_disabled", http.StatusFound)
|
||||
return
|
||||
}
|
||||
s.respondError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionID := s.auth.CreateSession(session)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_id",
|
||||
Value: sessionID,
|
||||
Path: "/",
|
||||
MaxAge: 86400,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
}
|
||||
|
||||
func (s *Server) handleDiscordLogin(w http.ResponseWriter, r *http.Request) {
|
||||
url, err := s.auth.DiscordLoginURL()
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, url, http.StatusFound)
|
||||
}
|
||||
|
||||
func (s *Server) handleDiscordCallback(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "Missing code parameter")
|
||||
return
|
||||
}
|
||||
|
||||
session, err := s.auth.DiscordCallback(r.Context(), code)
|
||||
if err != nil {
|
||||
// If registration is disabled, redirect back to login with error
|
||||
if err.Error() == "registration is disabled" {
|
||||
http.Redirect(w, r, "/?error=registration_disabled", http.StatusFound)
|
||||
return
|
||||
}
|
||||
s.respondError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionID := s.auth.CreateSession(session)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_id",
|
||||
Value: sessionID,
|
||||
Path: "/",
|
||||
MaxAge: 86400,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
}
|
||||
|
||||
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("session_id")
|
||||
if err == nil {
|
||||
s.auth.DeleteSession(cookie.Value)
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_id",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
})
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Logged out"})
|
||||
}
|
||||
|
||||
func (s *Server) handleGetMe(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("session_id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, "Not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
session, ok := s.auth.GetSession(cookie.Value)
|
||||
if !ok {
|
||||
s.respondError(w, http.StatusUnauthorized, "Invalid session")
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"id": session.UserID,
|
||||
"email": session.Email,
|
||||
"name": session.Name,
|
||||
"is_admin": session.IsAdmin,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleGetAuthProviders(w http.ResponseWriter, r *http.Request) {
|
||||
s.respondJSON(w, http.StatusOK, map[string]bool{
|
||||
"google": s.auth.IsGoogleOAuthConfigured(),
|
||||
"discord": s.auth.IsDiscordOAuthConfigured(),
|
||||
"local": s.auth.IsLocalLoginEnabled(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleLocalLoginAvailable(w http.ResponseWriter, r *http.Request) {
|
||||
s.respondJSON(w, http.StatusOK, map[string]bool{
|
||||
"available": s.auth.IsLocalLoginEnabled(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleLocalRegister(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Email == "" || req.Name == "" || req.Password == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "Email, name, and password are required")
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Password) < 8 {
|
||||
s.respondError(w, http.StatusBadRequest, "Password must be at least 8 characters long")
|
||||
return
|
||||
}
|
||||
|
||||
session, err := s.auth.RegisterLocalUser(req.Email, req.Name, req.Password)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionID := s.auth.CreateSession(session)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_id",
|
||||
Value: sessionID,
|
||||
Path: "/",
|
||||
MaxAge: 86400,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"message": "Registration successful",
|
||||
"user": map[string]interface{}{
|
||||
"id": session.UserID,
|
||||
"email": session.Email,
|
||||
"name": session.Name,
|
||||
"is_admin": session.IsAdmin,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleLocalLogin(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Username == "" || req.Password == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "Username and password are required")
|
||||
return
|
||||
}
|
||||
|
||||
session, err := s.auth.LocalLogin(req.Username, req.Password)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, "Invalid credentials")
|
||||
return
|
||||
}
|
||||
|
||||
sessionID := s.auth.CreateSession(session)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_id",
|
||||
Value: sessionID,
|
||||
Path: "/",
|
||||
MaxAge: 86400,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"message": "Login successful",
|
||||
"user": map[string]interface{}{
|
||||
"id": session.UserID,
|
||||
"email": session.Email,
|
||||
"name": session.Name,
|
||||
"is_admin": session.IsAdmin,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleChangePassword(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := getUserID(r)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
OldPassword string `json:"old_password"`
|
||||
NewPassword string `json:"new_password"`
|
||||
TargetUserID *int64 `json:"target_user_id,omitempty"` // For admin to change other users' passwords
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.NewPassword == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "New password is required")
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.NewPassword) < 8 {
|
||||
s.respondError(w, http.StatusBadRequest, "Password must be at least 8 characters long")
|
||||
return
|
||||
}
|
||||
|
||||
isAdmin := authpkg.IsAdmin(r.Context())
|
||||
|
||||
// If target_user_id is provided and user is admin, allow changing other user's password
|
||||
if req.TargetUserID != nil && isAdmin {
|
||||
if err := s.auth.AdminChangePassword(*req.TargetUserID, req.NewPassword); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Password changed successfully"})
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, user is changing their own password (requires old password)
|
||||
if req.OldPassword == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "Old password is required")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.auth.ChangePassword(userID, req.OldPassword, req.NewPassword); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Password changed successfully"})
|
||||
}
|
||||
|
||||
// Helper to get user ID from context
|
||||
func getUserID(r *http.Request) (int64, error) {
|
||||
userID, ok := authpkg.GetUserID(r.Context())
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("user ID not found in context")
|
||||
}
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
// Helper to parse ID from URL
|
||||
func parseID(r *http.Request, param string) (int64, error) {
|
||||
idStr := chi.URLParam(r, param)
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid ID: %s", idStr)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// StartBackgroundTasks starts background goroutines for error recovery
|
||||
func (s *Server) StartBackgroundTasks() {
|
||||
go s.recoverStuckTasks()
|
||||
go s.cleanupOldMetadataJobs()
|
||||
}
|
||||
|
||||
// recoverStuckTasks periodically checks for dead runners and stuck tasks
|
||||
func (s *Server) recoverStuckTasks() {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Also distribute tasks every 10 seconds (reduced frequency since we have event-driven distribution)
|
||||
distributeTicker := time.NewTicker(10 * time.Second)
|
||||
defer distributeTicker.Stop()
|
||||
|
||||
go func() {
|
||||
for range distributeTicker.C {
|
||||
s.distributeTasksToRunners()
|
||||
}
|
||||
}()
|
||||
|
||||
for range ticker.C {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("Panic in recoverStuckTasks: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Find dead runners (no heartbeat for 90 seconds)
|
||||
// But only mark as dead if they're not actually connected via WebSocket
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id FROM runners
|
||||
WHERE last_heartbeat < CURRENT_TIMESTAMP - INTERVAL '90 seconds'
|
||||
AND status = ?`,
|
||||
types.RunnerStatusOnline,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to query dead runners: %v", err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var deadRunnerIDs []int64
|
||||
s.runnerConnsMu.RLock()
|
||||
for rows.Next() {
|
||||
var runnerID int64
|
||||
if err := rows.Scan(&runnerID); err == nil {
|
||||
// Only mark as dead if not actually connected via WebSocket
|
||||
// The WebSocket connection is the source of truth
|
||||
if _, stillConnected := s.runnerConns[runnerID]; !stillConnected {
|
||||
deadRunnerIDs = append(deadRunnerIDs, runnerID)
|
||||
}
|
||||
// If still connected, heartbeat should be updated by pong handler or heartbeat message
|
||||
// No need to manually update here - if it's stale, the pong handler isn't working
|
||||
}
|
||||
}
|
||||
s.runnerConnsMu.RUnlock()
|
||||
rows.Close()
|
||||
|
||||
if len(deadRunnerIDs) == 0 {
|
||||
// Check for task timeouts
|
||||
s.recoverTaskTimeouts()
|
||||
return
|
||||
}
|
||||
|
||||
// Reset tasks assigned to dead runners
|
||||
for _, runnerID := range deadRunnerIDs {
|
||||
s.redistributeRunnerTasks(runnerID)
|
||||
|
||||
// Mark runner as offline
|
||||
_, _ = s.db.Exec(
|
||||
`UPDATE runners SET status = ? WHERE id = ?`,
|
||||
types.RunnerStatusOffline, runnerID,
|
||||
)
|
||||
}
|
||||
|
||||
// Check for task timeouts
|
||||
s.recoverTaskTimeouts()
|
||||
|
||||
// Distribute newly recovered tasks
|
||||
s.distributeTasksToRunners()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// recoverTaskTimeouts handles tasks that have exceeded their timeout
|
||||
func (s *Server) recoverTaskTimeouts() {
|
||||
// Find tasks running longer than their timeout
|
||||
rows, err := s.db.Query(
|
||||
`SELECT t.id, t.runner_id, t.retry_count, t.max_retries, t.timeout_seconds, t.started_at
|
||||
FROM tasks t
|
||||
WHERE t.status = ?
|
||||
AND t.started_at IS NOT NULL
|
||||
AND (t.timeout_seconds IS NULL OR
|
||||
t.started_at + INTERVAL (t.timeout_seconds || ' seconds') < CURRENT_TIMESTAMP)`,
|
||||
types.TaskStatusRunning,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to query timed out tasks: %v", err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var taskID int64
|
||||
var runnerID sql.NullInt64
|
||||
var retryCount, maxRetries int
|
||||
var timeoutSeconds sql.NullInt64
|
||||
var startedAt time.Time
|
||||
|
||||
err := rows.Scan(&taskID, &runnerID, &retryCount, &maxRetries, &timeoutSeconds, &startedAt)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Use default timeout if not set (5 minutes for frame tasks, 24 hours for FFmpeg)
|
||||
timeout := 300 // 5 minutes default
|
||||
if timeoutSeconds.Valid {
|
||||
timeout = int(timeoutSeconds.Int64)
|
||||
}
|
||||
|
||||
// Check if actually timed out
|
||||
if time.Since(startedAt).Seconds() < float64(timeout) {
|
||||
continue
|
||||
}
|
||||
|
||||
if retryCount >= maxRetries {
|
||||
// Mark as failed
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL
|
||||
WHERE id = ?`,
|
||||
types.TaskStatusFailed, "Task timeout exceeded, max retries reached", taskID,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to mark task %d as failed: %v", taskID, err)
|
||||
}
|
||||
} else {
|
||||
// Reset to pending
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL,
|
||||
retry_count = retry_count + 1 WHERE id = ?`,
|
||||
types.TaskStatusPending, taskID,
|
||||
)
|
||||
if err == nil {
|
||||
// Add log entry using the helper function
|
||||
s.logTaskEvent(taskID, nil, types.LogLevelWarn, fmt.Sprintf("Task timeout exceeded, resetting (retry %d/%d)", retryCount+1, maxRetries), "")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5,10 +5,13 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"jiggablend/internal/config"
|
||||
"jiggablend/internal/database"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -17,12 +20,31 @@ import (
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
// Context key types to avoid collisions (typed keys are safer than string keys)
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
contextKeyUserID contextKey = iota
|
||||
contextKeyUserEmail
|
||||
contextKeyUserName
|
||||
contextKeyIsAdmin
|
||||
)
|
||||
|
||||
// Configuration constants
|
||||
const (
|
||||
SessionDuration = 24 * time.Hour
|
||||
SessionCleanupInterval = 1 * time.Hour
|
||||
)
|
||||
|
||||
// Auth handles authentication
|
||||
type Auth struct {
|
||||
db *sql.DB
|
||||
db *database.DB
|
||||
cfg *config.Config
|
||||
googleConfig *oauth2.Config
|
||||
discordConfig *oauth2.Config
|
||||
sessionStore map[string]*Session
|
||||
sessionCache map[string]*Session // In-memory cache for performance
|
||||
cacheMu sync.RWMutex
|
||||
stopCleanup chan struct{}
|
||||
}
|
||||
|
||||
// Session represents a user session
|
||||
@@ -35,41 +57,53 @@ type Session struct {
|
||||
}
|
||||
|
||||
// NewAuth creates a new auth instance
|
||||
func NewAuth(db *sql.DB) (*Auth, error) {
|
||||
func NewAuth(db *database.DB, cfg *config.Config) (*Auth, error) {
|
||||
auth := &Auth{
|
||||
db: db,
|
||||
sessionStore: make(map[string]*Session),
|
||||
cfg: cfg,
|
||||
sessionCache: make(map[string]*Session),
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Initialize Google OAuth
|
||||
googleClientID := os.Getenv("GOOGLE_CLIENT_ID")
|
||||
googleClientSecret := os.Getenv("GOOGLE_CLIENT_SECRET")
|
||||
// Initialize Google OAuth from database config
|
||||
googleClientID := cfg.GoogleClientID()
|
||||
googleClientSecret := cfg.GoogleClientSecret()
|
||||
if googleClientID != "" && googleClientSecret != "" {
|
||||
auth.googleConfig = &oauth2.Config{
|
||||
ClientID: googleClientID,
|
||||
ClientSecret: googleClientSecret,
|
||||
RedirectURL: os.Getenv("GOOGLE_REDIRECT_URL"),
|
||||
RedirectURL: cfg.GoogleRedirectURL(),
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
log.Printf("Google OAuth configured")
|
||||
}
|
||||
|
||||
// Initialize Discord OAuth
|
||||
discordClientID := os.Getenv("DISCORD_CLIENT_ID")
|
||||
discordClientSecret := os.Getenv("DISCORD_CLIENT_SECRET")
|
||||
// Initialize Discord OAuth from database config
|
||||
discordClientID := cfg.DiscordClientID()
|
||||
discordClientSecret := cfg.DiscordClientSecret()
|
||||
if discordClientID != "" && discordClientSecret != "" {
|
||||
auth.discordConfig = &oauth2.Config{
|
||||
ClientID: discordClientID,
|
||||
ClientSecret: discordClientSecret,
|
||||
RedirectURL: os.Getenv("DISCORD_REDIRECT_URL"),
|
||||
RedirectURL: cfg.DiscordRedirectURL(),
|
||||
Scopes: []string{"identify", "email"},
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://discord.com/api/oauth2/authorize",
|
||||
TokenURL: "https://discord.com/api/oauth2/token",
|
||||
},
|
||||
}
|
||||
log.Printf("Discord OAuth configured")
|
||||
}
|
||||
|
||||
// Load existing sessions from database into cache
|
||||
if err := auth.loadSessionsFromDB(); err != nil {
|
||||
log.Printf("Warning: Failed to load sessions from database: %v", err)
|
||||
}
|
||||
|
||||
// Start background cleanup goroutine
|
||||
go auth.cleanupExpiredSessions()
|
||||
|
||||
// Initialize admin settings on startup to ensure they persist between boots
|
||||
if err := auth.initializeSettings(); err != nil {
|
||||
log.Printf("Warning: Failed to initialize admin settings: %v", err)
|
||||
@@ -85,19 +119,119 @@ func NewAuth(db *sql.DB) (*Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// Close stops background goroutines
|
||||
func (a *Auth) Close() {
|
||||
close(a.stopCleanup)
|
||||
}
|
||||
|
||||
// loadSessionsFromDB loads all valid sessions from database into cache
|
||||
func (a *Auth) loadSessionsFromDB() error {
|
||||
var sessions []struct {
|
||||
sessionID string
|
||||
session Session
|
||||
}
|
||||
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
rows, err := conn.Query(
|
||||
`SELECT session_id, user_id, email, name, is_admin, expires_at
|
||||
FROM sessions WHERE expires_at > CURRENT_TIMESTAMP`,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query sessions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var s struct {
|
||||
sessionID string
|
||||
session Session
|
||||
}
|
||||
err := rows.Scan(&s.sessionID, &s.session.UserID, &s.session.Email, &s.session.Name, &s.session.IsAdmin, &s.session.ExpiresAt)
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to scan session row: %v", err)
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, s)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a.cacheMu.Lock()
|
||||
defer a.cacheMu.Unlock()
|
||||
|
||||
for _, s := range sessions {
|
||||
a.sessionCache[s.sessionID] = &s.session
|
||||
}
|
||||
|
||||
if len(sessions) > 0 {
|
||||
log.Printf("Loaded %d active sessions from database", len(sessions))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupExpiredSessions periodically removes expired sessions from database and cache
|
||||
func (a *Auth) cleanupExpiredSessions() {
|
||||
ticker := time.NewTicker(SessionCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// Delete expired sessions from database
|
||||
var deleted int64
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
result, err := conn.Exec(`DELETE FROM sessions WHERE expires_at < CURRENT_TIMESTAMP`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
deleted, _ = result.RowsAffected()
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to cleanup expired sessions: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Clean up cache
|
||||
a.cacheMu.Lock()
|
||||
now := time.Now()
|
||||
for sessionID, session := range a.sessionCache {
|
||||
if now.After(session.ExpiresAt) {
|
||||
delete(a.sessionCache, sessionID)
|
||||
}
|
||||
}
|
||||
a.cacheMu.Unlock()
|
||||
|
||||
if deleted > 0 {
|
||||
log.Printf("Cleaned up %d expired sessions", deleted)
|
||||
}
|
||||
case <-a.stopCleanup:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// initializeSettings ensures all admin settings are initialized with defaults if they don't exist
|
||||
func (a *Auth) initializeSettings() error {
|
||||
// Initialize registration_enabled setting (default: true) if it doesn't exist
|
||||
var settingCount int
|
||||
err := a.db.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check registration_enabled setting: %w", err)
|
||||
}
|
||||
if settingCount == 0 {
|
||||
_, err = a.db.Exec(
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec(
|
||||
`INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)`,
|
||||
"registration_enabled", "true",
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize registration_enabled setting: %w", err)
|
||||
}
|
||||
@@ -118,7 +252,9 @@ func (a *Auth) initializeTestUser() error {
|
||||
|
||||
// Check if user already exists
|
||||
var exists bool
|
||||
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND oauth_provider = 'local')", testEmail).Scan(&exists)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND oauth_provider = 'local')", testEmail).Scan(&exists)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if test user exists: %w", err)
|
||||
}
|
||||
@@ -137,7 +273,12 @@ func (a *Auth) initializeTestUser() error {
|
||||
|
||||
// Check if this is the first user (make them admin)
|
||||
var userCount int
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check user count: %w", err)
|
||||
}
|
||||
isAdmin := userCount == 0
|
||||
|
||||
// Create test user (use email as name if no name is provided)
|
||||
@@ -147,10 +288,13 @@ func (a *Auth) initializeTestUser() error {
|
||||
}
|
||||
|
||||
// Create test user
|
||||
_, err = a.db.Exec(
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?)",
|
||||
testEmail, testName, testEmail, string(hashedPassword), isAdmin,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create test user: %w", err)
|
||||
}
|
||||
@@ -242,7 +386,9 @@ func (a *Auth) DiscordCallback(ctx context.Context, code string) (*Session, erro
|
||||
// IsRegistrationEnabled checks if new user registration is enabled
|
||||
func (a *Auth) IsRegistrationEnabled() (bool, error) {
|
||||
var value string
|
||||
err := a.db.QueryRow("SELECT value FROM settings WHERE key = ?", "registration_enabled").Scan(&value)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT value FROM settings WHERE key = ?", "registration_enabled").Scan(&value)
|
||||
})
|
||||
if err == sql.ErrNoRows {
|
||||
// Default to enabled if setting doesn't exist
|
||||
return true, nil
|
||||
@@ -262,29 +408,31 @@ func (a *Auth) SetRegistrationEnabled(enabled bool) error {
|
||||
|
||||
// Check if setting exists
|
||||
var exists bool
|
||||
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", "registration_enabled").Scan(&exists)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", "registration_enabled").Scan(&exists)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if setting exists: %w", err)
|
||||
}
|
||||
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
if exists {
|
||||
// Update existing setting
|
||||
_, err = a.db.Exec(
|
||||
_, err = conn.Exec(
|
||||
"UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?",
|
||||
value, "registration_enabled",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update setting: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Insert new setting
|
||||
_, err = a.db.Exec(
|
||||
_, err = conn.Exec(
|
||||
"INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)",
|
||||
"registration_enabled", value,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert setting: %w", err)
|
||||
}
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set registration_enabled: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -299,17 +447,21 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
|
||||
var dbProvider, dbOAuthID string
|
||||
|
||||
// First, try to find by provider + oauth_id
|
||||
err := a.db.QueryRow(
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow(
|
||||
"SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE oauth_provider = ? AND oauth_id = ?",
|
||||
provider, oauthID,
|
||||
).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID)
|
||||
})
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
// Not found by provider+oauth_id, check by email for account linking
|
||||
err = a.db.QueryRow(
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow(
|
||||
"SELECT id, email, name, is_admin, oauth_provider, oauth_id FROM users WHERE email = ?",
|
||||
email,
|
||||
).Scan(&userID, &dbEmail, &dbName, &isAdmin, &dbProvider, &dbOAuthID)
|
||||
})
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
// User doesn't exist, check if registration is enabled
|
||||
@@ -323,14 +475,26 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
|
||||
|
||||
// Check if this is the first user
|
||||
var userCount int
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check user count: %w", err)
|
||||
}
|
||||
isAdmin = userCount == 0
|
||||
|
||||
// Create new user
|
||||
err = a.db.QueryRow(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?) RETURNING id",
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
result, err := conn.Exec(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?)",
|
||||
email, name, provider, oauthID, isAdmin,
|
||||
).Scan(&userID)
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userID, err = result.LastInsertId()
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
@@ -339,10 +503,13 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
|
||||
} else {
|
||||
// User exists with same email but different provider - link accounts by updating provider info
|
||||
// This allows the user to log in with any provider that has the same email
|
||||
_, err = a.db.Exec(
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
_, err = conn.Exec(
|
||||
"UPDATE users SET oauth_provider = ?, oauth_id = ?, name = ? WHERE id = ?",
|
||||
provider, oauthID, name, userID,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to link account: %w", err)
|
||||
}
|
||||
@@ -353,10 +520,13 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
|
||||
} else {
|
||||
// User found by provider+oauth_id, update info if changed
|
||||
if dbEmail != email || dbName != name {
|
||||
_, err = a.db.Exec(
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
_, err = conn.Exec(
|
||||
"UPDATE users SET email = ?, name = ? WHERE id = ?",
|
||||
email, name, userID,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
@@ -368,41 +538,134 @@ func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session,
|
||||
Email: email,
|
||||
Name: name,
|
||||
IsAdmin: isAdmin,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
ExpiresAt: time.Now().Add(SessionDuration),
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// CreateSession creates a new session and returns a session ID
|
||||
// Sessions are persisted to database and cached in memory
|
||||
func (a *Auth) CreateSession(session *Session) string {
|
||||
sessionID := uuid.New().String()
|
||||
a.sessionStore[sessionID] = session
|
||||
|
||||
// Store in database first
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec(
|
||||
`INSERT INTO sessions (session_id, user_id, email, name, is_admin, expires_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
sessionID, session.UserID, session.Email, session.Name, session.IsAdmin, session.ExpiresAt,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to persist session to database: %v", err)
|
||||
// Continue anyway - session will work from cache but won't survive restart
|
||||
}
|
||||
|
||||
// Store in cache
|
||||
a.cacheMu.Lock()
|
||||
a.sessionCache[sessionID] = session
|
||||
a.cacheMu.Unlock()
|
||||
|
||||
return sessionID
|
||||
}
|
||||
|
||||
// GetSession retrieves a session by ID
|
||||
// First checks cache, then database if not found
|
||||
func (a *Auth) GetSession(sessionID string) (*Session, bool) {
|
||||
session, ok := a.sessionStore[sessionID]
|
||||
if !ok {
|
||||
// Check cache first
|
||||
a.cacheMu.RLock()
|
||||
session, ok := a.sessionCache[sessionID]
|
||||
a.cacheMu.RUnlock()
|
||||
|
||||
if ok {
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
a.DeleteSession(sessionID)
|
||||
return nil, false
|
||||
}
|
||||
// Refresh admin status from database
|
||||
var isAdmin bool
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin)
|
||||
})
|
||||
if err == nil {
|
||||
session.IsAdmin = isAdmin
|
||||
}
|
||||
return session, true
|
||||
}
|
||||
|
||||
// Not in cache, check database
|
||||
session = &Session{}
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow(
|
||||
`SELECT user_id, email, name, is_admin, expires_at
|
||||
FROM sessions WHERE session_id = ?`,
|
||||
sessionID,
|
||||
).Scan(&session.UserID, &session.Email, &session.Name, &session.IsAdmin, &session.ExpiresAt)
|
||||
})
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, false
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to query session from database: %v", err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
delete(a.sessionStore, sessionID)
|
||||
a.DeleteSession(sessionID)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Refresh admin status from database
|
||||
var isAdmin bool
|
||||
err := a.db.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin)
|
||||
})
|
||||
if err == nil {
|
||||
session.IsAdmin = isAdmin
|
||||
}
|
||||
|
||||
// Add to cache
|
||||
a.cacheMu.Lock()
|
||||
a.sessionCache[sessionID] = session
|
||||
a.cacheMu.Unlock()
|
||||
|
||||
return session, true
|
||||
}
|
||||
|
||||
// DeleteSession deletes a session
|
||||
// DeleteSession deletes a session from both cache and database
|
||||
func (a *Auth) DeleteSession(sessionID string) {
|
||||
delete(a.sessionStore, sessionID)
|
||||
// Delete from cache
|
||||
a.cacheMu.Lock()
|
||||
delete(a.sessionCache, sessionID)
|
||||
a.cacheMu.Unlock()
|
||||
|
||||
// Delete from database
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("DELETE FROM sessions WHERE session_id = ?", sessionID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to delete session from database: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// IsProductionMode returns true if running in production mode
|
||||
// This is a package-level function that checks the environment variable
|
||||
// For config-based checks, use Config.IsProductionMode()
|
||||
func IsProductionMode() bool {
|
||||
// Check environment variable first for backwards compatibility
|
||||
if os.Getenv("PRODUCTION") == "true" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsProductionModeFromConfig returns true if production mode is enabled in config
|
||||
func (a *Auth) IsProductionModeFromConfig() bool {
|
||||
return a.cfg.IsProductionMode()
|
||||
}
|
||||
|
||||
// Middleware creates an authentication middleware
|
||||
@@ -410,6 +673,7 @@ func (a *Auth) Middleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("session_id")
|
||||
if err != nil {
|
||||
log.Printf("Authentication failed: missing session cookie for %s %s", r.Method, r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
||||
@@ -418,30 +682,31 @@ func (a *Auth) Middleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
|
||||
session, ok := a.GetSession(cookie.Value)
|
||||
if !ok {
|
||||
log.Printf("Authentication failed: invalid session cookie for %s %s", r.Method, r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
// Add user info to request context
|
||||
ctx := context.WithValue(r.Context(), "user_id", session.UserID)
|
||||
ctx = context.WithValue(ctx, "user_email", session.Email)
|
||||
ctx = context.WithValue(ctx, "user_name", session.Name)
|
||||
ctx = context.WithValue(ctx, "is_admin", session.IsAdmin)
|
||||
// Add user info to request context using typed keys
|
||||
ctx := context.WithValue(r.Context(), contextKeyUserID, session.UserID)
|
||||
ctx = context.WithValue(ctx, contextKeyUserEmail, session.Email)
|
||||
ctx = context.WithValue(ctx, contextKeyUserName, session.Name)
|
||||
ctx = context.WithValue(ctx, contextKeyIsAdmin, session.IsAdmin)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserID gets the user ID from context
|
||||
func GetUserID(ctx context.Context) (int64, bool) {
|
||||
userID, ok := ctx.Value("user_id").(int64)
|
||||
userID, ok := ctx.Value(contextKeyUserID).(int64)
|
||||
return userID, ok
|
||||
}
|
||||
|
||||
// IsAdmin checks if the user in context is an admin
|
||||
func IsAdmin(ctx context.Context) bool {
|
||||
isAdmin, ok := ctx.Value("is_admin").(bool)
|
||||
isAdmin, ok := ctx.Value(contextKeyIsAdmin).(bool)
|
||||
return ok && isAdmin
|
||||
}
|
||||
|
||||
@@ -451,6 +716,7 @@ func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
// First check authentication
|
||||
cookie, err := r.Cookie("session_id")
|
||||
if err != nil {
|
||||
log.Printf("Admin authentication failed: missing session cookie for %s %s", r.Method, r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
||||
@@ -459,6 +725,7 @@ func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
|
||||
session, ok := a.GetSession(cookie.Value)
|
||||
if !ok {
|
||||
log.Printf("Admin authentication failed: invalid session cookie for %s %s", r.Method, r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
||||
@@ -467,25 +734,26 @@ func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
|
||||
// Then check admin status
|
||||
if !session.IsAdmin {
|
||||
log.Printf("Admin access denied: user %d (email: %s) attempted to access admin endpoint %s %s", session.UserID, session.Email, r.Method, r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Forbidden: Admin access required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Add user info to request context
|
||||
ctx := context.WithValue(r.Context(), "user_id", session.UserID)
|
||||
ctx = context.WithValue(ctx, "user_email", session.Email)
|
||||
ctx = context.WithValue(ctx, "user_name", session.Name)
|
||||
ctx = context.WithValue(ctx, "is_admin", session.IsAdmin)
|
||||
// Add user info to request context using typed keys
|
||||
ctx := context.WithValue(r.Context(), contextKeyUserID, session.UserID)
|
||||
ctx = context.WithValue(ctx, contextKeyUserEmail, session.Email)
|
||||
ctx = context.WithValue(ctx, contextKeyUserName, session.Name)
|
||||
ctx = context.WithValue(ctx, contextKeyIsAdmin, session.IsAdmin)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// IsLocalLoginEnabled returns whether local login is enabled
|
||||
// Local login is enabled when ENABLE_LOCAL_AUTH environment variable is set to "true"
|
||||
// Checks database config first, falls back to environment variable
|
||||
func (a *Auth) IsLocalLoginEnabled() bool {
|
||||
return os.Getenv("ENABLE_LOCAL_AUTH") == "true"
|
||||
return a.cfg.IsLocalAuthEnabled()
|
||||
}
|
||||
|
||||
// IsGoogleOAuthConfigured returns whether Google OAuth is configured
|
||||
@@ -506,10 +774,12 @@ func (a *Auth) LocalLogin(username, password string) (*Session, error) {
|
||||
var dbEmail, dbName, passwordHash string
|
||||
var isAdmin bool
|
||||
|
||||
err := a.db.QueryRow(
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow(
|
||||
"SELECT id, email, name, password_hash, is_admin FROM users WHERE email = ? AND oauth_provider = 'local'",
|
||||
email,
|
||||
).Scan(&userID, &dbEmail, &dbName, &passwordHash, &isAdmin)
|
||||
})
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("invalid credentials")
|
||||
@@ -534,7 +804,7 @@ func (a *Auth) LocalLogin(username, password string) (*Session, error) {
|
||||
Email: dbEmail,
|
||||
Name: dbName,
|
||||
IsAdmin: isAdmin,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
ExpiresAt: time.Now().Add(SessionDuration),
|
||||
}
|
||||
|
||||
return session, nil
|
||||
@@ -553,7 +823,9 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error)
|
||||
|
||||
// Check if user already exists
|
||||
var exists bool
|
||||
err = a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ?)", email).Scan(&exists)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE email = ?)", email).Scan(&exists)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if user exists: %w", err)
|
||||
}
|
||||
@@ -569,15 +841,27 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error)
|
||||
|
||||
// Check if this is the first user (make them admin)
|
||||
var userCount int
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check user count: %w", err)
|
||||
}
|
||||
isAdmin := userCount == 0
|
||||
|
||||
// Create user
|
||||
var userID int64
|
||||
err = a.db.QueryRow(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?) RETURNING id",
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
result, err := conn.Exec(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, password_hash, is_admin) VALUES (?, ?, 'local', ?, ?, ?)",
|
||||
email, name, email, string(hashedPassword), isAdmin,
|
||||
).Scan(&userID)
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userID, err = result.LastInsertId()
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
@@ -588,7 +872,7 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error)
|
||||
Email: email,
|
||||
Name: name,
|
||||
IsAdmin: isAdmin,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
ExpiresAt: time.Now().Add(SessionDuration),
|
||||
}
|
||||
|
||||
return session, nil
|
||||
@@ -598,7 +882,9 @@ func (a *Auth) RegisterLocalUser(email, name, password string) (*Session, error)
|
||||
func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) error {
|
||||
// Get current password hash
|
||||
var passwordHash string
|
||||
err := a.db.QueryRow("SELECT password_hash FROM users WHERE id = ? AND oauth_provider = 'local'", userID).Scan(&passwordHash)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT password_hash FROM users WHERE id = ? AND oauth_provider = 'local'", userID).Scan(&passwordHash)
|
||||
})
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("user not found or not a local user")
|
||||
}
|
||||
@@ -623,7 +909,10 @@ func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) err
|
||||
}
|
||||
|
||||
// Update password
|
||||
_, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), userID)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), userID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
@@ -635,7 +924,9 @@ func (a *Auth) ChangePassword(userID int64, oldPassword, newPassword string) err
|
||||
func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error {
|
||||
// Verify user exists and is a local user
|
||||
var exists bool
|
||||
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ? AND oauth_provider = 'local')", targetUserID).Scan(&exists)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ? AND oauth_provider = 'local')", targetUserID).Scan(&exists)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if user exists: %w", err)
|
||||
}
|
||||
@@ -650,7 +941,10 @@ func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error
|
||||
}
|
||||
|
||||
// Update password
|
||||
_, err = a.db.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), targetUserID)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hashedPassword), targetUserID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
@@ -661,7 +955,9 @@ func (a *Auth) AdminChangePassword(targetUserID int64, newPassword string) error
|
||||
// GetFirstUserID returns the ID of the first user (user with the lowest ID)
|
||||
func (a *Auth) GetFirstUserID() (int64, error) {
|
||||
var firstUserID int64
|
||||
err := a.db.QueryRow("SELECT id FROM users ORDER BY id ASC LIMIT 1").Scan(&firstUserID)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT id FROM users ORDER BY id ASC LIMIT 1").Scan(&firstUserID)
|
||||
})
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, fmt.Errorf("no users found")
|
||||
}
|
||||
@@ -675,7 +971,9 @@ func (a *Auth) GetFirstUserID() (int64, error) {
|
||||
func (a *Auth) SetUserAdminStatus(targetUserID int64, isAdmin bool) error {
|
||||
// Verify user exists
|
||||
var exists bool
|
||||
err := a.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", targetUserID).Scan(&exists)
|
||||
err := a.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", targetUserID).Scan(&exists)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if user exists: %w", err)
|
||||
}
|
||||
@@ -693,7 +991,10 @@ func (a *Auth) SetUserAdminStatus(targetUserID int64, isAdmin bool) error {
|
||||
}
|
||||
|
||||
// Update admin status
|
||||
_, err = a.db.Exec("UPDATE users SET is_admin = ? WHERE id = ?", isAdmin, targetUserID)
|
||||
err = a.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("UPDATE users SET is_admin = ? WHERE id = ?", isAdmin, targetUserID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update admin status: %w", err)
|
||||
}
|
||||
|
||||
115
internal/auth/jobtoken.go
Normal file
115
internal/auth/jobtoken.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// JobTokenDuration is the validity period for job tokens
|
||||
const JobTokenDuration = 1 * time.Hour
|
||||
|
||||
// JobTokenClaims represents the claims in a job token
|
||||
type JobTokenClaims struct {
|
||||
JobID int64 `json:"job_id"`
|
||||
RunnerID int64 `json:"runner_id"`
|
||||
TaskID int64 `json:"task_id"`
|
||||
Exp int64 `json:"exp"` // Unix timestamp
|
||||
}
|
||||
|
||||
// jobTokenSecret is the secret used to sign job tokens
|
||||
// Generated once at startup and kept in memory
|
||||
var jobTokenSecret []byte
|
||||
|
||||
func init() {
|
||||
// Generate a random secret for signing job tokens
|
||||
// This means tokens are invalidated on server restart, which is acceptable
|
||||
// for short-lived job tokens
|
||||
jobTokenSecret = make([]byte, 32)
|
||||
if _, err := rand.Read(jobTokenSecret); err != nil {
|
||||
panic(fmt.Sprintf("failed to generate job token secret: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateJobToken creates a new job token for a specific job/runner/task combination
|
||||
func GenerateJobToken(jobID, runnerID, taskID int64) (string, error) {
|
||||
claims := JobTokenClaims{
|
||||
JobID: jobID,
|
||||
RunnerID: runnerID,
|
||||
TaskID: taskID,
|
||||
Exp: time.Now().Add(JobTokenDuration).Unix(),
|
||||
}
|
||||
|
||||
// Encode claims to JSON
|
||||
claimsJSON, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal claims: %w", err)
|
||||
}
|
||||
|
||||
// Create HMAC signature
|
||||
h := hmac.New(sha256.New, jobTokenSecret)
|
||||
h.Write(claimsJSON)
|
||||
signature := h.Sum(nil)
|
||||
|
||||
// Combine claims and signature: base64(claims).base64(signature)
|
||||
token := base64.RawURLEncoding.EncodeToString(claimsJSON) + "." +
|
||||
base64.RawURLEncoding.EncodeToString(signature)
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// ValidateJobToken validates a job token and returns the claims if valid
|
||||
func ValidateJobToken(token string) (*JobTokenClaims, error) {
|
||||
// Split token into claims and signature
|
||||
var claimsB64, sigB64 string
|
||||
dotIdx := -1
|
||||
for i := len(token) - 1; i >= 0; i-- {
|
||||
if token[i] == '.' {
|
||||
dotIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if dotIdx == -1 {
|
||||
return nil, fmt.Errorf("invalid token format")
|
||||
}
|
||||
claimsB64 = token[:dotIdx]
|
||||
sigB64 = token[dotIdx+1:]
|
||||
|
||||
// Decode claims
|
||||
claimsJSON, err := base64.RawURLEncoding.DecodeString(claimsB64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid token encoding: %w", err)
|
||||
}
|
||||
|
||||
// Decode signature
|
||||
signature, err := base64.RawURLEncoding.DecodeString(sigB64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid signature encoding: %w", err)
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
h := hmac.New(sha256.New, jobTokenSecret)
|
||||
h.Write(claimsJSON)
|
||||
expectedSig := h.Sum(nil)
|
||||
if !hmac.Equal(signature, expectedSig) {
|
||||
return nil, fmt.Errorf("invalid signature")
|
||||
}
|
||||
|
||||
// Parse claims
|
||||
var claims JobTokenClaims
|
||||
if err := json.Unmarshal(claimsJSON, &claims); err != nil {
|
||||
return nil, fmt.Errorf("invalid claims: %w", err)
|
||||
}
|
||||
|
||||
// Check expiration
|
||||
if time.Now().Unix() > claims.Exp {
|
||||
return nil, fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
@@ -1,276 +1,236 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"jiggablend/internal/config"
|
||||
"jiggablend/internal/database"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Secrets handles secret and token management
|
||||
// Secrets handles API key management
|
||||
type Secrets struct {
|
||||
db *sql.DB
|
||||
fixedRegistrationToken string // Fixed token from environment variable (reusable, never expires)
|
||||
db *database.DB
|
||||
cfg *config.Config
|
||||
RegistrationMu sync.Mutex // Protects concurrent runner registrations
|
||||
}
|
||||
|
||||
// NewSecrets creates a new secrets manager
|
||||
func NewSecrets(db *sql.DB) (*Secrets, error) {
|
||||
s := &Secrets{db: db}
|
||||
|
||||
// Check for fixed registration token from environment
|
||||
fixedToken := os.Getenv("FIXED_REGISTRATION_TOKEN")
|
||||
if fixedToken != "" {
|
||||
s.fixedRegistrationToken = fixedToken
|
||||
log.Printf("Fixed registration token enabled (from FIXED_REGISTRATION_TOKEN env var)")
|
||||
log.Printf("WARNING: Fixed registration token is reusable and never expires - use only for testing/development!")
|
||||
}
|
||||
|
||||
// Ensure manager secret exists
|
||||
if err := s.ensureManagerSecret(); err != nil {
|
||||
return nil, fmt.Errorf("failed to ensure manager secret: %w", err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
func NewSecrets(db *database.DB, cfg *config.Config) (*Secrets, error) {
|
||||
return &Secrets{db: db, cfg: cfg}, nil
|
||||
}
|
||||
|
||||
// ensureManagerSecret ensures a manager secret exists in the database
|
||||
func (s *Secrets) ensureManagerSecret() error {
|
||||
var count int
|
||||
err := s.db.QueryRow("SELECT COUNT(*) FROM manager_secrets").Scan(&count)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check manager secrets: %w", err)
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
// Generate new manager secret
|
||||
secret, err := generateSecret(32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate manager secret: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.db.Exec("INSERT INTO manager_secrets (secret) VALUES (?)", secret)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store manager secret: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
// APIKeyInfo represents information about an API key
|
||||
type APIKeyInfo struct {
|
||||
ID int64 `json:"id"`
|
||||
Key string `json:"key"`
|
||||
Name string `json:"name"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
Scope string `json:"scope"` // 'manager' or 'user'
|
||||
IsActive bool `json:"is_active"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
CreatedBy int64 `json:"created_by"`
|
||||
}
|
||||
|
||||
// GetManagerSecret retrieves the current manager secret
|
||||
func (s *Secrets) GetManagerSecret() (string, error) {
|
||||
var secret string
|
||||
err := s.db.QueryRow("SELECT secret FROM manager_secrets ORDER BY created_at DESC LIMIT 1").Scan(&secret)
|
||||
// GenerateRunnerAPIKey generates a new API key for runners
|
||||
func (s *Secrets) GenerateRunnerAPIKey(createdBy int64, name, description string, scope string) (*APIKeyInfo, error) {
|
||||
// Generate API key in format: jk_r1_abc123def456...
|
||||
key, err := s.generateAPIKey()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get manager secret: %w", err)
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
// GenerateRegistrationToken generates a new registration token
|
||||
func (s *Secrets) GenerateRegistrationToken(createdBy int64, expiresIn time.Duration) (string, error) {
|
||||
token, err := generateSecret(32)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate token: %w", err)
|
||||
return nil, fmt.Errorf("failed to generate API key: %w", err)
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(expiresIn)
|
||||
// Extract prefix (first 5 chars after "jk_") and hash the full key
|
||||
parts := strings.Split(key, "_")
|
||||
if len(parts) < 3 {
|
||||
return nil, fmt.Errorf("invalid API key format generated")
|
||||
}
|
||||
keyPrefix := fmt.Sprintf("%s_%s", parts[0], parts[1])
|
||||
|
||||
_, err = s.db.Exec(
|
||||
"INSERT INTO registration_tokens (token, expires_at, created_by) VALUES (?, ?, ?)",
|
||||
token, expiresAt, createdBy,
|
||||
keyHash := sha256.Sum256([]byte(key))
|
||||
keyHashStr := hex.EncodeToString(keyHash[:])
|
||||
|
||||
var keyInfo APIKeyInfo
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
result, err := conn.Exec(
|
||||
`INSERT INTO runner_api_keys (key_prefix, key_hash, name, description, scope, is_active, created_by)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
keyPrefix, keyHashStr, name, description, scope, true, createdBy,
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to store registration token: %w", err)
|
||||
return fmt.Errorf("failed to store API key: %w", err)
|
||||
}
|
||||
|
||||
keyID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get inserted key ID: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
// Get the inserted key info
|
||||
err = conn.QueryRow(
|
||||
`SELECT id, name, description, scope, is_active, created_at, created_by
|
||||
FROM runner_api_keys WHERE id = ?`,
|
||||
keyID,
|
||||
).Scan(&keyInfo.ID, &keyInfo.Name, &keyInfo.Description, &keyInfo.Scope, &keyInfo.IsActive, &keyInfo.CreatedAt, &keyInfo.CreatedBy)
|
||||
|
||||
// TokenValidationResult represents the result of token validation
|
||||
type TokenValidationResult struct {
|
||||
Valid bool
|
||||
Reason string // "valid", "not_found", "already_used", "expired"
|
||||
Error error
|
||||
}
|
||||
return err
|
||||
})
|
||||
|
||||
// ValidateRegistrationToken validates a registration token
|
||||
func (s *Secrets) ValidateRegistrationToken(token string) (bool, error) {
|
||||
result, err := s.ValidateRegistrationTokenDetailed(token)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return nil, fmt.Errorf("failed to create API key: %w", err)
|
||||
}
|
||||
// For backward compatibility, return just the valid boolean
|
||||
return result.Valid, nil
|
||||
|
||||
keyInfo.Key = key
|
||||
return &keyInfo, nil
|
||||
}
|
||||
|
||||
// ValidateRegistrationTokenDetailed validates a registration token and returns detailed result
|
||||
func (s *Secrets) ValidateRegistrationTokenDetailed(token string) (*TokenValidationResult, error) {
|
||||
// Check fixed token first (if set) - it's reusable and never expires
|
||||
if s.fixedRegistrationToken != "" && token == s.fixedRegistrationToken {
|
||||
log.Printf("Fixed registration token used (from FIXED_REGISTRATION_TOKEN env var)")
|
||||
return &TokenValidationResult{Valid: true, Reason: "valid"}, nil
|
||||
// generateAPIKey generates a new API key in format jk_r1_abc123def456...
|
||||
func (s *Secrets) generateAPIKey() (string, error) {
|
||||
// Generate random suffix
|
||||
randomBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
randomStr := hex.EncodeToString(randomBytes)
|
||||
|
||||
// Generate a unique prefix (jk_r followed by 1 random digit)
|
||||
prefixDigit := make([]byte, 1)
|
||||
if _, err := rand.Read(prefixDigit); err != nil {
|
||||
return "", fmt.Errorf("failed to generate prefix digit: %w", err)
|
||||
}
|
||||
|
||||
// Check database tokens
|
||||
var used bool
|
||||
var expiresAt time.Time
|
||||
var id int64
|
||||
prefix := fmt.Sprintf("jk_r%d", prefixDigit[0]%10)
|
||||
key := fmt.Sprintf("%s_%s", prefix, randomStr)
|
||||
|
||||
err := s.db.QueryRow(
|
||||
"SELECT id, expires_at, used FROM registration_tokens WHERE token = ?",
|
||||
token,
|
||||
).Scan(&id, &expiresAt, &used)
|
||||
// Validate generated key format
|
||||
if !strings.HasPrefix(key, "jk_r") {
|
||||
return "", fmt.Errorf("generated invalid API key format: %s", key)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// ValidateRunnerAPIKey validates an API key and returns the key ID and scope if valid
|
||||
func (s *Secrets) ValidateRunnerAPIKey(apiKey string) (int64, string, error) {
|
||||
if apiKey == "" {
|
||||
return 0, "", fmt.Errorf("API key is required")
|
||||
}
|
||||
|
||||
// Check fixed API key first (from database config)
|
||||
fixedKey := s.cfg.FixedAPIKey()
|
||||
if fixedKey != "" && apiKey == fixedKey {
|
||||
// Return a special ID for fixed API key (doesn't exist in database)
|
||||
return -1, "manager", nil
|
||||
}
|
||||
|
||||
// Parse API key format: jk_rX_...
|
||||
if !strings.HasPrefix(apiKey, "jk_r") {
|
||||
return 0, "", fmt.Errorf("invalid API key format: expected format 'jk_rX_...' where X is a number (e.g., 'jk_r1_abc123...')")
|
||||
}
|
||||
|
||||
parts := strings.Split(apiKey, "_")
|
||||
if len(parts) < 3 {
|
||||
return 0, "", fmt.Errorf("invalid API key format: expected format 'jk_rX_...' with at least 3 parts separated by underscores")
|
||||
}
|
||||
|
||||
keyPrefix := fmt.Sprintf("%s_%s", parts[0], parts[1])
|
||||
|
||||
// Hash the full key for comparison
|
||||
keyHash := sha256.Sum256([]byte(apiKey))
|
||||
keyHashStr := hex.EncodeToString(keyHash[:])
|
||||
|
||||
var keyID int64
|
||||
var scope string
|
||||
var isActive bool
|
||||
|
||||
err := s.db.With(func(conn *sql.DB) error {
|
||||
err := conn.QueryRow(
|
||||
`SELECT id, scope, is_active FROM runner_api_keys
|
||||
WHERE key_prefix = ? AND key_hash = ?`,
|
||||
keyPrefix, keyHashStr,
|
||||
).Scan(&keyID, &scope, &isActive)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return &TokenValidationResult{Valid: false, Reason: "not_found"}, nil
|
||||
return fmt.Errorf("API key not found or invalid - please check that the key is correct and active")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query token: %w", err)
|
||||
return fmt.Errorf("failed to validate API key: %w", err)
|
||||
}
|
||||
|
||||
if used {
|
||||
return &TokenValidationResult{Valid: false, Reason: "already_used"}, nil
|
||||
if !isActive {
|
||||
return fmt.Errorf("API key is inactive")
|
||||
}
|
||||
|
||||
if time.Now().After(expiresAt) {
|
||||
return &TokenValidationResult{Valid: false, Reason: "expired"}, nil
|
||||
}
|
||||
// Update last_used_at (don't fail if this update fails)
|
||||
conn.Exec(`UPDATE runner_api_keys SET last_used_at = ? WHERE id = ?`, time.Now(), keyID)
|
||||
return nil
|
||||
})
|
||||
|
||||
// Mark token as used
|
||||
_, err = s.db.Exec("UPDATE registration_tokens SET used = 1 WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to mark token as used: %w", err)
|
||||
return 0, "", err
|
||||
}
|
||||
|
||||
return &TokenValidationResult{Valid: true, Reason: "valid"}, nil
|
||||
return keyID, scope, nil
|
||||
}
|
||||
|
||||
// ListRegistrationTokens lists all registration tokens
|
||||
func (s *Secrets) ListRegistrationTokens() ([]map[string]interface{}, error) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, token, expires_at, used, created_at, created_by
|
||||
FROM registration_tokens
|
||||
// ListRunnerAPIKeys lists all runner API keys
|
||||
func (s *Secrets) ListRunnerAPIKeys() ([]APIKeyInfo, error) {
|
||||
var keys []APIKeyInfo
|
||||
err := s.db.With(func(conn *sql.DB) error {
|
||||
rows, err := conn.Query(
|
||||
`SELECT id, key_prefix, name, description, scope, is_active, created_at, created_by
|
||||
FROM runner_api_keys
|
||||
ORDER BY created_at DESC`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query tokens: %w", err)
|
||||
return fmt.Errorf("failed to query API keys: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tokens []map[string]interface{}
|
||||
for rows.Next() {
|
||||
var id, createdBy sql.NullInt64
|
||||
var token string
|
||||
var expiresAt, createdAt time.Time
|
||||
var used bool
|
||||
var key APIKeyInfo
|
||||
var description sql.NullString
|
||||
|
||||
err := rows.Scan(&id, &token, &expiresAt, &used, &createdAt, &createdBy)
|
||||
err := rows.Scan(&key.ID, &key.Key, &key.Name, &description, &key.Scope, &key.IsActive, &key.CreatedAt, &key.CreatedBy)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
tokens = append(tokens, map[string]interface{}{
|
||||
"id": id.Int64,
|
||||
"token": token,
|
||||
"expires_at": expiresAt,
|
||||
"used": used,
|
||||
"created_at": createdAt,
|
||||
"created_by": createdBy.Int64,
|
||||
})
|
||||
}
|
||||
if description.Valid {
|
||||
key.Description = &description.String
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
keys = append(keys, key)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// RevokeRegistrationToken revokes a registration token
|
||||
func (s *Secrets) RevokeRegistrationToken(tokenID int64) error {
|
||||
_, err := s.db.Exec("UPDATE registration_tokens SET used = 1 WHERE id = ?", tokenID)
|
||||
// RevokeRunnerAPIKey revokes (deactivates) a runner API key
|
||||
func (s *Secrets) RevokeRunnerAPIKey(keyID int64) error {
|
||||
return s.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("UPDATE runner_api_keys SET is_active = false WHERE id = ?", keyID)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateRunnerSecret generates a unique secret for a runner
|
||||
func (s *Secrets) GenerateRunnerSecret() (string, error) {
|
||||
return generateSecret(32)
|
||||
// DeleteRunnerAPIKey deletes a runner API key
|
||||
func (s *Secrets) DeleteRunnerAPIKey(keyID int64) error {
|
||||
return s.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("DELETE FROM runner_api_keys WHERE id = ?", keyID)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// SignRequest signs a request with the given secret
|
||||
func SignRequest(method, path, body, secret string, timestamp time.Time) string {
|
||||
message := fmt.Sprintf("%s\n%s\n%s\n%d", method, path, body, timestamp.Unix())
|
||||
h := hmac.New(sha256.New, []byte(secret))
|
||||
h.Write([]byte(message))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// VerifyRequest verifies a signed request
|
||||
func VerifyRequest(r *http.Request, secret string, maxAge time.Duration) (bool, error) {
|
||||
signature := r.Header.Get("X-Runner-Signature")
|
||||
if signature == "" {
|
||||
return false, fmt.Errorf("missing signature")
|
||||
}
|
||||
|
||||
timestampStr := r.Header.Get("X-Runner-Timestamp")
|
||||
if timestampStr == "" {
|
||||
return false, fmt.Errorf("missing timestamp")
|
||||
}
|
||||
|
||||
var timestampUnix int64
|
||||
_, err := fmt.Sscanf(timestampStr, "%d", ×tampUnix)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("invalid timestamp: %w", err)
|
||||
}
|
||||
timestamp := time.Unix(timestampUnix, 0)
|
||||
|
||||
// Check timestamp is not too old
|
||||
if time.Since(timestamp) > maxAge {
|
||||
return false, fmt.Errorf("request too old")
|
||||
}
|
||||
|
||||
// Check timestamp is not in the future (allow 1 minute clock skew)
|
||||
if timestamp.After(time.Now().Add(1 * time.Minute)) {
|
||||
return false, fmt.Errorf("timestamp in future")
|
||||
}
|
||||
|
||||
// Read body
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to read body: %w", err)
|
||||
}
|
||||
// Restore body for handler
|
||||
r.Body = io.NopCloser(strings.NewReader(string(bodyBytes)))
|
||||
|
||||
// Verify signature - use path without query parameters (query params are not part of signature)
|
||||
// The runner signs with the path including query params, but we verify with just the path
|
||||
// This is intentional - query params are for identification, not part of the signature
|
||||
path := r.URL.Path
|
||||
expectedSig := SignRequest(r.Method, path, string(bodyBytes), secret, timestamp)
|
||||
|
||||
return hmac.Equal([]byte(signature), []byte(expectedSig)), nil
|
||||
}
|
||||
|
||||
// GetRunnerSecret retrieves the runner secret for a runner ID
|
||||
func (s *Secrets) GetRunnerSecret(runnerID int64) (string, error) {
|
||||
var secret string
|
||||
err := s.db.QueryRow("SELECT runner_secret FROM runners WHERE id = ?", runnerID).Scan(&secret)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", fmt.Errorf("runner not found")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get runner secret: %w", err)
|
||||
}
|
||||
if secret == "" {
|
||||
return "", fmt.Errorf("runner not verified")
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
// generateSecret generates a random secret of the given length
|
||||
func generateSecret(length int) (string, error) {
|
||||
|
||||
303
internal/config/config.go
Normal file
303
internal/config/config.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"jiggablend/internal/database"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Config keys stored in database
|
||||
const (
|
||||
KeyGoogleClientID = "google_client_id"
|
||||
KeyGoogleClientSecret = "google_client_secret"
|
||||
KeyGoogleRedirectURL = "google_redirect_url"
|
||||
KeyDiscordClientID = "discord_client_id"
|
||||
KeyDiscordClientSecret = "discord_client_secret"
|
||||
KeyDiscordRedirectURL = "discord_redirect_url"
|
||||
KeyEnableLocalAuth = "enable_local_auth"
|
||||
KeyFixedAPIKey = "fixed_api_key"
|
||||
KeyRegistrationEnabled = "registration_enabled"
|
||||
KeyProductionMode = "production_mode"
|
||||
KeyAllowedOrigins = "allowed_origins"
|
||||
)
|
||||
|
||||
// Config manages application configuration stored in the database
|
||||
type Config struct {
|
||||
db *database.DB
|
||||
}
|
||||
|
||||
// NewConfig creates a new config manager
|
||||
func NewConfig(db *database.DB) *Config {
|
||||
return &Config{db: db}
|
||||
}
|
||||
|
||||
// InitializeFromEnv loads configuration from environment variables on first run
|
||||
// Environment variables take precedence only if the config key doesn't exist in the database
|
||||
// This allows first-run setup via env vars, then subsequent runs use database values
|
||||
func (c *Config) InitializeFromEnv() error {
|
||||
envMappings := []struct {
|
||||
envKey string
|
||||
configKey string
|
||||
sensitive bool
|
||||
}{
|
||||
{"GOOGLE_CLIENT_ID", KeyGoogleClientID, false},
|
||||
{"GOOGLE_CLIENT_SECRET", KeyGoogleClientSecret, true},
|
||||
{"GOOGLE_REDIRECT_URL", KeyGoogleRedirectURL, false},
|
||||
{"DISCORD_CLIENT_ID", KeyDiscordClientID, false},
|
||||
{"DISCORD_CLIENT_SECRET", KeyDiscordClientSecret, true},
|
||||
{"DISCORD_REDIRECT_URL", KeyDiscordRedirectURL, false},
|
||||
{"ENABLE_LOCAL_AUTH", KeyEnableLocalAuth, false},
|
||||
{"FIXED_API_KEY", KeyFixedAPIKey, true},
|
||||
{"PRODUCTION", KeyProductionMode, false},
|
||||
{"ALLOWED_ORIGINS", KeyAllowedOrigins, false},
|
||||
}
|
||||
|
||||
for _, mapping := range envMappings {
|
||||
envValue := os.Getenv(mapping.envKey)
|
||||
if envValue == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if config already exists in database
|
||||
exists, err := c.Exists(mapping.configKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check config %s: %w", mapping.configKey, err)
|
||||
}
|
||||
|
||||
if !exists {
|
||||
// Store env value in database
|
||||
if err := c.Set(mapping.configKey, envValue); err != nil {
|
||||
return fmt.Errorf("failed to store config %s: %w", mapping.configKey, err)
|
||||
}
|
||||
if mapping.sensitive {
|
||||
log.Printf("Stored config from env: %s = [REDACTED]", mapping.configKey)
|
||||
} else {
|
||||
log.Printf("Stored config from env: %s = %s", mapping.configKey, envValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a config value from the database
|
||||
func (c *Config) Get(key string) (string, error) {
|
||||
var value string
|
||||
err := c.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT value FROM settings WHERE key = ?", key).Scan(&value)
|
||||
})
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get config %s: %w", key, err)
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// GetWithDefault retrieves a config value or returns a default if not set
|
||||
func (c *Config) GetWithDefault(key, defaultValue string) string {
|
||||
value, err := c.Get(key)
|
||||
if err != nil || value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// GetBool retrieves a boolean config value
|
||||
func (c *Config) GetBool(key string) (bool, error) {
|
||||
value, err := c.Get(key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return value == "true" || value == "1", nil
|
||||
}
|
||||
|
||||
// GetBoolWithDefault retrieves a boolean config value or returns a default
|
||||
func (c *Config) GetBoolWithDefault(key string, defaultValue bool) bool {
|
||||
value, err := c.GetBool(key)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
// If the key doesn't exist, Get returns empty string which becomes false
|
||||
// Check if key exists to distinguish between "false" and "not set"
|
||||
exists, _ := c.Exists(key)
|
||||
if !exists {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// GetInt retrieves an integer config value
|
||||
func (c *Config) GetInt(key string) (int, error) {
|
||||
value, err := c.Get(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if value == "" {
|
||||
return 0, nil
|
||||
}
|
||||
return strconv.Atoi(value)
|
||||
}
|
||||
|
||||
// GetIntWithDefault retrieves an integer config value or returns a default
|
||||
func (c *Config) GetIntWithDefault(key string, defaultValue int) int {
|
||||
value, err := c.GetInt(key)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
exists, _ := c.Exists(key)
|
||||
if !exists {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// Set stores a config value in the database
|
||||
func (c *Config) Set(key, value string) error {
|
||||
// Use upsert pattern
|
||||
exists, err := c.Exists(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.db.With(func(conn *sql.DB) error {
|
||||
if exists {
|
||||
_, err = conn.Exec(
|
||||
"UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?",
|
||||
value, key,
|
||||
)
|
||||
} else {
|
||||
_, err = conn.Exec(
|
||||
"INSERT INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)",
|
||||
key, value,
|
||||
)
|
||||
}
|
||||
return err
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set config %s: %w", key, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetBool stores a boolean config value
|
||||
func (c *Config) SetBool(key string, value bool) error {
|
||||
strValue := "false"
|
||||
if value {
|
||||
strValue = "true"
|
||||
}
|
||||
return c.Set(key, strValue)
|
||||
}
|
||||
|
||||
// SetInt stores an integer config value
|
||||
func (c *Config) SetInt(key string, value int) error {
|
||||
return c.Set(key, strconv.Itoa(value))
|
||||
}
|
||||
|
||||
// Delete removes a config value from the database
|
||||
func (c *Config) Delete(key string) error {
|
||||
err := c.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("DELETE FROM settings WHERE key = ?", key)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete config %s: %w", key, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if a config key exists in the database
|
||||
func (c *Config) Exists(key string) (bool, error) {
|
||||
var exists bool
|
||||
err := c.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM settings WHERE key = ?)", key).Scan(&exists)
|
||||
})
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check config existence %s: %w", key, err)
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// GetAll returns all config values (for debugging/admin purposes)
|
||||
func (c *Config) GetAll() (map[string]string, error) {
|
||||
var result map[string]string
|
||||
err := c.db.With(func(conn *sql.DB) error {
|
||||
rows, err := conn.Query("SELECT key, value FROM settings")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get all config: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
result = make(map[string]string)
|
||||
for rows.Next() {
|
||||
var key, value string
|
||||
if err := rows.Scan(&key, &value); err != nil {
|
||||
return fmt.Errorf("failed to scan config row: %w", err)
|
||||
}
|
||||
result[key] = value
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// --- Convenience methods for specific config values ---
|
||||
|
||||
// GoogleClientID returns the Google OAuth client ID
|
||||
func (c *Config) GoogleClientID() string {
|
||||
return c.GetWithDefault(KeyGoogleClientID, "")
|
||||
}
|
||||
|
||||
// GoogleClientSecret returns the Google OAuth client secret
|
||||
func (c *Config) GoogleClientSecret() string {
|
||||
return c.GetWithDefault(KeyGoogleClientSecret, "")
|
||||
}
|
||||
|
||||
// GoogleRedirectURL returns the Google OAuth redirect URL
|
||||
func (c *Config) GoogleRedirectURL() string {
|
||||
return c.GetWithDefault(KeyGoogleRedirectURL, "")
|
||||
}
|
||||
|
||||
// DiscordClientID returns the Discord OAuth client ID
|
||||
func (c *Config) DiscordClientID() string {
|
||||
return c.GetWithDefault(KeyDiscordClientID, "")
|
||||
}
|
||||
|
||||
// DiscordClientSecret returns the Discord OAuth client secret
|
||||
func (c *Config) DiscordClientSecret() string {
|
||||
return c.GetWithDefault(KeyDiscordClientSecret, "")
|
||||
}
|
||||
|
||||
// DiscordRedirectURL returns the Discord OAuth redirect URL
|
||||
func (c *Config) DiscordRedirectURL() string {
|
||||
return c.GetWithDefault(KeyDiscordRedirectURL, "")
|
||||
}
|
||||
|
||||
// IsLocalAuthEnabled returns whether local authentication is enabled
|
||||
func (c *Config) IsLocalAuthEnabled() bool {
|
||||
return c.GetBoolWithDefault(KeyEnableLocalAuth, false)
|
||||
}
|
||||
|
||||
// FixedAPIKey returns the fixed API key for testing
|
||||
func (c *Config) FixedAPIKey() string {
|
||||
return c.GetWithDefault(KeyFixedAPIKey, "")
|
||||
}
|
||||
|
||||
// IsProductionMode returns whether production mode is enabled
|
||||
func (c *Config) IsProductionMode() bool {
|
||||
return c.GetBoolWithDefault(KeyProductionMode, false)
|
||||
}
|
||||
|
||||
// AllowedOrigins returns the allowed CORS origins
|
||||
func (c *Config) AllowedOrigins() string {
|
||||
return c.GetWithDefault(KeyAllowedOrigins, "")
|
||||
}
|
||||
|
||||
36
internal/database/migrations/000001_initial_schema.down.sql
Normal file
36
internal/database/migrations/000001_initial_schema.down.sql
Normal file
@@ -0,0 +1,36 @@
|
||||
-- Drop indexes
|
||||
DROP INDEX IF EXISTS idx_sessions_expires_at;
|
||||
DROP INDEX IF EXISTS idx_sessions_user_id;
|
||||
DROP INDEX IF EXISTS idx_sessions_session_id;
|
||||
DROP INDEX IF EXISTS idx_runners_last_heartbeat;
|
||||
DROP INDEX IF EXISTS idx_task_steps_task_id;
|
||||
DROP INDEX IF EXISTS idx_task_logs_runner_id;
|
||||
DROP INDEX IF EXISTS idx_task_logs_task_id_id;
|
||||
DROP INDEX IF EXISTS idx_task_logs_task_id_created_at;
|
||||
DROP INDEX IF EXISTS idx_runners_api_key_id;
|
||||
DROP INDEX IF EXISTS idx_runner_api_keys_created_by;
|
||||
DROP INDEX IF EXISTS idx_runner_api_keys_active;
|
||||
DROP INDEX IF EXISTS idx_runner_api_keys_prefix;
|
||||
DROP INDEX IF EXISTS idx_job_files_job_id;
|
||||
DROP INDEX IF EXISTS idx_tasks_started_at;
|
||||
DROP INDEX IF EXISTS idx_tasks_job_status;
|
||||
DROP INDEX IF EXISTS idx_tasks_status;
|
||||
DROP INDEX IF EXISTS idx_tasks_runner_id;
|
||||
DROP INDEX IF EXISTS idx_tasks_job_id;
|
||||
DROP INDEX IF EXISTS idx_jobs_user_status_created;
|
||||
DROP INDEX IF EXISTS idx_jobs_status;
|
||||
DROP INDEX IF EXISTS idx_jobs_user_id;
|
||||
|
||||
-- Drop tables (order matters due to foreign keys)
|
||||
DROP TABLE IF EXISTS sessions;
|
||||
DROP TABLE IF EXISTS settings;
|
||||
DROP TABLE IF EXISTS task_steps;
|
||||
DROP TABLE IF EXISTS task_logs;
|
||||
DROP TABLE IF EXISTS manager_secrets;
|
||||
DROP TABLE IF EXISTS job_files;
|
||||
DROP TABLE IF EXISTS tasks;
|
||||
DROP TABLE IF EXISTS runners;
|
||||
DROP TABLE IF EXISTS jobs;
|
||||
DROP TABLE IF EXISTS runner_api_keys;
|
||||
DROP TABLE IF EXISTS users;
|
||||
|
||||
184
internal/database/migrations/000001_initial_schema.up.sql
Normal file
184
internal/database/migrations/000001_initial_schema.up.sql
Normal file
@@ -0,0 +1,184 @@
|
||||
-- Enable foreign keys for SQLite
|
||||
PRAGMA foreign_keys = ON;
|
||||
|
||||
-- Users table
|
||||
CREATE TABLE users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
oauth_provider TEXT NOT NULL,
|
||||
oauth_id TEXT NOT NULL,
|
||||
password_hash TEXT,
|
||||
is_admin INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(oauth_provider, oauth_id)
|
||||
);
|
||||
|
||||
-- Runner API keys table
|
||||
CREATE TABLE runner_api_keys (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
key_prefix TEXT NOT NULL,
|
||||
key_hash TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
scope TEXT NOT NULL DEFAULT 'user',
|
||||
is_active INTEGER NOT NULL DEFAULT 1,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
created_by INTEGER,
|
||||
FOREIGN KEY (created_by) REFERENCES users(id),
|
||||
UNIQUE(key_prefix)
|
||||
);
|
||||
|
||||
-- Jobs table
|
||||
CREATE TABLE jobs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL,
|
||||
job_type TEXT NOT NULL DEFAULT 'render',
|
||||
name TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
progress REAL NOT NULL DEFAULT 0.0,
|
||||
frame_start INTEGER,
|
||||
frame_end INTEGER,
|
||||
output_format TEXT,
|
||||
blend_metadata TEXT,
|
||||
retry_count INTEGER NOT NULL DEFAULT 0,
|
||||
max_retries INTEGER NOT NULL DEFAULT 3,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
error_message TEXT,
|
||||
assigned_runner_id INTEGER,
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
|
||||
-- Runners table
|
||||
CREATE TABLE runners (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
hostname TEXT NOT NULL,
|
||||
ip_address TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'offline',
|
||||
last_heartbeat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
capabilities TEXT,
|
||||
api_key_id INTEGER,
|
||||
api_key_scope TEXT NOT NULL DEFAULT 'user',
|
||||
priority INTEGER NOT NULL DEFAULT 100,
|
||||
fingerprint TEXT,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (api_key_id) REFERENCES runner_api_keys(id)
|
||||
);
|
||||
|
||||
-- Tasks table
|
||||
CREATE TABLE tasks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
job_id INTEGER NOT NULL,
|
||||
runner_id INTEGER,
|
||||
frame INTEGER NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
output_path TEXT,
|
||||
task_type TEXT NOT NULL DEFAULT 'render',
|
||||
current_step TEXT,
|
||||
retry_count INTEGER NOT NULL DEFAULT 0,
|
||||
max_retries INTEGER NOT NULL DEFAULT 3,
|
||||
runner_failure_count INTEGER NOT NULL DEFAULT 0,
|
||||
timeout_seconds INTEGER,
|
||||
condition TEXT,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
error_message TEXT,
|
||||
FOREIGN KEY (job_id) REFERENCES jobs(id),
|
||||
FOREIGN KEY (runner_id) REFERENCES runners(id)
|
||||
);
|
||||
|
||||
-- Job files table
|
||||
CREATE TABLE job_files (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
job_id INTEGER NOT NULL,
|
||||
file_type TEXT NOT NULL,
|
||||
file_path TEXT NOT NULL,
|
||||
file_name TEXT NOT NULL,
|
||||
file_size INTEGER NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (job_id) REFERENCES jobs(id)
|
||||
);
|
||||
|
||||
-- Manager secrets table
|
||||
CREATE TABLE manager_secrets (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
secret TEXT UNIQUE NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Task logs table
|
||||
CREATE TABLE task_logs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_id INTEGER NOT NULL,
|
||||
runner_id INTEGER,
|
||||
log_level TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
step_name TEXT,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (task_id) REFERENCES tasks(id),
|
||||
FOREIGN KEY (runner_id) REFERENCES runners(id)
|
||||
);
|
||||
|
||||
-- Task steps table
|
||||
CREATE TABLE task_steps (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_id INTEGER NOT NULL,
|
||||
step_name TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
duration_ms INTEGER,
|
||||
error_message TEXT,
|
||||
FOREIGN KEY (task_id) REFERENCES tasks(id)
|
||||
);
|
||||
|
||||
-- Settings table
|
||||
CREATE TABLE settings (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Sessions table
|
||||
CREATE TABLE sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT UNIQUE NOT NULL,
|
||||
user_id INTEGER NOT NULL,
|
||||
email TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
is_admin INTEGER NOT NULL DEFAULT 0,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
|
||||
-- Indexes
|
||||
CREATE INDEX idx_jobs_user_id ON jobs(user_id);
|
||||
CREATE INDEX idx_jobs_status ON jobs(status);
|
||||
CREATE INDEX idx_jobs_user_status_created ON jobs(user_id, status, created_at DESC);
|
||||
CREATE INDEX idx_tasks_job_id ON tasks(job_id);
|
||||
CREATE INDEX idx_tasks_runner_id ON tasks(runner_id);
|
||||
CREATE INDEX idx_tasks_status ON tasks(status);
|
||||
CREATE INDEX idx_tasks_job_status ON tasks(job_id, status);
|
||||
CREATE INDEX idx_tasks_started_at ON tasks(started_at);
|
||||
CREATE INDEX idx_job_files_job_id ON job_files(job_id);
|
||||
CREATE INDEX idx_runner_api_keys_prefix ON runner_api_keys(key_prefix);
|
||||
CREATE INDEX idx_runner_api_keys_active ON runner_api_keys(is_active);
|
||||
CREATE INDEX idx_runner_api_keys_created_by ON runner_api_keys(created_by);
|
||||
CREATE INDEX idx_runners_api_key_id ON runners(api_key_id);
|
||||
CREATE INDEX idx_task_logs_task_id_created_at ON task_logs(task_id, created_at);
|
||||
CREATE INDEX idx_task_logs_task_id_id ON task_logs(task_id, id DESC);
|
||||
CREATE INDEX idx_task_logs_runner_id ON task_logs(runner_id);
|
||||
CREATE INDEX idx_task_steps_task_id ON task_steps(task_id);
|
||||
CREATE INDEX idx_runners_last_heartbeat ON runners(last_heartbeat);
|
||||
CREATE INDEX idx_sessions_session_id ON sessions(session_id);
|
||||
CREATE INDEX idx_sessions_user_id ON sessions(user_id);
|
||||
CREATE INDEX idx_sessions_expires_at ON sessions(expires_at);
|
||||
|
||||
-- Initialize registration_enabled setting
|
||||
INSERT INTO settings (key, value, updated_at) VALUES ('registration_enabled', 'true', CURRENT_TIMESTAMP);
|
||||
|
||||
@@ -2,252 +2,158 @@ package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log"
|
||||
|
||||
_ "github.com/marcboeker/go-duckdb/v2"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
//go:embed migrations/*.sql
|
||||
var migrationsFS embed.FS
|
||||
|
||||
// DB wraps the database connection
|
||||
// Note: No mutex needed - we only have one connection per process and SQLite with WAL mode
|
||||
// handles concurrent access safely
|
||||
type DB struct {
|
||||
*sql.DB
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewDB creates a new database connection
|
||||
func NewDB(dbPath string) (*DB, error) {
|
||||
db, err := sql.Open("duckdb", dbPath)
|
||||
// Use WAL mode for better concurrency (allows readers and writers simultaneously)
|
||||
// Add timeout and busy handler for better concurrent access
|
||||
db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_busy_timeout=5000")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool for better concurrency
|
||||
// SQLite with WAL mode supports multiple concurrent readers and one writer
|
||||
// Increasing pool size allows multiple HTTP requests to query the database simultaneously
|
||||
// This prevents blocking when multiple requests come in (e.g., on page refresh)
|
||||
db.SetMaxOpenConns(10) // Allow up to 10 concurrent connections
|
||||
db.SetMaxIdleConns(5) // Keep 5 idle connections ready
|
||||
db.SetConnMaxLifetime(0) // Connections don't expire
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
database := &DB{DB: db}
|
||||
// Enable foreign keys for SQLite
|
||||
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
|
||||
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
|
||||
}
|
||||
|
||||
// Enable WAL mode explicitly (in case the connection string didn't work)
|
||||
if _, err := db.Exec("PRAGMA journal_mode = WAL"); err != nil {
|
||||
log.Printf("Warning: Failed to enable WAL mode: %v", err)
|
||||
}
|
||||
|
||||
database := &DB{db: db}
|
||||
if err := database.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("failed to migrate database: %w", err)
|
||||
}
|
||||
|
||||
// Verify connection is still open after migration
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("database connection closed after migration: %w", err)
|
||||
}
|
||||
|
||||
return database, nil
|
||||
}
|
||||
|
||||
// migrate runs database migrations
|
||||
// With executes a function with access to the database
|
||||
// The function receives the underlying *sql.DB connection
|
||||
// No mutex needed - single connection + WAL mode handles concurrency
|
||||
func (db *DB) With(fn func(*sql.DB) error) error {
|
||||
return fn(db.db)
|
||||
}
|
||||
|
||||
// WithTx executes a function within a transaction
|
||||
// The function receives a *sql.Tx transaction
|
||||
// If the function returns an error, the transaction is rolled back
|
||||
// If the function returns nil, the transaction is committed
|
||||
// No mutex needed - single connection + WAL mode handles concurrency
|
||||
func (db *DB) WithTx(fn func(*sql.Tx) error) error {
|
||||
tx, err := db.db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
if rbErr := tx.Rollback(); rbErr != nil {
|
||||
return fmt.Errorf("transaction error: %w, rollback error: %v", err, rbErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrate runs database migrations using golang-migrate
|
||||
func (db *DB) migrate() error {
|
||||
// Create sequences for auto-incrementing primary keys
|
||||
sequences := []string{
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_users_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_jobs_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_runners_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_tasks_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_job_files_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_manager_secrets_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_registration_tokens_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_task_logs_id START 1`,
|
||||
`CREATE SEQUENCE IF NOT EXISTS seq_task_steps_id START 1`,
|
||||
// Create SQLite driver instance
|
||||
// Note: We use db.db directly since we're in the same package and this is called during initialization
|
||||
driver, err := sqlite3.WithInstance(db.db, &sqlite3.Config{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create sqlite3 driver: %w", err)
|
||||
}
|
||||
|
||||
for _, seq := range sequences {
|
||||
if _, err := db.Exec(seq); err != nil {
|
||||
return fmt.Errorf("failed to create sequence: %w", err)
|
||||
// Create embedded filesystem source
|
||||
migrationFS, err := fs.Sub(migrationsFS, "migrations")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migration filesystem: %w", err)
|
||||
}
|
||||
|
||||
sourceDriver, err := iofs.New(migrationFS, ".")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create iofs source driver: %w", err)
|
||||
}
|
||||
|
||||
// Create migrate instance
|
||||
m, err := migrate.NewWithInstance("iofs", sourceDriver, "sqlite3", driver)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migrate instance: %w", err)
|
||||
}
|
||||
|
||||
// Run migrations
|
||||
if err := m.Up(); err != nil {
|
||||
// If the error is "no change", that's fine - database is already up to date
|
||||
if err == migrate.ErrNoChange {
|
||||
log.Printf("Database is already up to date")
|
||||
// Don't close migrate instance - it may close the database connection
|
||||
// The migrate instance will be garbage collected
|
||||
return nil
|
||||
}
|
||||
// Don't close migrate instance on error either - it may close the DB
|
||||
return fmt.Errorf("failed to run migrations: %w", err)
|
||||
}
|
||||
|
||||
schema := `
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_users_id'),
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
oauth_provider TEXT NOT NULL,
|
||||
oauth_id TEXT NOT NULL,
|
||||
password_hash TEXT,
|
||||
is_admin BOOLEAN NOT NULL DEFAULT false,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(oauth_provider, oauth_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_jobs_id'),
|
||||
user_id BIGINT NOT NULL,
|
||||
job_type TEXT NOT NULL DEFAULT 'render',
|
||||
name TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
progress REAL NOT NULL DEFAULT 0.0,
|
||||
frame_start INTEGER,
|
||||
frame_end INTEGER,
|
||||
output_format TEXT,
|
||||
allow_parallel_runners BOOLEAN,
|
||||
timeout_seconds INTEGER DEFAULT 86400,
|
||||
blend_metadata TEXT,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
error_message TEXT,
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS runners (
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_runners_id'),
|
||||
name TEXT NOT NULL,
|
||||
hostname TEXT NOT NULL,
|
||||
ip_address TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'offline',
|
||||
last_heartbeat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
capabilities TEXT,
|
||||
registration_token TEXT,
|
||||
runner_secret TEXT,
|
||||
manager_secret TEXT,
|
||||
verified BOOLEAN NOT NULL DEFAULT false,
|
||||
priority INTEGER NOT NULL DEFAULT 100,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_tasks_id'),
|
||||
job_id BIGINT NOT NULL,
|
||||
runner_id BIGINT,
|
||||
frame_start INTEGER NOT NULL,
|
||||
frame_end INTEGER NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
output_path TEXT,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
error_message TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS job_files (
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_job_files_id'),
|
||||
job_id BIGINT NOT NULL,
|
||||
file_type TEXT NOT NULL,
|
||||
file_path TEXT NOT NULL,
|
||||
file_name TEXT NOT NULL,
|
||||
file_size INTEGER NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS manager_secrets (
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_manager_secrets_id'),
|
||||
secret TEXT UNIQUE NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS registration_tokens (
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_registration_tokens_id'),
|
||||
token TEXT UNIQUE NOT NULL,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
used BOOLEAN NOT NULL DEFAULT false,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
created_by BIGINT,
|
||||
FOREIGN KEY (created_by) REFERENCES users(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS task_logs (
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_task_logs_id'),
|
||||
task_id BIGINT NOT NULL,
|
||||
runner_id BIGINT,
|
||||
log_level TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
step_name TEXT,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS task_steps (
|
||||
id BIGINT PRIMARY KEY DEFAULT nextval('seq_task_steps_id'),
|
||||
task_id BIGINT NOT NULL,
|
||||
step_name TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
duration_ms INTEGER,
|
||||
error_message TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_user_id ON jobs(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_job_id ON tasks(job_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_runner_id ON tasks(runner_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_started_at ON tasks(started_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_job_files_job_id ON job_files(job_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_registration_tokens_token ON registration_tokens(token);
|
||||
CREATE INDEX IF NOT EXISTS idx_registration_tokens_expires_at ON registration_tokens(expires_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_task_logs_task_id_created_at ON task_logs(task_id, created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_task_logs_runner_id ON task_logs(runner_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_task_steps_task_id ON task_steps(task_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_runners_last_heartbeat ON runners(last_heartbeat);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS settings (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
`
|
||||
|
||||
if _, err := db.Exec(schema); err != nil {
|
||||
return fmt.Errorf("failed to create schema: %w", err)
|
||||
}
|
||||
|
||||
// Migrate existing tables to add new columns
|
||||
migrations := []string{
|
||||
// Add is_admin to users if it doesn't exist
|
||||
`ALTER TABLE users ADD COLUMN IF NOT EXISTS is_admin BOOLEAN NOT NULL DEFAULT false`,
|
||||
// Add new columns to runners if they don't exist
|
||||
`ALTER TABLE runners ADD COLUMN IF NOT EXISTS registration_token TEXT`,
|
||||
`ALTER TABLE runners ADD COLUMN IF NOT EXISTS runner_secret TEXT`,
|
||||
`ALTER TABLE runners ADD COLUMN IF NOT EXISTS manager_secret TEXT`,
|
||||
`ALTER TABLE runners ADD COLUMN IF NOT EXISTS verified BOOLEAN NOT NULL DEFAULT false`,
|
||||
`ALTER TABLE runners ADD COLUMN IF NOT EXISTS priority INTEGER NOT NULL DEFAULT 100`,
|
||||
// Add allow_parallel_runners to jobs if it doesn't exist
|
||||
`ALTER TABLE jobs ADD COLUMN IF NOT EXISTS allow_parallel_runners BOOLEAN NOT NULL DEFAULT true`,
|
||||
// Add timeout_seconds to jobs if it doesn't exist
|
||||
`ALTER TABLE jobs ADD COLUMN IF NOT EXISTS timeout_seconds INTEGER DEFAULT 86400`,
|
||||
// Add blend_metadata to jobs if it doesn't exist
|
||||
`ALTER TABLE jobs ADD COLUMN IF NOT EXISTS blend_metadata TEXT`,
|
||||
// Add job_type to jobs if it doesn't exist
|
||||
`ALTER TABLE jobs ADD COLUMN IF NOT EXISTS job_type TEXT DEFAULT 'render'`,
|
||||
// Add task_type to tasks if it doesn't exist
|
||||
`ALTER TABLE tasks ADD COLUMN IF NOT EXISTS task_type TEXT DEFAULT 'render'`,
|
||||
// Add new columns to tasks if they don't exist
|
||||
`ALTER TABLE tasks ADD COLUMN IF NOT EXISTS current_step TEXT`,
|
||||
`ALTER TABLE tasks ADD COLUMN IF NOT EXISTS retry_count INTEGER DEFAULT 0`,
|
||||
`ALTER TABLE tasks ADD COLUMN IF NOT EXISTS max_retries INTEGER DEFAULT 3`,
|
||||
`ALTER TABLE tasks ADD COLUMN IF NOT EXISTS timeout_seconds INTEGER`,
|
||||
}
|
||||
|
||||
for _, migration := range migrations {
|
||||
// DuckDB supports IF NOT EXISTS for ALTER TABLE, so we can safely execute
|
||||
if _, err := db.Exec(migration); err != nil {
|
||||
// Log but don't fail - column might already exist or table might not exist yet
|
||||
// This is fine for migrations that run after schema creation
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize registration_enabled setting (default: true) if it doesn't exist
|
||||
var settingCount int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "registration_enabled").Scan(&settingCount)
|
||||
if err == nil && settingCount == 0 {
|
||||
_, err = db.Exec("INSERT INTO settings (key, value) VALUES (?, ?)", "registration_enabled", "true")
|
||||
if err != nil {
|
||||
// Log but don't fail - setting might have been created by another process
|
||||
log.Printf("Note: Could not initialize registration_enabled setting: %v", err)
|
||||
}
|
||||
}
|
||||
// Don't close the migrate instance - with sqlite3.WithInstance, closing it
|
||||
// may close the underlying database connection. The migrate instance will
|
||||
// be garbage collected when it goes out of scope.
|
||||
// If we need to close it later, we can store it in the DB struct and close
|
||||
// it when DB.Close() is called, but for now we'll let it be GC'd.
|
||||
|
||||
log.Printf("Database migrations completed successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, migration := range migrations {
|
||||
// DuckDB supports IF NOT EXISTS for ALTER TABLE, so we can safely execute
|
||||
if _, err := db.Exec(migration); err != nil {
|
||||
// Log but don't fail - column might already exist or table might not exist yet
|
||||
// This is fine for migrations that run after schema creation
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
// Ping checks the database connection
|
||||
func (db *DB) Ping() error {
|
||||
return db.db.Ping()
|
||||
}
|
||||
|
||||
// Close closes the database connection
|
||||
func (db *DB) Close() error {
|
||||
return db.DB.Close()
|
||||
return db.db.Close()
|
||||
}
|
||||
|
||||
223
internal/logger/logger.go
Normal file
223
internal/logger/logger.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Level represents log severity
|
||||
type Level int
|
||||
|
||||
const (
|
||||
LevelDebug Level = iota
|
||||
LevelInfo
|
||||
LevelWarn
|
||||
LevelError
|
||||
)
|
||||
|
||||
var levelNames = map[Level]string{
|
||||
LevelDebug: "DEBUG",
|
||||
LevelInfo: "INFO",
|
||||
LevelWarn: "WARN",
|
||||
LevelError: "ERROR",
|
||||
}
|
||||
|
||||
// ParseLevel parses a level string into a Level
|
||||
func ParseLevel(s string) Level {
|
||||
switch s {
|
||||
case "debug", "DEBUG":
|
||||
return LevelDebug
|
||||
case "info", "INFO":
|
||||
return LevelInfo
|
||||
case "warn", "WARN", "warning", "WARNING":
|
||||
return LevelWarn
|
||||
case "error", "ERROR":
|
||||
return LevelError
|
||||
default:
|
||||
return LevelInfo
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
defaultLogger *Logger
|
||||
once sync.Once
|
||||
currentLevel Level = LevelInfo
|
||||
)
|
||||
|
||||
// Logger wraps the standard log.Logger with optional file output and levels
|
||||
type Logger struct {
|
||||
*log.Logger
|
||||
fileWriter io.WriteCloser
|
||||
}
|
||||
|
||||
// SetLevel sets the global log level
|
||||
func SetLevel(level Level) {
|
||||
currentLevel = level
|
||||
}
|
||||
|
||||
// GetLevel returns the current log level
|
||||
func GetLevel() Level {
|
||||
return currentLevel
|
||||
}
|
||||
|
||||
// InitStdout initializes the logger to only write to stdout
|
||||
func InitStdout() {
|
||||
once.Do(func() {
|
||||
log.SetOutput(os.Stdout)
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
defaultLogger = &Logger{
|
||||
Logger: log.Default(),
|
||||
fileWriter: nil,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// InitWithFile initializes the logger with both file and stdout output
|
||||
// The file is truncated on each start
|
||||
func InitWithFile(logPath string) error {
|
||||
var err error
|
||||
once.Do(func() {
|
||||
defaultLogger, err = NewWithFile(logPath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Replace standard log output with the multi-writer
|
||||
multiWriter := io.MultiWriter(os.Stdout, defaultLogger.fileWriter)
|
||||
log.SetOutput(multiWriter)
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// NewWithFile creates a new logger that writes to both stdout and a log file
|
||||
// The file is truncated on each start
|
||||
func NewWithFile(logPath string) (*Logger, error) {
|
||||
// Ensure log directory exists
|
||||
logDir := filepath.Dir(logPath)
|
||||
if err := os.MkdirAll(logDir, 0755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create/truncate the log file
|
||||
fileWriter, err := os.Create(logPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create multi-writer that writes to both stdout and file
|
||||
multiWriter := io.MultiWriter(os.Stdout, fileWriter)
|
||||
|
||||
// Create logger with standard flags
|
||||
logger := log.New(multiWriter, "", log.LstdFlags|log.Lshortfile)
|
||||
|
||||
return &Logger{
|
||||
Logger: logger,
|
||||
fileWriter: fileWriter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the file writer
|
||||
func (l *Logger) Close() error {
|
||||
if l.fileWriter != nil {
|
||||
return l.fileWriter.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDefault returns the default logger instance
|
||||
func GetDefault() *Logger {
|
||||
return defaultLogger
|
||||
}
|
||||
|
||||
// logf logs a formatted message at the given level
|
||||
func logf(level Level, format string, v ...interface{}) {
|
||||
if level < currentLevel {
|
||||
return
|
||||
}
|
||||
prefix := fmt.Sprintf("[%s] ", levelNames[level])
|
||||
msg := fmt.Sprintf(format, v...)
|
||||
log.Print(prefix + msg)
|
||||
}
|
||||
|
||||
// logln logs a message at the given level
|
||||
func logln(level Level, v ...interface{}) {
|
||||
if level < currentLevel {
|
||||
return
|
||||
}
|
||||
prefix := fmt.Sprintf("[%s] ", levelNames[level])
|
||||
msg := fmt.Sprint(v...)
|
||||
log.Print(prefix + msg)
|
||||
}
|
||||
|
||||
// Debug logs a debug message
|
||||
func Debug(v ...interface{}) {
|
||||
logln(LevelDebug, v...)
|
||||
}
|
||||
|
||||
// Debugf logs a formatted debug message
|
||||
func Debugf(format string, v ...interface{}) {
|
||||
logf(LevelDebug, format, v...)
|
||||
}
|
||||
|
||||
// Info logs an info message
|
||||
func Info(v ...interface{}) {
|
||||
logln(LevelInfo, v...)
|
||||
}
|
||||
|
||||
// Infof logs a formatted info message
|
||||
func Infof(format string, v ...interface{}) {
|
||||
logf(LevelInfo, format, v...)
|
||||
}
|
||||
|
||||
// Warn logs a warning message
|
||||
func Warn(v ...interface{}) {
|
||||
logln(LevelWarn, v...)
|
||||
}
|
||||
|
||||
// Warnf logs a formatted warning message
|
||||
func Warnf(format string, v ...interface{}) {
|
||||
logf(LevelWarn, format, v...)
|
||||
}
|
||||
|
||||
// Error logs an error message
|
||||
func Error(v ...interface{}) {
|
||||
logln(LevelError, v...)
|
||||
}
|
||||
|
||||
// Errorf logs a formatted error message
|
||||
func Errorf(format string, v ...interface{}) {
|
||||
logf(LevelError, format, v...)
|
||||
}
|
||||
|
||||
// Fatal logs an error message and exits
|
||||
func Fatal(v ...interface{}) {
|
||||
logln(LevelError, v...)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Fatalf logs a formatted error message and exits
|
||||
func Fatalf(format string, v ...interface{}) {
|
||||
logf(LevelError, format, v...)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// --- Backwards compatibility (maps to Info level) ---
|
||||
|
||||
// Printf logs a formatted message at Info level
|
||||
func Printf(format string, v ...interface{}) {
|
||||
logf(LevelInfo, format, v...)
|
||||
}
|
||||
|
||||
// Print logs a message at Info level
|
||||
func Print(v ...interface{}) {
|
||||
logln(LevelInfo, v...)
|
||||
}
|
||||
|
||||
// Println logs a message at Info level
|
||||
func Println(v ...interface{}) {
|
||||
logln(LevelInfo, v...)
|
||||
}
|
||||
@@ -10,68 +10,119 @@ import (
|
||||
"jiggablend/pkg/types"
|
||||
)
|
||||
|
||||
// handleGenerateRegistrationToken generates a new registration token
|
||||
func (s *Server) handleGenerateRegistrationToken(w http.ResponseWriter, r *http.Request) {
|
||||
// handleGenerateRunnerAPIKey generates a new runner API key
|
||||
func (s *Manager) handleGenerateRunnerAPIKey(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := getUserID(r)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Default expiration: 24 hours
|
||||
expiresIn := 24 * time.Hour
|
||||
|
||||
var req struct {
|
||||
ExpiresInHours int `json:"expires_in_hours,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Scope string `json:"scope,omitempty"` // 'manager' or 'user'
|
||||
}
|
||||
if r.Body != nil && r.ContentLength > 0 {
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err == nil && req.ExpiresInHours > 0 {
|
||||
expiresIn = time.Duration(req.ExpiresInHours) * time.Hour
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "API key name is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Default scope to 'user' if not specified
|
||||
scope := req.Scope
|
||||
if scope == "" {
|
||||
scope = "user"
|
||||
}
|
||||
if scope != "manager" && scope != "user" {
|
||||
s.respondError(w, http.StatusBadRequest, "Scope must be 'manager' or 'user'")
|
||||
return
|
||||
}
|
||||
|
||||
keyInfo, err := s.secrets.GenerateRunnerAPIKey(userID, req.Name, req.Description, scope)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to generate API key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": keyInfo.ID,
|
||||
"key": keyInfo.Key,
|
||||
"name": keyInfo.Name,
|
||||
"description": keyInfo.Description,
|
||||
"is_active": keyInfo.IsActive,
|
||||
"created_at": keyInfo.CreatedAt,
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusCreated, response)
|
||||
}
|
||||
|
||||
// handleListRunnerAPIKeys lists all runner API keys
|
||||
func (s *Manager) handleListRunnerAPIKeys(w http.ResponseWriter, r *http.Request) {
|
||||
keys, err := s.secrets.ListRunnerAPIKeys()
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to list API keys: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to response format (hide sensitive hash data)
|
||||
var response []map[string]interface{}
|
||||
for _, key := range keys {
|
||||
item := map[string]interface{}{
|
||||
"id": key.ID,
|
||||
"key_prefix": key.Key, // Only show prefix, not full key
|
||||
"name": key.Name,
|
||||
"is_active": key.IsActive,
|
||||
"created_at": key.CreatedAt,
|
||||
"created_by": key.CreatedBy,
|
||||
}
|
||||
if key.Description != nil {
|
||||
item["description"] = *key.Description
|
||||
}
|
||||
response = append(response, item)
|
||||
}
|
||||
|
||||
token, err := s.secrets.GenerateRegistrationToken(userID, expiresIn)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to generate token: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"token": token,
|
||||
"expires_in": expiresIn.String(),
|
||||
"expires_at": time.Now().Add(expiresIn),
|
||||
})
|
||||
s.respondJSON(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
// handleListRegistrationTokens lists all registration tokens
|
||||
func (s *Server) handleListRegistrationTokens(w http.ResponseWriter, r *http.Request) {
|
||||
tokens, err := s.secrets.ListRegistrationTokens()
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to list tokens: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, tokens)
|
||||
}
|
||||
|
||||
// handleRevokeRegistrationToken revokes a registration token
|
||||
func (s *Server) handleRevokeRegistrationToken(w http.ResponseWriter, r *http.Request) {
|
||||
tokenID, err := parseID(r, "id")
|
||||
// handleRevokeRunnerAPIKey revokes a runner API key
|
||||
func (s *Manager) handleRevokeRunnerAPIKey(w http.ResponseWriter, r *http.Request) {
|
||||
keyID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.secrets.RevokeRegistrationToken(tokenID); err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to revoke token: %v", err))
|
||||
if err := s.secrets.RevokeRunnerAPIKey(keyID); err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to revoke API key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Token revoked"})
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "API key revoked"})
|
||||
}
|
||||
|
||||
// handleDeleteRunnerAPIKey deletes a runner API key
|
||||
func (s *Manager) handleDeleteRunnerAPIKey(w http.ResponseWriter, r *http.Request) {
|
||||
keyID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.secrets.DeleteRunnerAPIKey(keyID); err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to delete API key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "API key deleted"})
|
||||
}
|
||||
|
||||
// handleVerifyRunner manually verifies a runner
|
||||
func (s *Server) handleVerifyRunner(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Manager) handleVerifyRunner(w http.ResponseWriter, r *http.Request) {
|
||||
runnerID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
@@ -80,14 +131,19 @@ func (s *Server) handleVerifyRunner(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Check if runner exists
|
||||
var exists bool
|
||||
err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists)
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists)
|
||||
})
|
||||
if err != nil || !exists {
|
||||
s.respondError(w, http.StatusNotFound, "Runner not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Mark runner as verified
|
||||
_, err = s.db.Exec("UPDATE runners SET verified = 1 WHERE id = ?", runnerID)
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("UPDATE runners SET verified = 1 WHERE id = ?", runnerID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify runner: %v", err))
|
||||
return
|
||||
@@ -97,7 +153,7 @@ func (s *Server) handleVerifyRunner(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// handleDeleteRunner removes a runner
|
||||
func (s *Server) handleDeleteRunner(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Manager) handleDeleteRunner(w http.ResponseWriter, r *http.Request) {
|
||||
runnerID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
@@ -106,14 +162,19 @@ func (s *Server) handleDeleteRunner(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Check if runner exists
|
||||
var exists bool
|
||||
err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists)
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists)
|
||||
})
|
||||
if err != nil || !exists {
|
||||
s.respondError(w, http.StatusNotFound, "Runner not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Delete runner
|
||||
_, err = s.db.Exec("DELETE FROM runners WHERE id = ?", runnerID)
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec("DELETE FROM runners WHERE id = ?", runnerID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to delete runner: %v", err))
|
||||
return
|
||||
@@ -123,12 +184,17 @@ func (s *Server) handleDeleteRunner(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// handleListRunnersAdmin lists all runners with admin details
|
||||
func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, name, hostname, ip_address, status, last_heartbeat, capabilities,
|
||||
registration_token, verified, priority, created_at
|
||||
func (s *Manager) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
var rows *sql.Rows
|
||||
err := s.db.With(func(conn *sql.DB) error {
|
||||
var err error
|
||||
rows, err = conn.Query(
|
||||
`SELECT id, name, hostname, status, last_heartbeat, capabilities,
|
||||
api_key_id, api_key_scope, priority, created_at
|
||||
FROM runners ORDER BY created_at DESC`,
|
||||
)
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query runners: %v", err))
|
||||
return
|
||||
@@ -138,31 +204,32 @@ func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request)
|
||||
runners := []map[string]interface{}{}
|
||||
for rows.Next() {
|
||||
var runner types.Runner
|
||||
var registrationToken sql.NullString
|
||||
var verified bool
|
||||
var apiKeyID sql.NullInt64
|
||||
var apiKeyScope string
|
||||
|
||||
err := rows.Scan(
|
||||
&runner.ID, &runner.Name, &runner.Hostname, &runner.IPAddress,
|
||||
&runner.ID, &runner.Name, &runner.Hostname,
|
||||
&runner.Status, &runner.LastHeartbeat, &runner.Capabilities,
|
||||
®istrationToken, &verified, &runner.Priority, &runner.CreatedAt,
|
||||
&apiKeyID, &apiKeyScope, &runner.Priority, &runner.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan runner: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// In polling model, database status is the source of truth
|
||||
// Runners update their status when they poll for jobs
|
||||
runners = append(runners, map[string]interface{}{
|
||||
"id": runner.ID,
|
||||
"name": runner.Name,
|
||||
"hostname": runner.Hostname,
|
||||
"ip_address": runner.IPAddress,
|
||||
"status": runner.Status,
|
||||
"last_heartbeat": runner.LastHeartbeat,
|
||||
"capabilities": runner.Capabilities,
|
||||
"registration_token": registrationToken.String,
|
||||
"verified": verified,
|
||||
"priority": runner.Priority,
|
||||
"created_at": runner.CreatedAt,
|
||||
"id": runner.ID,
|
||||
"name": runner.Name,
|
||||
"hostname": runner.Hostname,
|
||||
"status": runner.Status,
|
||||
"last_heartbeat": runner.LastHeartbeat,
|
||||
"capabilities": runner.Capabilities,
|
||||
"api_key_id": apiKeyID.Int64,
|
||||
"api_key_scope": apiKeyScope,
|
||||
"priority": runner.Priority,
|
||||
"created_at": runner.CreatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -170,7 +237,7 @@ func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// handleListUsers lists all users
|
||||
func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Manager) handleListUsers(w http.ResponseWriter, r *http.Request) {
|
||||
// Get first user ID to mark it in the response
|
||||
firstUserID, err := s.auth.GetFirstUserID()
|
||||
if err != nil {
|
||||
@@ -178,10 +245,15 @@ func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) {
|
||||
firstUserID = 0
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, email, name, oauth_provider, is_admin, created_at
|
||||
var rows *sql.Rows
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
var err error
|
||||
rows, err = conn.Query(
|
||||
`SELECT id, email, name, oauth_provider, is_admin, created_at
|
||||
FROM users ORDER BY created_at DESC`,
|
||||
)
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query users: %v", err))
|
||||
return
|
||||
@@ -203,7 +275,9 @@ func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Get job count for this user
|
||||
var jobCount int
|
||||
err = s.db.QueryRow("SELECT COUNT(*) FROM jobs WHERE user_id = ?", userID).Scan(&jobCount)
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT COUNT(*) FROM jobs WHERE user_id = ?", userID).Scan(&jobCount)
|
||||
})
|
||||
if err != nil {
|
||||
jobCount = 0 // Default to 0 if query fails
|
||||
}
|
||||
@@ -224,7 +298,7 @@ func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// handleGetUserJobs gets all jobs for a specific user
|
||||
func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Manager) handleGetUserJobs(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
@@ -233,18 +307,25 @@ func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Verify user exists
|
||||
var exists bool
|
||||
err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", userID).Scan(&exists)
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", userID).Scan(&exists)
|
||||
})
|
||||
if err != nil || !exists {
|
||||
s.respondError(w, http.StatusNotFound, "User not found")
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format,
|
||||
allow_parallel_runners, timeout_seconds, blend_metadata, created_at, started_at, completed_at, error_message
|
||||
var rows *sql.Rows
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
var err error
|
||||
rows, err = conn.Query(
|
||||
`SELECT id, user_id, job_type, name, status, progress, frame_start, frame_end, output_format,
|
||||
blend_metadata, created_at, started_at, completed_at, error_message
|
||||
FROM jobs WHERE user_id = ? ORDER BY created_at DESC`,
|
||||
userID,
|
||||
)
|
||||
userID,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query jobs: %v", err))
|
||||
return
|
||||
@@ -260,11 +341,9 @@ func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) {
|
||||
var errorMessage sql.NullString
|
||||
var frameStart, frameEnd sql.NullInt64
|
||||
var outputFormat sql.NullString
|
||||
var allowParallelRunners sql.NullBool
|
||||
|
||||
err := rows.Scan(
|
||||
&job.ID, &job.UserID, &jobType, &job.Name, &job.Status, &job.Progress,
|
||||
&frameStart, &frameEnd, &outputFormat, &allowParallelRunners, &job.TimeoutSeconds,
|
||||
&frameStart, &frameEnd, &outputFormat,
|
||||
&blendMetadataJSON, &job.CreatedAt, &startedAt, &completedAt, &errorMessage,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -284,9 +363,6 @@ func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) {
|
||||
if outputFormat.Valid {
|
||||
job.OutputFormat = &outputFormat.String
|
||||
}
|
||||
if allowParallelRunners.Valid {
|
||||
job.AllowParallelRunners = &allowParallelRunners.Bool
|
||||
}
|
||||
if startedAt.Valid {
|
||||
job.StartedAt = &startedAt.Time
|
||||
}
|
||||
@@ -310,7 +386,7 @@ func (s *Server) handleGetUserJobs(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// handleGetRegistrationEnabled gets the registration enabled setting
|
||||
func (s *Server) handleGetRegistrationEnabled(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Manager) handleGetRegistrationEnabled(w http.ResponseWriter, r *http.Request) {
|
||||
enabled, err := s.auth.IsRegistrationEnabled()
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to get registration setting: %v", err))
|
||||
@@ -320,12 +396,12 @@ func (s *Server) handleGetRegistrationEnabled(w http.ResponseWriter, r *http.Req
|
||||
}
|
||||
|
||||
// handleSetRegistrationEnabled sets the registration enabled setting
|
||||
func (s *Server) handleSetRegistrationEnabled(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Manager) handleSetRegistrationEnabled(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -338,7 +414,7 @@ func (s *Server) handleSetRegistrationEnabled(w http.ResponseWriter, r *http.Req
|
||||
}
|
||||
|
||||
// handleSetUserAdminStatus sets a user's admin status (admin only)
|
||||
func (s *Server) handleSetUserAdminStatus(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Manager) handleSetUserAdminStatus(w http.ResponseWriter, r *http.Request) {
|
||||
targetUserID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
@@ -349,7 +425,7 @@ func (s *Server) handleSetUserAdminStatus(w http.ResponseWriter, r *http.Request
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
831
internal/manager/blender.go
Normal file
831
internal/manager/blender.go
Normal file
@@ -0,0 +1,831 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/bzip2"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
BlenderDownloadBaseURL = "https://download.blender.org/release/"
|
||||
BlenderVersionCacheTTL = 1 * time.Hour
|
||||
)
|
||||
|
||||
// BlenderVersion represents a parsed Blender version
|
||||
type BlenderVersion struct {
|
||||
Major int `json:"major"`
|
||||
Minor int `json:"minor"`
|
||||
Patch int `json:"patch"`
|
||||
Full string `json:"full"` // e.g., "4.2.3"
|
||||
DirName string `json:"dir_name"` // e.g., "Blender4.2"
|
||||
Filename string `json:"filename"` // e.g., "blender-4.2.3-linux-x64.tar.xz"
|
||||
URL string `json:"url"` // Full download URL
|
||||
}
|
||||
|
||||
// BlenderVersionCache caches available Blender versions
|
||||
type BlenderVersionCache struct {
|
||||
versions []BlenderVersion
|
||||
fetchedAt time.Time
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var blenderVersionCache = &BlenderVersionCache{}
|
||||
|
||||
// FetchBlenderVersions fetches available Blender versions from download.blender.org
|
||||
// Returns versions sorted by version number (newest first)
|
||||
func (s *Manager) FetchBlenderVersions() ([]BlenderVersion, error) {
|
||||
// Check cache first
|
||||
blenderVersionCache.mu.RLock()
|
||||
if time.Since(blenderVersionCache.fetchedAt) < BlenderVersionCacheTTL && len(blenderVersionCache.versions) > 0 {
|
||||
versions := make([]BlenderVersion, len(blenderVersionCache.versions))
|
||||
copy(versions, blenderVersionCache.versions)
|
||||
blenderVersionCache.mu.RUnlock()
|
||||
return versions, nil
|
||||
}
|
||||
blenderVersionCache.mu.RUnlock()
|
||||
|
||||
// Fetch from website with timeout
|
||||
client := &http.Client{
|
||||
Timeout: WSWriteDeadline,
|
||||
}
|
||||
resp, err := client.Get(BlenderDownloadBaseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch blender releases: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to fetch blender releases: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
// Parse directory listing for Blender version folders
|
||||
// Looking for patterns like href="Blender4.2/" or href="Blender3.6/"
|
||||
dirPattern := regexp.MustCompile(`href="Blender(\d+)\.(\d+)/"`)
|
||||
log.Printf("Fetching Blender versions from %s", BlenderDownloadBaseURL)
|
||||
matches := dirPattern.FindAllStringSubmatch(string(body), -1)
|
||||
|
||||
// Fetch sub-versions concurrently to speed up the process
|
||||
type versionResult struct {
|
||||
versions []BlenderVersion
|
||||
err error
|
||||
}
|
||||
results := make(chan versionResult, len(matches))
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, match := range matches {
|
||||
if len(match) < 3 {
|
||||
continue
|
||||
}
|
||||
|
||||
major := 0
|
||||
minor := 0
|
||||
fmt.Sscanf(match[1], "%d", &major)
|
||||
fmt.Sscanf(match[2], "%d", &minor)
|
||||
|
||||
// Skip very old versions (pre-2.80)
|
||||
if major < 2 || (major == 2 && minor < 80) {
|
||||
continue
|
||||
}
|
||||
|
||||
dirName := fmt.Sprintf("Blender%d.%d", major, minor)
|
||||
|
||||
// Fetch the specific version directory concurrently
|
||||
wg.Add(1)
|
||||
go func(dn string, maj, min int) {
|
||||
defer wg.Done()
|
||||
subVersions, err := fetchSubVersions(dn, maj, min)
|
||||
results <- versionResult{versions: subVersions, err: err}
|
||||
}(dirName, major, minor)
|
||||
}
|
||||
|
||||
// Close results channel when all goroutines complete
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(results)
|
||||
}()
|
||||
|
||||
var versions []BlenderVersion
|
||||
for result := range results {
|
||||
if result.err != nil {
|
||||
log.Printf("Warning: failed to fetch sub-versions: %v", result.err)
|
||||
continue
|
||||
}
|
||||
versions = append(versions, result.versions...)
|
||||
}
|
||||
|
||||
// Sort by version (newest first)
|
||||
sort.Slice(versions, func(i, j int) bool {
|
||||
if versions[i].Major != versions[j].Major {
|
||||
return versions[i].Major > versions[j].Major
|
||||
}
|
||||
if versions[i].Minor != versions[j].Minor {
|
||||
return versions[i].Minor > versions[j].Minor
|
||||
}
|
||||
return versions[i].Patch > versions[j].Patch
|
||||
})
|
||||
|
||||
// Update cache
|
||||
blenderVersionCache.mu.Lock()
|
||||
blenderVersionCache.versions = versions
|
||||
blenderVersionCache.fetchedAt = time.Now()
|
||||
blenderVersionCache.mu.Unlock()
|
||||
|
||||
return versions, nil
|
||||
}
|
||||
|
||||
// fetchSubVersions fetches specific version files from a Blender release directory
|
||||
func fetchSubVersions(dirName string, major, minor int) ([]BlenderVersion, error) {
|
||||
url := BlenderDownloadBaseURL + dirName + "/"
|
||||
client := &http.Client{
|
||||
Timeout: WSWriteDeadline,
|
||||
}
|
||||
resp, err := client.Get(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Look for linux 64-bit tar.xz/bz2 files
|
||||
// Various naming conventions across versions:
|
||||
// - Modern (2.93+): blender-4.2.3-linux-x64.tar.xz
|
||||
// - 2.83 early: blender-2.83.0-linux64.tar.xz
|
||||
// - 2.80-2.82: blender-2.80-linux-glibc217-x86_64.tar.bz2
|
||||
// Skip: rc versions, alpha/beta, i686 (32-bit)
|
||||
filePatterns := []*regexp.Regexp{
|
||||
// Modern format: blender-X.Y.Z-linux-x64.tar.xz
|
||||
regexp.MustCompile(`blender-(\d+)\.(\d+)\.(\d+)-linux-x64\.tar\.(xz|bz2)`),
|
||||
// Older format: blender-X.Y.Z-linux64.tar.xz
|
||||
regexp.MustCompile(`blender-(\d+)\.(\d+)\.(\d+)-linux64\.tar\.(xz|bz2)`),
|
||||
// glibc format: blender-X.Y.Z-linux-glibc217-x86_64.tar.bz2 (prefer glibc217 for compatibility)
|
||||
regexp.MustCompile(`blender-(\d+)\.(\d+)\.(\d+)-linux-glibc217-x86_64\.tar\.(xz|bz2)`),
|
||||
}
|
||||
|
||||
var versions []BlenderVersion
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, filePattern := range filePatterns {
|
||||
matches := filePattern.FindAllStringSubmatch(string(body), -1)
|
||||
|
||||
for _, match := range matches {
|
||||
if len(match) < 5 {
|
||||
continue
|
||||
}
|
||||
|
||||
patch := 0
|
||||
fmt.Sscanf(match[3], "%d", &patch)
|
||||
|
||||
full := fmt.Sprintf("%d.%d.%d", major, minor, patch)
|
||||
if seen[full] {
|
||||
continue
|
||||
}
|
||||
seen[full] = true
|
||||
|
||||
filename := match[0]
|
||||
versions = append(versions, BlenderVersion{
|
||||
Major: major,
|
||||
Minor: minor,
|
||||
Patch: patch,
|
||||
Full: full,
|
||||
DirName: dirName,
|
||||
Filename: filename,
|
||||
URL: url + filename,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return versions, nil
|
||||
}
|
||||
|
||||
// GetLatestBlenderForMajorMinor returns the latest patch version for a given major.minor
|
||||
// If exact match not found, uses fuzzy matching to find the closest available version
|
||||
func (s *Manager) GetLatestBlenderForMajorMinor(major, minor int) (*BlenderVersion, error) {
|
||||
versions, err := s.FetchBlenderVersions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(versions) == 0 {
|
||||
return nil, fmt.Errorf("no blender versions available")
|
||||
}
|
||||
|
||||
// Try exact match first - find the highest patch for this major.minor
|
||||
var exactMatch *BlenderVersion
|
||||
for i := range versions {
|
||||
v := &versions[i]
|
||||
if v.Major == major && v.Minor == minor {
|
||||
if exactMatch == nil || v.Patch > exactMatch.Patch {
|
||||
exactMatch = v
|
||||
}
|
||||
}
|
||||
}
|
||||
if exactMatch != nil {
|
||||
log.Printf("Found Blender %d.%d.%d for requested %d.%d", exactMatch.Major, exactMatch.Minor, exactMatch.Patch, major, minor)
|
||||
return exactMatch, nil
|
||||
}
|
||||
|
||||
// Fuzzy matching: find closest version
|
||||
// Priority: same major with closest minor > closest major
|
||||
log.Printf("No exact match for Blender %d.%d, using fuzzy matching", major, minor)
|
||||
|
||||
var bestMatch *BlenderVersion
|
||||
bestScore := -1000000 // Large negative number
|
||||
|
||||
for i := range versions {
|
||||
v := &versions[i]
|
||||
score := 0
|
||||
|
||||
if v.Major == major {
|
||||
// Same major version - prefer this
|
||||
score = 10000
|
||||
|
||||
// Prefer lower minor versions (more stable/compatible)
|
||||
// but not too far back
|
||||
minorDiff := minor - v.Minor
|
||||
if minorDiff >= 0 {
|
||||
// v.Minor <= minor (older or same) - prefer closer
|
||||
score += 1000 - minorDiff*10
|
||||
} else {
|
||||
// v.Minor > minor (newer) - less preferred but acceptable
|
||||
score += 500 + minorDiff*10
|
||||
}
|
||||
|
||||
// Higher patch is better
|
||||
score += v.Patch
|
||||
} else {
|
||||
// Different major - less preferred
|
||||
majorDiff := major - v.Major
|
||||
if majorDiff > 0 {
|
||||
// v.Major < major (older major) - acceptable fallback
|
||||
score = 5000 - majorDiff*1000 + v.Minor*10 + v.Patch
|
||||
} else {
|
||||
// v.Major > major (newer major) - avoid if possible
|
||||
score = -majorDiff * 1000
|
||||
}
|
||||
}
|
||||
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
bestMatch = v
|
||||
}
|
||||
}
|
||||
|
||||
if bestMatch != nil {
|
||||
log.Printf("Fuzzy match: requested %d.%d, using %d.%d.%d", major, minor, bestMatch.Major, bestMatch.Minor, bestMatch.Patch)
|
||||
return bestMatch, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no blender version found for %d.%d", major, minor)
|
||||
}
|
||||
|
||||
// GetBlenderArchivePath returns the path to the cached blender archive for a specific version
|
||||
// Downloads from blender.org and decompresses to .tar if not already cached
|
||||
// The manager caches as uncompressed .tar to save decompression time on runners
|
||||
func (s *Manager) GetBlenderArchivePath(version *BlenderVersion) (string, error) {
|
||||
// Base directory for blender archives
|
||||
blenderDir := filepath.Join(s.storage.BasePath(), "blender-versions")
|
||||
if err := os.MkdirAll(blenderDir, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create blender directory: %w", err)
|
||||
}
|
||||
|
||||
// Cache as uncompressed .tar for faster runner downloads
|
||||
// Convert filename like "blender-4.2.3-linux-x64.tar.xz" to "blender-4.2.3-linux-x64.tar"
|
||||
tarFilename := version.Filename
|
||||
tarFilename = strings.TrimSuffix(tarFilename, ".xz")
|
||||
tarFilename = strings.TrimSuffix(tarFilename, ".bz2")
|
||||
archivePath := filepath.Join(blenderDir, tarFilename)
|
||||
|
||||
// Check if already cached as .tar
|
||||
if _, err := os.Stat(archivePath); err == nil {
|
||||
log.Printf("Using cached Blender %s at %s", version.Full, archivePath)
|
||||
// Clean up any extracted folders that might exist
|
||||
s.cleanupExtractedBlenderFolders(blenderDir, version)
|
||||
return archivePath, nil
|
||||
}
|
||||
|
||||
// Need to download and decompress
|
||||
log.Printf("Downloading Blender %s from %s", version.Full, version.URL)
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 0, // No timeout for large downloads
|
||||
}
|
||||
resp, err := client.Get(version.URL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to download blender: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("failed to download blender: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Download to temp file first
|
||||
compressedPath := filepath.Join(blenderDir, "download-"+version.Filename)
|
||||
compressedFile, err := os.Create(compressedPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(compressedFile, resp.Body); err != nil {
|
||||
compressedFile.Close()
|
||||
os.Remove(compressedPath)
|
||||
return "", fmt.Errorf("failed to download blender: %w", err)
|
||||
}
|
||||
compressedFile.Close()
|
||||
|
||||
log.Printf("Downloaded Blender %s, decompressing to .tar...", version.Full)
|
||||
|
||||
// Decompress to .tar
|
||||
if err := decompressToTar(compressedPath, archivePath); err != nil {
|
||||
os.Remove(compressedPath)
|
||||
os.Remove(archivePath)
|
||||
return "", fmt.Errorf("failed to decompress blender archive: %w", err)
|
||||
}
|
||||
|
||||
// Remove compressed file
|
||||
os.Remove(compressedPath)
|
||||
|
||||
// Clean up any extracted folders for this version (if they exist)
|
||||
s.cleanupExtractedBlenderFolders(blenderDir, version)
|
||||
|
||||
log.Printf("Blender %s cached at %s", version.Full, archivePath)
|
||||
return archivePath, nil
|
||||
}
|
||||
|
||||
// decompressToTar decompresses a .tar.xz or .tar.bz2 file to a plain .tar file
|
||||
func decompressToTar(compressedPath, tarPath string) error {
|
||||
if strings.HasSuffix(compressedPath, ".tar.xz") {
|
||||
// Use xz command for decompression
|
||||
cmd := exec.Command("xz", "-d", "-k", "-c", compressedPath)
|
||||
outFile, err := os.Create(tarPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer outFile.Close()
|
||||
|
||||
cmd.Stdout = outFile
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("xz decompression failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
} else if strings.HasSuffix(compressedPath, ".tar.bz2") {
|
||||
// Use bzip2 for decompression
|
||||
inFile, err := os.Open(compressedPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer inFile.Close()
|
||||
|
||||
bzReader := bzip2.NewReader(inFile)
|
||||
outFile, err := os.Create(tarPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer outFile.Close()
|
||||
|
||||
if _, err := io.Copy(outFile, bzReader); err != nil {
|
||||
return fmt.Errorf("bzip2 decompression failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("unsupported compression format: %s", compressedPath)
|
||||
}
|
||||
|
||||
// cleanupExtractedBlenderFolders removes any extracted Blender folders for the given version
|
||||
// This ensures we only keep the .tar file and not extracted folders
|
||||
func (s *Manager) cleanupExtractedBlenderFolders(blenderDir string, version *BlenderVersion) {
|
||||
// Look for folders matching the version (e.g., "4.2.3", "2.83.20")
|
||||
versionDirs := []string{
|
||||
filepath.Join(blenderDir, version.Full), // e.g., "4.2.3"
|
||||
filepath.Join(blenderDir, fmt.Sprintf("%d.%d", version.Major, version.Minor)), // e.g., "4.2"
|
||||
}
|
||||
|
||||
for _, dir := range versionDirs {
|
||||
if info, err := os.Stat(dir); err == nil && info.IsDir() {
|
||||
log.Printf("Removing extracted Blender folder: %s", dir)
|
||||
if err := os.RemoveAll(dir); err != nil {
|
||||
log.Printf("Warning: failed to remove extracted folder %s: %v", dir, err)
|
||||
} else {
|
||||
log.Printf("Removed extracted Blender folder: %s", dir)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ParseBlenderVersionFromFile parses the Blender version that a .blend file was saved with
|
||||
// This reads the file header to determine the version
|
||||
func ParseBlenderVersionFromFile(blendPath string) (major, minor int, err error) {
|
||||
file, err := os.Open(blendPath)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to open blend file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
return ParseBlenderVersionFromReader(file)
|
||||
}
|
||||
|
||||
// ParseBlenderVersionFromReader parses the Blender version from a reader
|
||||
// Useful for reading from uploaded files without saving to disk first
|
||||
func ParseBlenderVersionFromReader(r io.ReadSeeker) (major, minor int, err error) {
|
||||
// Read the first 12 bytes of the blend file header
|
||||
// Format: BLENDER-v<major><minor><patch> or BLENDER_v<major><minor><patch>
|
||||
// The header is: "BLENDER" (7 bytes) + pointer size (1 byte: '-' for 64-bit, '_' for 32-bit)
|
||||
// + endianness (1 byte: 'v' for little-endian, 'V' for big-endian)
|
||||
// + version (3 bytes: e.g., "402" for 4.02)
|
||||
header := make([]byte, 12)
|
||||
n, err := r.Read(header)
|
||||
if err != nil || n < 12 {
|
||||
return 0, 0, fmt.Errorf("failed to read blend file header: %w", err)
|
||||
}
|
||||
|
||||
// Check for BLENDER magic
|
||||
if string(header[:7]) != "BLENDER" {
|
||||
// Might be compressed - try to decompress
|
||||
r.Seek(0, 0)
|
||||
return parseCompressedBlendVersion(r)
|
||||
}
|
||||
|
||||
// Parse version from bytes 9-11 (3 digits)
|
||||
versionStr := string(header[9:12])
|
||||
var vMajor, vMinor int
|
||||
|
||||
// Version format changed in Blender 3.0
|
||||
// Pre-3.0: "279" = 2.79, "280" = 2.80
|
||||
// 3.0+: "300" = 3.0, "402" = 4.02, "410" = 4.10
|
||||
if len(versionStr) == 3 {
|
||||
// First digit is major version
|
||||
fmt.Sscanf(string(versionStr[0]), "%d", &vMajor)
|
||||
// Next two digits are minor version
|
||||
fmt.Sscanf(versionStr[1:3], "%d", &vMinor)
|
||||
}
|
||||
|
||||
return vMajor, vMinor, nil
|
||||
}
|
||||
|
||||
// parseCompressedBlendVersion handles gzip and zstd compressed blend files
|
||||
func parseCompressedBlendVersion(r io.ReadSeeker) (major, minor int, err error) {
|
||||
// Check for compression magic bytes
|
||||
magic := make([]byte, 4)
|
||||
if _, err := r.Read(magic); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
r.Seek(0, 0)
|
||||
|
||||
if magic[0] == 0x1f && magic[1] == 0x8b {
|
||||
// gzip compressed
|
||||
gzReader, err := gzip.NewReader(r)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
}
|
||||
defer gzReader.Close()
|
||||
|
||||
header := make([]byte, 12)
|
||||
n, err := gzReader.Read(header)
|
||||
if err != nil || n < 12 {
|
||||
return 0, 0, fmt.Errorf("failed to read compressed blend header: %w", err)
|
||||
}
|
||||
|
||||
if string(header[:7]) != "BLENDER" {
|
||||
return 0, 0, fmt.Errorf("invalid blend file format")
|
||||
}
|
||||
|
||||
versionStr := string(header[9:12])
|
||||
var vMajor, vMinor int
|
||||
if len(versionStr) == 3 {
|
||||
fmt.Sscanf(string(versionStr[0]), "%d", &vMajor)
|
||||
fmt.Sscanf(versionStr[1:3], "%d", &vMinor)
|
||||
}
|
||||
|
||||
return vMajor, vMinor, nil
|
||||
}
|
||||
|
||||
// Check for zstd magic (Blender 3.0+): 0x28 0xB5 0x2F 0xFD
|
||||
if magic[0] == 0x28 && magic[1] == 0xb5 && magic[2] == 0x2f && magic[3] == 0xfd {
|
||||
return parseZstdBlendVersion(r)
|
||||
}
|
||||
|
||||
return 0, 0, fmt.Errorf("unknown blend file format")
|
||||
}
|
||||
|
||||
// parseZstdBlendVersion handles zstd-compressed blend files (Blender 3.0+)
|
||||
// Uses zstd command line tool since Go doesn't have native zstd support
|
||||
func parseZstdBlendVersion(r io.ReadSeeker) (major, minor int, err error) {
|
||||
r.Seek(0, 0)
|
||||
|
||||
// We need to decompress just enough to read the header
|
||||
// Use zstd command to decompress from stdin
|
||||
cmd := exec.Command("zstd", "-d", "-c")
|
||||
cmd.Stdin = r
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to create zstd stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to start zstd decompression: %w", err)
|
||||
}
|
||||
|
||||
// Read just the header (12 bytes)
|
||||
header := make([]byte, 12)
|
||||
n, readErr := io.ReadFull(stdout, header)
|
||||
|
||||
// Kill the process early - we only need the header
|
||||
cmd.Process.Kill()
|
||||
cmd.Wait()
|
||||
|
||||
if readErr != nil || n < 12 {
|
||||
return 0, 0, fmt.Errorf("failed to read zstd compressed blend header: %v", readErr)
|
||||
}
|
||||
|
||||
if string(header[:7]) != "BLENDER" {
|
||||
return 0, 0, fmt.Errorf("invalid blend file format in zstd archive")
|
||||
}
|
||||
|
||||
versionStr := string(header[9:12])
|
||||
var vMajor, vMinor int
|
||||
if len(versionStr) == 3 {
|
||||
fmt.Sscanf(string(versionStr[0]), "%d", &vMajor)
|
||||
fmt.Sscanf(versionStr[1:3], "%d", &vMinor)
|
||||
}
|
||||
|
||||
return vMajor, vMinor, nil
|
||||
}
|
||||
|
||||
// handleGetBlenderVersions returns available Blender versions
|
||||
func (s *Manager) handleGetBlenderVersions(w http.ResponseWriter, r *http.Request) {
|
||||
versions, err := s.FetchBlenderVersions()
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to fetch blender versions: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Group by major.minor for easier frontend display
|
||||
type VersionGroup struct {
|
||||
MajorMinor string `json:"major_minor"`
|
||||
Latest BlenderVersion `json:"latest"`
|
||||
All []BlenderVersion `json:"all"`
|
||||
}
|
||||
|
||||
groups := make(map[string]*VersionGroup)
|
||||
for _, v := range versions {
|
||||
key := fmt.Sprintf("%d.%d", v.Major, v.Minor)
|
||||
if groups[key] == nil {
|
||||
groups[key] = &VersionGroup{
|
||||
MajorMinor: key,
|
||||
Latest: v, // First one is latest due to sorting
|
||||
All: []BlenderVersion{v},
|
||||
}
|
||||
} else {
|
||||
groups[key].All = append(groups[key].All, v)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to slice and sort by version
|
||||
var groupedResult []VersionGroup
|
||||
for _, g := range groups {
|
||||
groupedResult = append(groupedResult, *g)
|
||||
}
|
||||
sort.Slice(groupedResult, func(i, j int) bool {
|
||||
// Parse major.minor for comparison
|
||||
var iMaj, iMin, jMaj, jMin int
|
||||
fmt.Sscanf(groupedResult[i].MajorMinor, "%d.%d", &iMaj, &iMin)
|
||||
fmt.Sscanf(groupedResult[j].MajorMinor, "%d.%d", &jMaj, &jMin)
|
||||
if iMaj != jMaj {
|
||||
return iMaj > jMaj
|
||||
}
|
||||
return iMin > jMin
|
||||
})
|
||||
|
||||
// Return both flat list and grouped for flexibility
|
||||
response := map[string]interface{}{
|
||||
"versions": versions, // Flat list of all versions (newest first)
|
||||
"grouped": groupedResult, // Grouped by major.minor
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
// handleDownloadBlender serves a cached Blender archive to runners
|
||||
func (s *Manager) handleDownloadBlender(w http.ResponseWriter, r *http.Request) {
|
||||
version := r.URL.Query().Get("version")
|
||||
if version == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "version parameter required")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse version string (e.g., "4.2.3" or "4.2")
|
||||
var major, minor, patch int
|
||||
parts := strings.Split(version, ".")
|
||||
if len(parts) < 2 {
|
||||
s.respondError(w, http.StatusBadRequest, "invalid version format, expected major.minor or major.minor.patch")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Sscanf(parts[0], "%d", &major)
|
||||
fmt.Sscanf(parts[1], "%d", &minor)
|
||||
if len(parts) >= 3 {
|
||||
fmt.Sscanf(parts[2], "%d", &patch)
|
||||
}
|
||||
|
||||
// Find the version
|
||||
var blenderVersion *BlenderVersion
|
||||
if len(parts) >= 3 {
|
||||
// Exact patch version requested - find it
|
||||
versions, err := s.FetchBlenderVersions()
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to fetch versions: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
for _, v := range versions {
|
||||
if v.Major == major && v.Minor == minor && v.Patch == patch {
|
||||
blenderVersion = &v
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if blenderVersion == nil {
|
||||
s.respondError(w, http.StatusNotFound, fmt.Sprintf("blender version %s not found", version))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Major.minor only - use helper to get latest patch version
|
||||
var err error
|
||||
blenderVersion, err = s.GetLatestBlenderForMajorMinor(major, minor)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusNotFound, fmt.Sprintf("blender version %s not found: %v", version, err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Get or download the archive
|
||||
archivePath, err := s.GetBlenderArchivePath(blenderVersion)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to get blender archive: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Serve the file
|
||||
file, err := os.Open(archivePath)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to open archive: %v", err))
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to stat archive: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Filename is now .tar (decompressed)
|
||||
tarFilename := blenderVersion.Filename
|
||||
tarFilename = strings.TrimSuffix(tarFilename, ".xz")
|
||||
tarFilename = strings.TrimSuffix(tarFilename, ".bz2")
|
||||
|
||||
w.Header().Set("Content-Type", "application/x-tar")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", tarFilename))
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
|
||||
w.Header().Set("X-Blender-Version", blenderVersion.Full)
|
||||
|
||||
io.Copy(w, file)
|
||||
}
|
||||
|
||||
// Unused functions from extraction - keeping for reference but not needed on manager
|
||||
var _ = extractBlenderArchive
|
||||
var _ = extractTarXz
|
||||
var _ = extractTar
|
||||
|
||||
// extractBlenderArchive extracts a blender archive (already decompressed to .tar by GetBlenderArchivePath)
|
||||
func extractBlenderArchive(archivePath string, version *BlenderVersion, destDir string) error {
|
||||
file, err := os.Open(archivePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// The archive is already decompressed to .tar by GetBlenderArchivePath
|
||||
// Just extract it directly
|
||||
if strings.HasSuffix(archivePath, ".tar") {
|
||||
tarReader := tar.NewReader(file)
|
||||
return extractTar(tarReader, version, destDir)
|
||||
}
|
||||
|
||||
// Fallback for any other format (shouldn't happen with current flow)
|
||||
if strings.HasSuffix(archivePath, ".tar.xz") {
|
||||
return extractTarXz(archivePath, version, destDir)
|
||||
} else if strings.HasSuffix(archivePath, ".tar.bz2") {
|
||||
bzReader := bzip2.NewReader(file)
|
||||
tarReader := tar.NewReader(bzReader)
|
||||
return extractTar(tarReader, version, destDir)
|
||||
}
|
||||
|
||||
return fmt.Errorf("unsupported archive format: %s", archivePath)
|
||||
}
|
||||
|
||||
// extractTarXz extracts a tar.xz archive using the xz command
|
||||
func extractTarXz(archivePath string, version *BlenderVersion, destDir string) error {
|
||||
versionDir := filepath.Join(destDir, version.Full)
|
||||
if err := os.MkdirAll(versionDir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd := exec.Command("tar", "-xJf", archivePath, "-C", versionDir, "--strip-components=1")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("tar extraction failed: %v, output: %s", err, string(output))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractTar extracts files from a tar reader
|
||||
func extractTar(tarReader *tar.Reader, version *BlenderVersion, destDir string) error {
|
||||
versionDir := filepath.Join(destDir, version.Full)
|
||||
if err := os.MkdirAll(versionDir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stripPrefix := ""
|
||||
|
||||
for {
|
||||
header, err := tarReader.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if stripPrefix == "" {
|
||||
parts := strings.SplitN(header.Name, "/", 2)
|
||||
if len(parts) > 0 {
|
||||
stripPrefix = parts[0] + "/"
|
||||
}
|
||||
}
|
||||
|
||||
name := strings.TrimPrefix(header.Name, stripPrefix)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
targetPath := filepath.Join(versionDir, name)
|
||||
|
||||
switch header.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil {
|
||||
return err
|
||||
}
|
||||
case tar.TypeReg:
|
||||
if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
outFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.Copy(outFile, tarReader); err != nil {
|
||||
outFile.Close()
|
||||
return err
|
||||
}
|
||||
outFile.Close()
|
||||
case tar.TypeSymlink:
|
||||
if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Symlink(header.Linkname, targetPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
5213
internal/manager/jobs.go
Normal file
5213
internal/manager/jobs.go
Normal file
File diff suppressed because it is too large
Load Diff
1307
internal/manager/manager.go
Normal file
1307
internal/manager/manager.go
Normal file
File diff suppressed because it is too large
Load Diff
265
internal/manager/metadata.go
Normal file
265
internal/manager/metadata.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"jiggablend/pkg/executils"
|
||||
"jiggablend/pkg/scripts"
|
||||
"jiggablend/pkg/types"
|
||||
)
|
||||
|
||||
// handleGetJobMetadata retrieves metadata for a job
|
||||
func (s *Manager) handleGetJobMetadata(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := getUserID(r)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jobID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Verify job belongs to user
|
||||
var jobUserID int64
|
||||
var blendMetadataJSON sql.NullString
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow(
|
||||
`SELECT user_id, blend_metadata FROM jobs WHERE id = ?`,
|
||||
jobID,
|
||||
).Scan(&jobUserID, &blendMetadataJSON)
|
||||
})
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Job not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query job: %v", err))
|
||||
return
|
||||
}
|
||||
if jobUserID != userID {
|
||||
s.respondError(w, http.StatusForbidden, "Access denied")
|
||||
return
|
||||
}
|
||||
|
||||
if !blendMetadataJSON.Valid || blendMetadataJSON.String == "" {
|
||||
s.respondJSON(w, http.StatusOK, nil)
|
||||
return
|
||||
}
|
||||
|
||||
var metadata types.BlendMetadata
|
||||
if err := json.Unmarshal([]byte(blendMetadataJSON.String), &metadata); err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, "Failed to parse metadata")
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, metadata)
|
||||
}
|
||||
|
||||
// extractMetadataFromContext extracts metadata from the blend file in a context archive
|
||||
// Returns the extracted metadata or an error
|
||||
func (s *Manager) extractMetadataFromContext(jobID int64) (*types.BlendMetadata, error) {
|
||||
contextPath := filepath.Join(s.storage.JobPath(jobID), "context.tar")
|
||||
|
||||
// Check if context exists
|
||||
if _, err := os.Stat(contextPath); err != nil {
|
||||
return nil, fmt.Errorf("context archive not found: %w", err)
|
||||
}
|
||||
|
||||
// Create temporary directory for extraction under storage base path
|
||||
tmpDir, err := s.storage.TempDir(fmt.Sprintf("jiggablend-metadata-%d-*", jobID))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create temporary directory: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.RemoveAll(tmpDir); err != nil {
|
||||
log.Printf("Warning: Failed to clean up temp directory %s: %v", tmpDir, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Extract context archive
|
||||
if err := s.extractTar(contextPath, tmpDir); err != nil {
|
||||
return nil, fmt.Errorf("failed to extract context: %w", err)
|
||||
}
|
||||
|
||||
// Find .blend file in extracted contents
|
||||
blendFile := ""
|
||||
err = filepath.Walk(tmpDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".blend") {
|
||||
// Check it's not a Blender save file (.blend1, .blend2, etc.)
|
||||
lower := strings.ToLower(info.Name())
|
||||
idx := strings.LastIndex(lower, ".blend")
|
||||
if idx != -1 {
|
||||
suffix := lower[idx+len(".blend"):]
|
||||
// If there are digits after .blend, it's a save file
|
||||
isSaveFile := false
|
||||
if len(suffix) > 0 {
|
||||
isSaveFile = true
|
||||
for _, r := range suffix {
|
||||
if r < '0' || r > '9' {
|
||||
isSaveFile = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !isSaveFile {
|
||||
blendFile = path
|
||||
return filepath.SkipAll // Stop walking once we find a blend file
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find blend file: %w", err)
|
||||
}
|
||||
|
||||
if blendFile == "" {
|
||||
return nil, fmt.Errorf("no .blend file found in context - the uploaded context archive must contain at least one .blend file for metadata extraction")
|
||||
}
|
||||
|
||||
// Use embedded Python script
|
||||
scriptPath := filepath.Join(tmpDir, "extract_metadata.py")
|
||||
if err := os.WriteFile(scriptPath, []byte(scripts.ExtractMetadata), 0644); err != nil {
|
||||
return nil, fmt.Errorf("failed to create extraction script: %w", err)
|
||||
}
|
||||
|
||||
// Make blend file path relative to tmpDir to avoid path resolution issues
|
||||
blendFileRel, err := filepath.Rel(tmpDir, blendFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get relative path for blend file: %w", err)
|
||||
}
|
||||
|
||||
// Execute Blender with Python script using executils
|
||||
result, err := executils.RunCommand(
|
||||
"blender",
|
||||
[]string{"-b", blendFileRel, "--python", "extract_metadata.py"},
|
||||
tmpDir,
|
||||
nil, // inherit environment
|
||||
jobID,
|
||||
nil, // no process tracker needed for metadata extraction
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
stderrOutput := ""
|
||||
stdoutOutput := ""
|
||||
if result != nil {
|
||||
stderrOutput = strings.TrimSpace(result.Stderr)
|
||||
stdoutOutput = strings.TrimSpace(result.Stdout)
|
||||
}
|
||||
log.Printf("Blender metadata extraction failed for job %d:", jobID)
|
||||
if stderrOutput != "" {
|
||||
log.Printf("Blender stderr: %s", stderrOutput)
|
||||
}
|
||||
if stdoutOutput != "" {
|
||||
log.Printf("Blender stdout (last 500 chars): %s", truncateString(stdoutOutput, 500))
|
||||
}
|
||||
if stderrOutput != "" {
|
||||
return nil, fmt.Errorf("blender metadata extraction failed: %w (stderr: %s)", err, truncateString(stderrOutput, 200))
|
||||
}
|
||||
return nil, fmt.Errorf("blender metadata extraction failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse output (metadata is printed to stdout)
|
||||
metadataJSON := strings.TrimSpace(result.Stdout)
|
||||
// Extract JSON from output (Blender may print other stuff)
|
||||
jsonStart := strings.Index(metadataJSON, "{")
|
||||
jsonEnd := strings.LastIndex(metadataJSON, "}")
|
||||
if jsonStart == -1 || jsonEnd == -1 || jsonEnd <= jsonStart {
|
||||
return nil, errors.New("failed to extract JSON from Blender output")
|
||||
}
|
||||
metadataJSON = metadataJSON[jsonStart : jsonEnd+1]
|
||||
|
||||
var metadata types.BlendMetadata
|
||||
if err := json.Unmarshal([]byte(metadataJSON), &metadata); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse metadata JSON: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Metadata extracted for job %d: frame_start=%d, frame_end=%d", jobID, metadata.FrameStart, metadata.FrameEnd)
|
||||
return &metadata, nil
|
||||
}
|
||||
|
||||
// extractTar extracts a tar archive to a destination directory
|
||||
func (s *Manager) extractTar(tarPath, destDir string) error {
|
||||
log.Printf("Extracting tar archive: %s -> %s", tarPath, destDir)
|
||||
|
||||
// Ensure destination directory exists
|
||||
if err := os.MkdirAll(destDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create destination directory: %w", err)
|
||||
}
|
||||
|
||||
file, err := os.Open(tarPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open archive: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
tr := tar.NewReader(file)
|
||||
|
||||
fileCount := 0
|
||||
dirCount := 0
|
||||
|
||||
for {
|
||||
header, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read tar header: %w", err)
|
||||
}
|
||||
|
||||
// Sanitize path to prevent directory traversal
|
||||
target := filepath.Join(destDir, header.Name)
|
||||
|
||||
// Ensure target is within destDir
|
||||
cleanTarget := filepath.Clean(target)
|
||||
cleanDestDir := filepath.Clean(destDir)
|
||||
if !strings.HasPrefix(cleanTarget, cleanDestDir+string(os.PathSeparator)) && cleanTarget != cleanDestDir {
|
||||
log.Printf("ERROR: Invalid file path in TAR - target: %s, destDir: %s", cleanTarget, cleanDestDir)
|
||||
return fmt.Errorf("invalid file path in archive: %s (target: %s, destDir: %s)", header.Name, cleanTarget, cleanDestDir)
|
||||
}
|
||||
|
||||
// Create parent directories
|
||||
if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
// Write file
|
||||
switch header.Typeflag {
|
||||
case tar.TypeReg:
|
||||
outFile, err := os.Create(target)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create file: %w", err)
|
||||
}
|
||||
_, err = io.Copy(outFile, tr)
|
||||
if err != nil {
|
||||
outFile.Close()
|
||||
return fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
outFile.Close()
|
||||
fileCount++
|
||||
case tar.TypeDir:
|
||||
dirCount++
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("Extraction complete: %d files, %d directories extracted to %s", fileCount, dirCount, destDir)
|
||||
return nil
|
||||
}
|
||||
2529
internal/manager/runners.go
Normal file
2529
internal/manager/runners.go
Normal file
File diff suppressed because it is too large
Load Diff
333
internal/runner/api/jobconn.go
Normal file
333
internal/runner/api/jobconn.go
Normal file
@@ -0,0 +1,333 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"jiggablend/pkg/types"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// JobConnection wraps a WebSocket connection for job communication.
|
||||
type JobConnection struct {
|
||||
conn *websocket.Conn
|
||||
writeMu sync.Mutex
|
||||
stopPing chan struct{}
|
||||
stopHeartbeat chan struct{}
|
||||
isConnected bool
|
||||
connMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewJobConnection creates a new job connection wrapper.
|
||||
func NewJobConnection() *JobConnection {
|
||||
return &JobConnection{}
|
||||
}
|
||||
|
||||
// Connect establishes a WebSocket connection for a job (no runnerID needed).
|
||||
func (j *JobConnection) Connect(managerURL, jobPath, jobToken string) error {
|
||||
wsPath := jobPath + "/ws"
|
||||
wsURL := strings.Replace(managerURL, "http://", "ws://", 1)
|
||||
wsURL = strings.Replace(wsURL, "https://", "wss://", 1)
|
||||
wsURL += wsPath
|
||||
|
||||
log.Printf("Connecting to job WebSocket: %s", wsPath)
|
||||
|
||||
dialer := websocket.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
}
|
||||
conn, _, err := dialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect job WebSocket: %w", err)
|
||||
}
|
||||
|
||||
j.conn = conn
|
||||
|
||||
// Send auth message
|
||||
authMsg := map[string]interface{}{
|
||||
"type": "auth",
|
||||
"job_token": jobToken,
|
||||
}
|
||||
if err := conn.WriteJSON(authMsg); err != nil {
|
||||
conn.Close()
|
||||
return fmt.Errorf("failed to send auth: %w", err)
|
||||
}
|
||||
|
||||
// Wait for auth_ok
|
||||
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
var authResp map[string]string
|
||||
if err := conn.ReadJSON(&authResp); err != nil {
|
||||
conn.Close()
|
||||
return fmt.Errorf("failed to read auth response: %w", err)
|
||||
}
|
||||
if authResp["type"] == "error" {
|
||||
conn.Close()
|
||||
return fmt.Errorf("auth failed: %s", authResp["message"])
|
||||
}
|
||||
if authResp["type"] != "auth_ok" {
|
||||
conn.Close()
|
||||
return fmt.Errorf("unexpected auth response: %s", authResp["type"])
|
||||
}
|
||||
|
||||
// Clear read deadline after auth
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// Set up ping/pong handler for keepalive
|
||||
conn.SetPongHandler(func(string) error {
|
||||
conn.SetReadDeadline(time.Now().Add(90 * time.Second))
|
||||
return nil
|
||||
})
|
||||
|
||||
// Start ping goroutine
|
||||
j.stopPing = make(chan struct{})
|
||||
j.connMu.Lock()
|
||||
j.isConnected = true
|
||||
j.connMu.Unlock()
|
||||
go j.pingLoop()
|
||||
|
||||
// Start WebSocket heartbeat goroutine
|
||||
j.stopHeartbeat = make(chan struct{})
|
||||
go j.heartbeatLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// pingLoop sends periodic pings to keep the WebSocket connection alive.
|
||||
func (j *JobConnection) pingLoop() {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
log.Printf("Ping loop panicked: %v", rec)
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-j.stopPing:
|
||||
return
|
||||
case <-ticker.C:
|
||||
j.writeMu.Lock()
|
||||
if j.conn != nil {
|
||||
deadline := time.Now().Add(10 * time.Second)
|
||||
if err := j.conn.WriteControl(websocket.PingMessage, []byte{}, deadline); err != nil {
|
||||
log.Printf("Failed to send ping, closing connection: %v", err)
|
||||
j.connMu.Lock()
|
||||
j.isConnected = false
|
||||
if j.conn != nil {
|
||||
j.conn.Close()
|
||||
j.conn = nil
|
||||
}
|
||||
j.connMu.Unlock()
|
||||
}
|
||||
}
|
||||
j.writeMu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Heartbeat sends a heartbeat message over WebSocket to keep runner online.
|
||||
func (j *JobConnection) Heartbeat() {
|
||||
if j.conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
j.writeMu.Lock()
|
||||
defer j.writeMu.Unlock()
|
||||
|
||||
msg := map[string]interface{}{
|
||||
"type": "runner_heartbeat",
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
|
||||
if err := j.conn.WriteJSON(msg); err != nil {
|
||||
log.Printf("Failed to send WebSocket heartbeat: %v", err)
|
||||
// Handle connection failure
|
||||
j.connMu.Lock()
|
||||
j.isConnected = false
|
||||
if j.conn != nil {
|
||||
j.conn.Close()
|
||||
j.conn = nil
|
||||
}
|
||||
j.connMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeatLoop sends periodic heartbeat messages over WebSocket.
|
||||
func (j *JobConnection) heartbeatLoop() {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
log.Printf("WebSocket heartbeat loop panicked: %v", rec)
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-j.stopHeartbeat:
|
||||
return
|
||||
case <-ticker.C:
|
||||
j.Heartbeat()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the WebSocket connection.
|
||||
func (j *JobConnection) Close() {
|
||||
j.connMu.Lock()
|
||||
j.isConnected = false
|
||||
j.connMu.Unlock()
|
||||
|
||||
// Stop heartbeat goroutine
|
||||
if j.stopHeartbeat != nil {
|
||||
close(j.stopHeartbeat)
|
||||
j.stopHeartbeat = nil
|
||||
}
|
||||
|
||||
// Stop ping goroutine
|
||||
if j.stopPing != nil {
|
||||
close(j.stopPing)
|
||||
j.stopPing = nil
|
||||
}
|
||||
|
||||
if j.conn != nil {
|
||||
j.conn.Close()
|
||||
j.conn = nil
|
||||
}
|
||||
}
|
||||
|
||||
// IsConnected returns true if the connection is established.
|
||||
func (j *JobConnection) IsConnected() bool {
|
||||
j.connMu.RLock()
|
||||
defer j.connMu.RUnlock()
|
||||
return j.isConnected && j.conn != nil
|
||||
}
|
||||
|
||||
// Log sends a log entry to the manager.
|
||||
func (j *JobConnection) Log(taskID int64, level types.LogLevel, message string) {
|
||||
if j.conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
j.writeMu.Lock()
|
||||
defer j.writeMu.Unlock()
|
||||
|
||||
msg := map[string]interface{}{
|
||||
"type": "log_entry",
|
||||
"data": map[string]interface{}{
|
||||
"task_id": taskID,
|
||||
"log_level": string(level),
|
||||
"message": message,
|
||||
},
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
if err := j.conn.WriteJSON(msg); err != nil {
|
||||
log.Printf("Failed to send job log, connection may be broken: %v", err)
|
||||
// Close the connection on write error
|
||||
j.connMu.Lock()
|
||||
j.isConnected = false
|
||||
if j.conn != nil {
|
||||
j.conn.Close()
|
||||
j.conn = nil
|
||||
}
|
||||
j.connMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Progress sends a progress update to the manager.
|
||||
func (j *JobConnection) Progress(taskID int64, progress float64) {
|
||||
if j.conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
j.writeMu.Lock()
|
||||
defer j.writeMu.Unlock()
|
||||
|
||||
msg := map[string]interface{}{
|
||||
"type": "progress",
|
||||
"data": map[string]interface{}{
|
||||
"task_id": taskID,
|
||||
"progress": progress,
|
||||
},
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
if err := j.conn.WriteJSON(msg); err != nil {
|
||||
log.Printf("Failed to send job progress, connection may be broken: %v", err)
|
||||
// Close the connection on write error
|
||||
j.connMu.Lock()
|
||||
j.isConnected = false
|
||||
if j.conn != nil {
|
||||
j.conn.Close()
|
||||
j.conn = nil
|
||||
}
|
||||
j.connMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// OutputUploaded notifies that an output file was uploaded.
|
||||
func (j *JobConnection) OutputUploaded(taskID int64, fileName string) {
|
||||
if j.conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
j.writeMu.Lock()
|
||||
defer j.writeMu.Unlock()
|
||||
|
||||
msg := map[string]interface{}{
|
||||
"type": "output_uploaded",
|
||||
"data": map[string]interface{}{
|
||||
"task_id": taskID,
|
||||
"file_name": fileName,
|
||||
},
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
if err := j.conn.WriteJSON(msg); err != nil {
|
||||
log.Printf("Failed to send output uploaded, connection may be broken: %v", err)
|
||||
// Close the connection on write error
|
||||
j.connMu.Lock()
|
||||
j.isConnected = false
|
||||
if j.conn != nil {
|
||||
j.conn.Close()
|
||||
j.conn = nil
|
||||
}
|
||||
j.connMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Complete sends task completion to the manager.
|
||||
func (j *JobConnection) Complete(taskID int64, success bool, errorMsg error) {
|
||||
if j.conn == nil {
|
||||
log.Printf("Cannot send task complete: WebSocket connection is nil")
|
||||
return
|
||||
}
|
||||
|
||||
j.writeMu.Lock()
|
||||
defer j.writeMu.Unlock()
|
||||
|
||||
msg := map[string]interface{}{
|
||||
"type": "task_complete",
|
||||
"data": map[string]interface{}{
|
||||
"task_id": taskID,
|
||||
"success": success,
|
||||
"error": errorMsg,
|
||||
},
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
if err := j.conn.WriteJSON(msg); err != nil {
|
||||
log.Printf("Failed to send task complete, connection may be broken: %v", err)
|
||||
// Close the connection on write error
|
||||
j.connMu.Lock()
|
||||
j.isConnected = false
|
||||
if j.conn != nil {
|
||||
j.conn.Close()
|
||||
j.conn = nil
|
||||
}
|
||||
j.connMu.Unlock()
|
||||
}
|
||||
}
|
||||
421
internal/runner/api/manager.go
Normal file
421
internal/runner/api/manager.go
Normal file
@@ -0,0 +1,421 @@
|
||||
// Package api provides HTTP and WebSocket communication with the manager server.
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"jiggablend/pkg/types"
|
||||
)
|
||||
|
||||
// ManagerClient handles all HTTP communication with the manager server.
|
||||
type ManagerClient struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
runnerID int64
|
||||
httpClient *http.Client // Standard timeout for quick requests
|
||||
longClient *http.Client // No timeout for large file transfers
|
||||
}
|
||||
|
||||
// NewManagerClient creates a new manager client.
|
||||
func NewManagerClient(baseURL string) *ManagerClient {
|
||||
return &ManagerClient{
|
||||
baseURL: strings.TrimSuffix(baseURL, "/"),
|
||||
httpClient: &http.Client{Timeout: 30 * time.Second},
|
||||
longClient: &http.Client{Timeout: 0}, // No timeout for large transfers
|
||||
}
|
||||
}
|
||||
|
||||
// SetCredentials sets the API key and runner ID after registration.
|
||||
func (m *ManagerClient) SetCredentials(runnerID int64, apiKey string) {
|
||||
m.runnerID = runnerID
|
||||
m.apiKey = apiKey
|
||||
}
|
||||
|
||||
// GetRunnerID returns the registered runner ID.
|
||||
func (m *ManagerClient) GetRunnerID() int64 {
|
||||
return m.runnerID
|
||||
}
|
||||
|
||||
// GetAPIKey returns the API key.
|
||||
func (m *ManagerClient) GetAPIKey() string {
|
||||
return m.apiKey
|
||||
}
|
||||
|
||||
// GetBaseURL returns the base URL.
|
||||
func (m *ManagerClient) GetBaseURL() string {
|
||||
return m.baseURL
|
||||
}
|
||||
|
||||
// Request performs an authenticated HTTP request with standard timeout.
|
||||
func (m *ManagerClient) Request(method, path string, body []byte) (*http.Response, error) {
|
||||
return m.doRequest(method, path, body, m.httpClient)
|
||||
}
|
||||
|
||||
// RequestLong performs an authenticated HTTP request with no timeout.
|
||||
// Use for large file uploads/downloads.
|
||||
func (m *ManagerClient) RequestLong(method, path string, body []byte) (*http.Response, error) {
|
||||
return m.doRequest(method, path, body, m.longClient)
|
||||
}
|
||||
|
||||
func (m *ManagerClient) doRequest(method, path string, body []byte, client *http.Client) (*http.Response, error) {
|
||||
if m.apiKey == "" {
|
||||
return nil, fmt.Errorf("not authenticated")
|
||||
}
|
||||
|
||||
fullURL := m.baseURL + path
|
||||
req, err := http.NewRequest(method, fullURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+m.apiKey)
|
||||
if len(body) > 0 {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
// RequestWithToken performs an authenticated HTTP request using a specific token.
|
||||
func (m *ManagerClient) RequestWithToken(method, path, token string, body []byte) (*http.Response, error) {
|
||||
return m.doRequestWithToken(method, path, token, body, m.httpClient)
|
||||
}
|
||||
|
||||
// RequestLongWithToken performs a long-running request with a specific token.
|
||||
func (m *ManagerClient) RequestLongWithToken(method, path, token string, body []byte) (*http.Response, error) {
|
||||
return m.doRequestWithToken(method, path, token, body, m.longClient)
|
||||
}
|
||||
|
||||
func (m *ManagerClient) doRequestWithToken(method, path, token string, body []byte, client *http.Client) (*http.Response, error) {
|
||||
fullURL := m.baseURL + path
|
||||
req, err := http.NewRequest(method, fullURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
if len(body) > 0 {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
// RegisterRequest is the request body for runner registration.
|
||||
type RegisterRequest struct {
|
||||
Name string `json:"name"`
|
||||
Hostname string `json:"hostname"`
|
||||
Capabilities string `json:"capabilities"`
|
||||
APIKey string `json:"api_key"`
|
||||
Fingerprint string `json:"fingerprint,omitempty"`
|
||||
}
|
||||
|
||||
// RegisterResponse is the response from runner registration.
|
||||
type RegisterResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
}
|
||||
|
||||
// Register registers the runner with the manager.
|
||||
func (m *ManagerClient) Register(name, hostname string, capabilities map[string]interface{}, registrationToken, fingerprint string) (int64, error) {
|
||||
capsJSON, err := json.Marshal(capabilities)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to marshal capabilities: %w", err)
|
||||
}
|
||||
|
||||
reqBody := RegisterRequest{
|
||||
Name: name,
|
||||
Hostname: hostname,
|
||||
Capabilities: string(capsJSON),
|
||||
APIKey: registrationToken,
|
||||
}
|
||||
|
||||
// Only send fingerprint for non-fixed API keys
|
||||
if !strings.HasPrefix(registrationToken, "jk_r0_") {
|
||||
reqBody.Fingerprint = fingerprint
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
resp, err := m.httpClient.Post(
|
||||
m.baseURL+"/api/runner/register",
|
||||
"application/json",
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("connection error: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
errorBody := string(bodyBytes)
|
||||
|
||||
// Check for token-related errors (should not retry)
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusBadRequest {
|
||||
errorLower := strings.ToLower(errorBody)
|
||||
if strings.Contains(errorLower, "invalid") ||
|
||||
strings.Contains(errorLower, "expired") ||
|
||||
strings.Contains(errorLower, "already used") ||
|
||||
strings.Contains(errorLower, "token") {
|
||||
return 0, fmt.Errorf("token error: %s", errorBody)
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("registration failed (status %d): %s", resp.StatusCode, errorBody)
|
||||
}
|
||||
|
||||
var result RegisterResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return 0, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
m.runnerID = result.ID
|
||||
m.apiKey = registrationToken
|
||||
|
||||
return result.ID, nil
|
||||
}
|
||||
|
||||
// NextJobResponse represents the response from the next-job endpoint.
|
||||
type NextJobResponse struct {
|
||||
JobToken string `json:"job_token"`
|
||||
JobPath string `json:"job_path"`
|
||||
Task NextJobTaskInfo `json:"task"`
|
||||
}
|
||||
|
||||
// NextJobTaskInfo contains task information from the next-job response.
|
||||
type NextJobTaskInfo struct {
|
||||
TaskID int64 `json:"task_id"`
|
||||
JobID int64 `json:"job_id"`
|
||||
JobName string `json:"job_name"`
|
||||
Frame int `json:"frame"`
|
||||
TaskType string `json:"task_type"`
|
||||
Metadata *types.BlendMetadata `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// PollNextJob polls the manager for the next available job.
|
||||
// Returns nil, nil if no job is available.
|
||||
func (m *ManagerClient) PollNextJob() (*NextJobResponse, error) {
|
||||
if m.runnerID == 0 || m.apiKey == "" {
|
||||
return nil, fmt.Errorf("runner not authenticated")
|
||||
}
|
||||
|
||||
path := fmt.Sprintf("/api/runner/workers/%d/next-job", m.runnerID)
|
||||
resp, err := m.Request("GET", path, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to poll for job: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNoContent {
|
||||
return nil, nil // No job available
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("unexpected status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var job NextJobResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&job); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode job response: %w", err)
|
||||
}
|
||||
|
||||
return &job, nil
|
||||
}
|
||||
|
||||
// DownloadContext downloads the job context tar file.
|
||||
func (m *ManagerClient) DownloadContext(contextPath, jobToken string) (io.ReadCloser, error) {
|
||||
resp, err := m.RequestLongWithToken("GET", contextPath, jobToken, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to download context: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("context download failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
// UploadFile uploads a file to the manager.
|
||||
func (m *ManagerClient) UploadFile(uploadPath, jobToken, filePath string) error {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Create multipart form
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
part, err := writer.CreateFormFile("file", filepath.Base(filePath))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create form file: %w", err)
|
||||
}
|
||||
if _, err := io.Copy(part, file); err != nil {
|
||||
return fmt.Errorf("failed to copy file to form: %w", err)
|
||||
}
|
||||
writer.Close()
|
||||
|
||||
fullURL := m.baseURL + uploadPath
|
||||
req, err := http.NewRequest("POST", fullURL, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+jobToken)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
resp, err := m.longClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upload file: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("upload failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetJobMetadata retrieves job metadata from the manager.
|
||||
func (m *ManagerClient) GetJobMetadata(jobID int64) (*types.BlendMetadata, error) {
|
||||
path := fmt.Sprintf("/api/runner/jobs/%d/metadata?runner_id=%d", jobID, m.runnerID)
|
||||
resp, err := m.Request("GET", path, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, nil // No metadata found
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("failed to get job metadata: %s", string(body))
|
||||
}
|
||||
|
||||
var metadata types.BlendMetadata
|
||||
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &metadata, nil
|
||||
}
|
||||
|
||||
// JobFile represents a file associated with a job.
|
||||
type JobFile struct {
|
||||
ID int64 `json:"id"`
|
||||
JobID int64 `json:"job_id"`
|
||||
FileType string `json:"file_type"`
|
||||
FilePath string `json:"file_path"`
|
||||
FileName string `json:"file_name"`
|
||||
FileSize int64 `json:"file_size"`
|
||||
}
|
||||
|
||||
// GetJobFiles retrieves the list of files for a job.
|
||||
func (m *ManagerClient) GetJobFiles(jobID int64) ([]JobFile, error) {
|
||||
path := fmt.Sprintf("/api/runner/jobs/%d/files?runner_id=%d", jobID, m.runnerID)
|
||||
resp, err := m.Request("GET", path, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("failed to get job files: %s", string(body))
|
||||
}
|
||||
|
||||
var files []JobFile
|
||||
if err := json.NewDecoder(resp.Body).Decode(&files); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
// DownloadFrame downloads a frame file from the manager.
|
||||
func (m *ManagerClient) DownloadFrame(jobID int64, fileName, destPath string) error {
|
||||
encodedFileName := url.PathEscape(fileName)
|
||||
path := fmt.Sprintf("/api/runner/files/%d/%s?runner_id=%d", jobID, encodedFileName, m.runnerID)
|
||||
resp, err := m.RequestLong("GET", path, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("download failed: %s", string(body))
|
||||
}
|
||||
|
||||
file, err := os.Create(destPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
_, err = io.Copy(file, resp.Body)
|
||||
return err
|
||||
}
|
||||
|
||||
// SubmitMetadata submits extracted metadata to the manager.
|
||||
func (m *ManagerClient) SubmitMetadata(jobID int64, metadata types.BlendMetadata) error {
|
||||
metadataJSON, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal metadata: %w", err)
|
||||
}
|
||||
|
||||
path := fmt.Sprintf("/api/runner/jobs/%d/metadata?runner_id=%d", jobID, m.runnerID)
|
||||
fullURL := m.baseURL + path
|
||||
req, err := http.NewRequest("POST", fullURL, bytes.NewReader(metadataJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+m.apiKey)
|
||||
|
||||
resp, err := m.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to submit metadata: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("metadata submission failed: %s", string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DownloadBlender downloads a Blender version from the manager.
|
||||
func (m *ManagerClient) DownloadBlender(version string) (io.ReadCloser, error) {
|
||||
path := fmt.Sprintf("/api/runner/blender/download?version=%s&runner_id=%d", version, m.runnerID)
|
||||
resp, err := m.RequestLong("GET", path, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to download blender from manager: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("failed to download blender: status %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return resp.Body, nil
|
||||
}
|
||||
87
internal/runner/blender/binary.go
Normal file
87
internal/runner/blender/binary.go
Normal file
@@ -0,0 +1,87 @@
|
||||
// Package blender handles Blender binary management and execution.
|
||||
package blender
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"jiggablend/internal/runner/api"
|
||||
"jiggablend/internal/runner/workspace"
|
||||
)
|
||||
|
||||
// Manager handles Blender binary downloads and management.
|
||||
type Manager struct {
|
||||
manager *api.ManagerClient
|
||||
workspaceDir string
|
||||
}
|
||||
|
||||
// NewManager creates a new Blender manager.
|
||||
func NewManager(managerClient *api.ManagerClient, workspaceDir string) *Manager {
|
||||
return &Manager{
|
||||
manager: managerClient,
|
||||
workspaceDir: workspaceDir,
|
||||
}
|
||||
}
|
||||
|
||||
// GetBinaryPath returns the path to the Blender binary for a specific version.
|
||||
// Downloads from manager and extracts if not already present.
|
||||
func (m *Manager) GetBinaryPath(version string) (string, error) {
|
||||
blenderDir := filepath.Join(m.workspaceDir, "blender-versions")
|
||||
if err := os.MkdirAll(blenderDir, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create blender directory: %w", err)
|
||||
}
|
||||
|
||||
// Check if already installed - look for version folder first
|
||||
versionDir := filepath.Join(blenderDir, version)
|
||||
binaryPath := filepath.Join(versionDir, "blender")
|
||||
|
||||
// Check if version folder exists and contains the binary
|
||||
if versionInfo, err := os.Stat(versionDir); err == nil && versionInfo.IsDir() {
|
||||
// Version folder exists, check if binary is present
|
||||
if binaryInfo, err := os.Stat(binaryPath); err == nil {
|
||||
// Verify it's actually a file (not a directory)
|
||||
if !binaryInfo.IsDir() {
|
||||
log.Printf("Found existing Blender %s installation at %s", version, binaryPath)
|
||||
return binaryPath, nil
|
||||
}
|
||||
}
|
||||
// Version folder exists but binary is missing - might be incomplete installation
|
||||
log.Printf("Version folder %s exists but binary not found, will re-download", versionDir)
|
||||
}
|
||||
|
||||
// Download from manager
|
||||
log.Printf("Downloading Blender %s from manager", version)
|
||||
|
||||
reader, err := m.manager.DownloadBlender(version)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Manager serves pre-decompressed .tar files - extract directly
|
||||
log.Printf("Extracting Blender %s...", version)
|
||||
if err := workspace.ExtractTarStripPrefix(reader, versionDir); err != nil {
|
||||
return "", fmt.Errorf("failed to extract blender: %w", err)
|
||||
}
|
||||
|
||||
// Verify binary exists
|
||||
if _, err := os.Stat(binaryPath); err != nil {
|
||||
return "", fmt.Errorf("blender binary not found after extraction")
|
||||
}
|
||||
|
||||
log.Printf("Blender %s installed at %s", version, binaryPath)
|
||||
return binaryPath, nil
|
||||
}
|
||||
|
||||
// GetBinaryForJob returns the Blender binary path for a job.
|
||||
// Uses the version from metadata or falls back to system blender.
|
||||
func (m *Manager) GetBinaryForJob(version string) (string, error) {
|
||||
if version == "" {
|
||||
return "blender", nil // System blender
|
||||
}
|
||||
|
||||
return m.GetBinaryPath(version)
|
||||
}
|
||||
|
||||
100
internal/runner/blender/logfilter.go
Normal file
100
internal/runner/blender/logfilter.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package blender
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"jiggablend/pkg/types"
|
||||
)
|
||||
|
||||
// FilterLog checks if a Blender log line should be filtered or downgraded.
|
||||
// Returns (shouldFilter, logLevel) - if shouldFilter is true, the log should be skipped.
|
||||
func FilterLog(line string) (shouldFilter bool, logLevel types.LogLevel) {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
|
||||
// Filter out empty lines
|
||||
if trimmed == "" {
|
||||
return true, types.LogLevelInfo
|
||||
}
|
||||
|
||||
// Filter out separator lines
|
||||
if trimmed == "--------------------------------------------------------------------" ||
|
||||
(strings.HasPrefix(trimmed, "-----") && strings.Contains(trimmed, "----")) {
|
||||
return true, types.LogLevelInfo
|
||||
}
|
||||
|
||||
// Filter out trace headers
|
||||
upperLine := strings.ToUpper(trimmed)
|
||||
upperOriginal := strings.ToUpper(line)
|
||||
|
||||
if trimmed == "Trace:" ||
|
||||
trimmed == "Depth Type Name" ||
|
||||
trimmed == "----- ---- ----" ||
|
||||
line == "Depth Type Name" ||
|
||||
line == "----- ---- ----" ||
|
||||
(strings.Contains(upperLine, "DEPTH") && strings.Contains(upperLine, "TYPE") && strings.Contains(upperLine, "NAME")) ||
|
||||
(strings.Contains(upperOriginal, "DEPTH") && strings.Contains(upperOriginal, "TYPE") && strings.Contains(upperOriginal, "NAME")) ||
|
||||
strings.Contains(line, "Depth Type Name") ||
|
||||
strings.Contains(line, "----- ---- ----") ||
|
||||
strings.HasPrefix(trimmed, "-----") ||
|
||||
regexp.MustCompile(`^[-]+\s+[-]+\s+[-]+$`).MatchString(trimmed) {
|
||||
return true, types.LogLevelInfo
|
||||
}
|
||||
|
||||
// Completely filter out dependency graph messages (they're just noise)
|
||||
dependencyGraphPatterns := []string{
|
||||
"Failed to add relation",
|
||||
"Could not find op_from",
|
||||
"OperationKey",
|
||||
"find_node_operation: Failed for",
|
||||
"BONE_DONE",
|
||||
"component name:",
|
||||
"operation code:",
|
||||
"rope_ctrl_rot_",
|
||||
}
|
||||
|
||||
for _, pattern := range dependencyGraphPatterns {
|
||||
if strings.Contains(line, pattern) {
|
||||
return true, types.LogLevelInfo
|
||||
}
|
||||
}
|
||||
|
||||
// Filter out animation system warnings (invalid drivers are common and harmless)
|
||||
animationSystemPatterns := []string{
|
||||
"BKE_animsys_eval_driver: invalid driver",
|
||||
"bke.anim_sys",
|
||||
"rotation_quaternion[",
|
||||
"constraints[",
|
||||
".influence[0]",
|
||||
"pose.bones[",
|
||||
}
|
||||
|
||||
for _, pattern := range animationSystemPatterns {
|
||||
if strings.Contains(line, pattern) {
|
||||
return true, types.LogLevelInfo
|
||||
}
|
||||
}
|
||||
|
||||
// Filter out modifier warnings (common when vertices change)
|
||||
modifierPatterns := []string{
|
||||
"BKE_modifier_set_error",
|
||||
"bke.modifier",
|
||||
"Vertices changed from",
|
||||
"Modifier:",
|
||||
}
|
||||
|
||||
for _, pattern := range modifierPatterns {
|
||||
if strings.Contains(line, pattern) {
|
||||
return true, types.LogLevelInfo
|
||||
}
|
||||
}
|
||||
|
||||
// Filter out lines that are just numbers or trace depth indicators
|
||||
// Pattern: number, word, word (e.g., "1 Object timer_box_franck")
|
||||
if matched, _ := regexp.MatchString(`^\d+\s+\w+\s+\w+`, trimmed); matched {
|
||||
return true, types.LogLevelInfo
|
||||
}
|
||||
|
||||
return false, types.LogLevelInfo
|
||||
}
|
||||
|
||||
143
internal/runner/blender/version.go
Normal file
143
internal/runner/blender/version.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package blender
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// ParseVersionFromFile parses the Blender version that a .blend file was saved with.
|
||||
// Returns major and minor version numbers.
|
||||
func ParseVersionFromFile(blendPath string) (major, minor int, err error) {
|
||||
file, err := os.Open(blendPath)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to open blend file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Read the first 12 bytes of the blend file header
|
||||
// Format: BLENDER-v<major><minor><patch> or BLENDER_v<major><minor><patch>
|
||||
// The header is: "BLENDER" (7 bytes) + pointer size (1 byte: '-' for 64-bit, '_' for 32-bit)
|
||||
// + endianness (1 byte: 'v' for little-endian, 'V' for big-endian)
|
||||
// + version (3 bytes: e.g., "402" for 4.02)
|
||||
header := make([]byte, 12)
|
||||
n, err := file.Read(header)
|
||||
if err != nil || n < 12 {
|
||||
return 0, 0, fmt.Errorf("failed to read blend file header: %w", err)
|
||||
}
|
||||
|
||||
// Check for BLENDER magic
|
||||
if string(header[:7]) != "BLENDER" {
|
||||
// Might be compressed - try to decompress
|
||||
file.Seek(0, 0)
|
||||
return parseCompressedVersion(file)
|
||||
}
|
||||
|
||||
// Parse version from bytes 9-11 (3 digits)
|
||||
versionStr := string(header[9:12])
|
||||
|
||||
// Version format changed in Blender 3.0
|
||||
// Pre-3.0: "279" = 2.79, "280" = 2.80
|
||||
// 3.0+: "300" = 3.0, "402" = 4.02, "410" = 4.10
|
||||
if len(versionStr) == 3 {
|
||||
// First digit is major version
|
||||
fmt.Sscanf(string(versionStr[0]), "%d", &major)
|
||||
// Next two digits are minor version
|
||||
fmt.Sscanf(versionStr[1:3], "%d", &minor)
|
||||
}
|
||||
|
||||
return major, minor, nil
|
||||
}
|
||||
|
||||
// parseCompressedVersion handles gzip and zstd compressed blend files.
|
||||
func parseCompressedVersion(file *os.File) (major, minor int, err error) {
|
||||
magic := make([]byte, 4)
|
||||
if _, err := file.Read(magic); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
file.Seek(0, 0)
|
||||
|
||||
if magic[0] == 0x1f && magic[1] == 0x8b {
|
||||
// gzip compressed
|
||||
gzReader, err := gzip.NewReader(file)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
}
|
||||
defer gzReader.Close()
|
||||
|
||||
header := make([]byte, 12)
|
||||
n, err := gzReader.Read(header)
|
||||
if err != nil || n < 12 {
|
||||
return 0, 0, fmt.Errorf("failed to read compressed blend header: %w", err)
|
||||
}
|
||||
|
||||
if string(header[:7]) != "BLENDER" {
|
||||
return 0, 0, fmt.Errorf("invalid blend file format")
|
||||
}
|
||||
|
||||
versionStr := string(header[9:12])
|
||||
if len(versionStr) == 3 {
|
||||
fmt.Sscanf(string(versionStr[0]), "%d", &major)
|
||||
fmt.Sscanf(versionStr[1:3], "%d", &minor)
|
||||
}
|
||||
|
||||
return major, minor, nil
|
||||
}
|
||||
|
||||
// Check for zstd magic (Blender 3.0+): 0x28 0xB5 0x2F 0xFD
|
||||
if magic[0] == 0x28 && magic[1] == 0xb5 && magic[2] == 0x2f && magic[3] == 0xfd {
|
||||
return parseZstdVersion(file)
|
||||
}
|
||||
|
||||
return 0, 0, fmt.Errorf("unknown blend file format")
|
||||
}
|
||||
|
||||
// parseZstdVersion handles zstd-compressed blend files (Blender 3.0+).
|
||||
// Uses zstd command line tool since Go doesn't have native zstd support.
|
||||
func parseZstdVersion(file *os.File) (major, minor int, err error) {
|
||||
file.Seek(0, 0)
|
||||
|
||||
cmd := exec.Command("zstd", "-d", "-c")
|
||||
cmd.Stdin = file
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to create zstd stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to start zstd decompression: %w", err)
|
||||
}
|
||||
|
||||
// Read just the header (12 bytes)
|
||||
header := make([]byte, 12)
|
||||
n, readErr := io.ReadFull(stdout, header)
|
||||
|
||||
// Kill the process early - we only need the header
|
||||
cmd.Process.Kill()
|
||||
cmd.Wait()
|
||||
|
||||
if readErr != nil || n < 12 {
|
||||
return 0, 0, fmt.Errorf("failed to read zstd compressed blend header: %v", readErr)
|
||||
}
|
||||
|
||||
if string(header[:7]) != "BLENDER" {
|
||||
return 0, 0, fmt.Errorf("invalid blend file format in zstd archive")
|
||||
}
|
||||
|
||||
versionStr := string(header[9:12])
|
||||
if len(versionStr) == 3 {
|
||||
fmt.Sscanf(string(versionStr[0]), "%d", &major)
|
||||
fmt.Sscanf(versionStr[1:3], "%d", &minor)
|
||||
}
|
||||
|
||||
return major, minor, nil
|
||||
}
|
||||
|
||||
// VersionString returns a formatted version string like "4.2".
|
||||
func VersionString(major, minor int) string {
|
||||
return fmt.Sprintf("%d.%d", major, minor)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
71
internal/runner/encoding/encoder.go
Normal file
71
internal/runner/encoding/encoder.go
Normal file
@@ -0,0 +1,71 @@
|
||||
// Package encoding handles video encoding with software encoders.
|
||||
package encoding
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// Encoder represents a video encoder.
|
||||
type Encoder interface {
|
||||
Name() string
|
||||
Codec() string
|
||||
Available() bool
|
||||
BuildCommand(config *EncodeConfig) *exec.Cmd
|
||||
}
|
||||
|
||||
// EncodeConfig holds configuration for video encoding.
|
||||
type EncodeConfig struct {
|
||||
InputPattern string // Input file pattern (e.g., "frame_%04d.exr")
|
||||
OutputPath string // Output file path
|
||||
StartFrame int // Starting frame number
|
||||
FrameRate float64 // Frame rate
|
||||
WorkDir string // Working directory
|
||||
UseAlpha bool // Whether to preserve alpha channel
|
||||
TwoPass bool // Whether to use 2-pass encoding
|
||||
SourceFormat string // Source format: "exr" or "png" (defaults to "exr")
|
||||
PreserveHDR bool // Whether to preserve HDR range for EXR (uses HLG with bt709 primaries)
|
||||
}
|
||||
|
||||
// Selector selects the software encoder.
|
||||
type Selector struct {
|
||||
h264Encoders []Encoder
|
||||
av1Encoders []Encoder
|
||||
vp9Encoders []Encoder
|
||||
}
|
||||
|
||||
// NewSelector creates a new encoder selector with software encoders.
|
||||
func NewSelector() *Selector {
|
||||
s := &Selector{}
|
||||
s.detectEncoders()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Selector) detectEncoders() {
|
||||
// Use software encoding only - reliable and avoids hardware-specific colorspace issues
|
||||
s.h264Encoders = []Encoder{
|
||||
&SoftwareEncoder{codec: "libx264"},
|
||||
}
|
||||
|
||||
s.av1Encoders = []Encoder{
|
||||
&SoftwareEncoder{codec: "libaom-av1"},
|
||||
}
|
||||
|
||||
s.vp9Encoders = []Encoder{
|
||||
&SoftwareEncoder{codec: "libvpx-vp9"},
|
||||
}
|
||||
}
|
||||
|
||||
// SelectH264 returns the software H.264 encoder.
|
||||
func (s *Selector) SelectH264() Encoder {
|
||||
return &SoftwareEncoder{codec: "libx264"}
|
||||
}
|
||||
|
||||
// SelectAV1 returns the software AV1 encoder.
|
||||
func (s *Selector) SelectAV1() Encoder {
|
||||
return &SoftwareEncoder{codec: "libaom-av1"}
|
||||
}
|
||||
|
||||
// SelectVP9 returns the software VP9 encoder.
|
||||
func (s *Selector) SelectVP9() Encoder {
|
||||
return &SoftwareEncoder{codec: "libvpx-vp9"}
|
||||
}
|
||||
270
internal/runner/encoding/encoders.go
Normal file
270
internal/runner/encoding/encoders.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package encoding
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// CRFH264 is the Constant Rate Factor for H.264 encoding (lower = higher quality, range 0-51)
|
||||
CRFH264 = 15
|
||||
// CRFAV1 is the Constant Rate Factor for AV1 encoding (lower = higher quality, range 0-63)
|
||||
CRFAV1 = 30
|
||||
// CRFVP9 is the Constant Rate Factor for VP9 encoding (lower = higher quality, range 0-63)
|
||||
CRFVP9 = 30
|
||||
)
|
||||
|
||||
// tonemapFilter returns the appropriate filter for EXR input.
|
||||
// For HDR preservation: converts linear RGB (EXR) to bt2020 YUV with HLG transfer function
|
||||
// Uses zscale to properly convert colorspace from linear RGB to bt2020 YUV while preserving HDR range
|
||||
// Step 1: Ensure format is gbrpf32le (linear RGB)
|
||||
// Step 2: Convert transfer function from linear to HLG (arib-std-b67) with bt2020 primaries/matrix
|
||||
// Step 3: Convert to YUV format
|
||||
func tonemapFilter(useAlpha bool) string {
|
||||
// Convert from linear RGB (gbrpf32le) to HLG with bt709 primaries to match PNG appearance
|
||||
// Based on best practices: convert linear RGB directly to HLG with bt709 primaries
|
||||
// This matches PNG color appearance (bt709 primaries) while preserving HDR range (HLG transfer)
|
||||
// zscale uses numeric values:
|
||||
// primaries: 1=bt709 (matches PNG), 9=bt2020
|
||||
// matrix: 1=bt709, 9=bt2020nc, 0=gbr (RGB input)
|
||||
// transfer: 8=linear, 18=arib-std-b67 (HLG)
|
||||
// Direct conversion: linear RGB -> HLG with bt709 primaries -> bt2020 YUV (for wider gamut metadata)
|
||||
// The bt709 primaries in the conversion match PNG, but we set bt2020 in metadata for HDR displays
|
||||
// Convert linear RGB to sRGB first, then convert to HLG
|
||||
// This approach: linear -> sRGB -> HLG -> bt2020
|
||||
// Fixes red tint by using sRGB conversion, preserves HDR range with HLG
|
||||
filter := "format=gbrpf32le,zscale=transferin=8:transfer=13:primariesin=1:primaries=1:matrixin=0:matrix=1:rangein=full:range=full,zscale=transferin=13:transfer=18:primariesin=1:primaries=9:matrixin=1:matrix=9:rangein=full:range=full"
|
||||
if useAlpha {
|
||||
return filter + ",format=yuva420p10le"
|
||||
}
|
||||
return filter + ",format=yuv420p10le"
|
||||
}
|
||||
|
||||
// SoftwareEncoder implements software encoding (libx264, libaom-av1, libvpx-vp9).
|
||||
type SoftwareEncoder struct {
|
||||
codec string
|
||||
}
|
||||
|
||||
func (e *SoftwareEncoder) Name() string { return "software" }
|
||||
func (e *SoftwareEncoder) Codec() string { return e.codec }
|
||||
|
||||
func (e *SoftwareEncoder) Available() bool {
|
||||
return true // Software encoding is always available
|
||||
}
|
||||
|
||||
func (e *SoftwareEncoder) BuildCommand(config *EncodeConfig) *exec.Cmd {
|
||||
// Use HDR pixel formats for EXR, SDR for PNG
|
||||
var pixFmt string
|
||||
var colorPrimaries, colorTrc, colorspace string
|
||||
if config.SourceFormat == "png" {
|
||||
// PNG: SDR format
|
||||
pixFmt = "yuv420p"
|
||||
if config.UseAlpha {
|
||||
pixFmt = "yuva420p"
|
||||
}
|
||||
colorPrimaries = "bt709"
|
||||
colorTrc = "bt709"
|
||||
colorspace = "bt709"
|
||||
} else {
|
||||
// EXR: Use HDR encoding if PreserveHDR is true, otherwise SDR (like PNG)
|
||||
if config.PreserveHDR {
|
||||
// HDR: Use HLG transfer with bt709 primaries to preserve HDR range while matching PNG color
|
||||
pixFmt = "yuv420p10le" // 10-bit to preserve HDR range
|
||||
if config.UseAlpha {
|
||||
pixFmt = "yuva420p10le"
|
||||
}
|
||||
colorPrimaries = "bt709" // bt709 primaries to match PNG color appearance
|
||||
colorTrc = "arib-std-b67" // HLG transfer function - preserves HDR range, works on SDR displays
|
||||
colorspace = "bt709" // bt709 colorspace to match PNG
|
||||
} else {
|
||||
// SDR: Treat as SDR (like PNG) - encode as bt709
|
||||
pixFmt = "yuv420p"
|
||||
if config.UseAlpha {
|
||||
pixFmt = "yuva420p"
|
||||
}
|
||||
colorPrimaries = "bt709"
|
||||
colorTrc = "bt709"
|
||||
colorspace = "bt709"
|
||||
}
|
||||
}
|
||||
|
||||
var codecArgs []string
|
||||
switch e.codec {
|
||||
case "libaom-av1":
|
||||
codecArgs = []string{"-crf", strconv.Itoa(CRFAV1), "-b:v", "0", "-tiles", "2x2", "-g", "240"}
|
||||
case "libvpx-vp9":
|
||||
// VP9 supports alpha and HDR, use good quality settings
|
||||
codecArgs = []string{"-crf", strconv.Itoa(CRFVP9), "-b:v", "0", "-row-mt", "1", "-g", "240"}
|
||||
default:
|
||||
// H.264: Use High 10 profile for HDR EXR (10-bit), High profile for SDR
|
||||
if config.SourceFormat != "png" && config.PreserveHDR {
|
||||
codecArgs = []string{"-preset", "veryslow", "-crf", strconv.Itoa(CRFH264), "-profile:v", "high10", "-level", "5.2", "-tune", "film", "-keyint_min", "24", "-g", "240", "-bf", "2", "-refs", "4"}
|
||||
} else {
|
||||
codecArgs = []string{"-preset", "veryslow", "-crf", strconv.Itoa(CRFH264), "-profile:v", "high", "-level", "5.2", "-tune", "film", "-keyint_min", "24", "-g", "240", "-bf", "2", "-refs", "4"}
|
||||
}
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"-y",
|
||||
"-f", "image2",
|
||||
"-start_number", fmt.Sprintf("%d", config.StartFrame),
|
||||
"-framerate", fmt.Sprintf("%.2f", config.FrameRate),
|
||||
"-i", config.InputPattern,
|
||||
"-c:v", e.codec,
|
||||
"-pix_fmt", pixFmt,
|
||||
"-r", fmt.Sprintf("%.2f", config.FrameRate),
|
||||
"-color_primaries", colorPrimaries,
|
||||
"-color_trc", colorTrc,
|
||||
"-colorspace", colorspace,
|
||||
"-color_range", "tv",
|
||||
}
|
||||
|
||||
// Add video filter for EXR: convert linear RGB based on HDR setting
|
||||
// PNG doesn't need any filter as it's already in sRGB
|
||||
if config.SourceFormat != "png" {
|
||||
var vf string
|
||||
if config.PreserveHDR {
|
||||
// HDR: Convert linear RGB -> sRGB -> HLG with bt709 primaries
|
||||
// This preserves HDR range while matching PNG color appearance
|
||||
vf = "format=gbrpf32le,zscale=transferin=8:transfer=13:primariesin=1:primaries=1:matrixin=0:matrix=1:rangein=full:range=full,zscale=transferin=13:transfer=18:primariesin=1:primaries=1:matrixin=1:matrix=1:rangein=full:range=full"
|
||||
if config.UseAlpha {
|
||||
vf += ",format=yuva420p10le"
|
||||
} else {
|
||||
vf += ",format=yuv420p10le"
|
||||
}
|
||||
} else {
|
||||
// SDR: Convert linear RGB (EXR) to sRGB (bt709) - simple conversion like Krita does
|
||||
// zscale: linear (8) -> sRGB (13) with bt709 primaries/matrix
|
||||
vf = "format=gbrpf32le,zscale=transferin=8:transfer=13:primariesin=1:primaries=1:matrixin=0:matrix=1:rangein=full:range=full"
|
||||
if config.UseAlpha {
|
||||
vf += ",format=yuva420p"
|
||||
} else {
|
||||
vf += ",format=yuv420p"
|
||||
}
|
||||
}
|
||||
args = append(args, "-vf", vf)
|
||||
}
|
||||
args = append(args, codecArgs...)
|
||||
|
||||
if config.TwoPass {
|
||||
// For 2-pass, this builds pass 2 command
|
||||
args = append(args, "-pass", "2")
|
||||
}
|
||||
|
||||
args = append(args, config.OutputPath)
|
||||
|
||||
if config.TwoPass {
|
||||
log.Printf("Build Software Pass 2 command: ffmpeg %s", strings.Join(args, " "))
|
||||
} else {
|
||||
log.Printf("Build Software command: ffmpeg %s", strings.Join(args, " "))
|
||||
}
|
||||
cmd := exec.Command("ffmpeg", args...)
|
||||
cmd.Dir = config.WorkDir
|
||||
return cmd
|
||||
}
|
||||
|
||||
// BuildPass1Command builds the first pass command for 2-pass encoding.
|
||||
func (e *SoftwareEncoder) BuildPass1Command(config *EncodeConfig) *exec.Cmd {
|
||||
// Use HDR pixel formats for EXR, SDR for PNG
|
||||
var pixFmt string
|
||||
var colorPrimaries, colorTrc, colorspace string
|
||||
if config.SourceFormat == "png" {
|
||||
// PNG: SDR format
|
||||
pixFmt = "yuv420p"
|
||||
if config.UseAlpha {
|
||||
pixFmt = "yuva420p"
|
||||
}
|
||||
colorPrimaries = "bt709"
|
||||
colorTrc = "bt709"
|
||||
colorspace = "bt709"
|
||||
} else {
|
||||
// EXR: Use HDR encoding if PreserveHDR is true, otherwise SDR (like PNG)
|
||||
if config.PreserveHDR {
|
||||
// HDR: Use HLG transfer with bt709 primaries to preserve HDR range while matching PNG color
|
||||
pixFmt = "yuv420p10le" // 10-bit to preserve HDR range
|
||||
if config.UseAlpha {
|
||||
pixFmt = "yuva420p10le"
|
||||
}
|
||||
colorPrimaries = "bt709" // bt709 primaries to match PNG color appearance
|
||||
colorTrc = "arib-std-b67" // HLG transfer function - preserves HDR range, works on SDR displays
|
||||
colorspace = "bt709" // bt709 colorspace to match PNG
|
||||
} else {
|
||||
// SDR: Treat as SDR (like PNG) - encode as bt709
|
||||
pixFmt = "yuv420p"
|
||||
if config.UseAlpha {
|
||||
pixFmt = "yuva420p"
|
||||
}
|
||||
colorPrimaries = "bt709"
|
||||
colorTrc = "bt709"
|
||||
colorspace = "bt709"
|
||||
}
|
||||
}
|
||||
|
||||
var codecArgs []string
|
||||
switch e.codec {
|
||||
case "libaom-av1":
|
||||
codecArgs = []string{"-crf", strconv.Itoa(CRFAV1), "-b:v", "0", "-tiles", "2x2", "-g", "240"}
|
||||
case "libvpx-vp9":
|
||||
// VP9 supports alpha and HDR, use good quality settings
|
||||
codecArgs = []string{"-crf", strconv.Itoa(CRFVP9), "-b:v", "0", "-row-mt", "1", "-g", "240"}
|
||||
default:
|
||||
// H.264: Use High 10 profile for HDR EXR (10-bit), High profile for SDR
|
||||
if config.SourceFormat != "png" && config.PreserveHDR {
|
||||
codecArgs = []string{"-preset", "veryslow", "-crf", strconv.Itoa(CRFH264), "-profile:v", "high10", "-level", "5.2", "-tune", "film", "-keyint_min", "24", "-g", "240", "-bf", "2", "-refs", "4"}
|
||||
} else {
|
||||
codecArgs = []string{"-preset", "veryslow", "-crf", strconv.Itoa(CRFH264), "-profile:v", "high", "-level", "5.2", "-tune", "film", "-keyint_min", "24", "-g", "240", "-bf", "2", "-refs", "4"}
|
||||
}
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"-y",
|
||||
"-f", "image2",
|
||||
"-start_number", fmt.Sprintf("%d", config.StartFrame),
|
||||
"-framerate", fmt.Sprintf("%.2f", config.FrameRate),
|
||||
"-i", config.InputPattern,
|
||||
"-c:v", e.codec,
|
||||
"-pix_fmt", pixFmt,
|
||||
"-r", fmt.Sprintf("%.2f", config.FrameRate),
|
||||
"-color_primaries", colorPrimaries,
|
||||
"-color_trc", colorTrc,
|
||||
"-colorspace", colorspace,
|
||||
"-color_range", "tv",
|
||||
}
|
||||
|
||||
// Add video filter for EXR: convert linear RGB based on HDR setting
|
||||
// PNG doesn't need any filter as it's already in sRGB
|
||||
if config.SourceFormat != "png" {
|
||||
var vf string
|
||||
if config.PreserveHDR {
|
||||
// HDR: Convert linear RGB -> sRGB -> HLG with bt709 primaries
|
||||
// This preserves HDR range while matching PNG color appearance
|
||||
vf = "format=gbrpf32le,zscale=transferin=8:transfer=13:primariesin=1:primaries=1:matrixin=0:matrix=1:rangein=full:range=full,zscale=transferin=13:transfer=18:primariesin=1:primaries=1:matrixin=1:matrix=1:rangein=full:range=full"
|
||||
if config.UseAlpha {
|
||||
vf += ",format=yuva420p10le"
|
||||
} else {
|
||||
vf += ",format=yuv420p10le"
|
||||
}
|
||||
} else {
|
||||
// SDR: Convert linear RGB (EXR) to sRGB (bt709) - simple conversion like Krita does
|
||||
// zscale: linear (8) -> sRGB (13) with bt709 primaries/matrix
|
||||
vf = "format=gbrpf32le,zscale=transferin=8:transfer=13:primariesin=1:primaries=1:matrixin=0:matrix=1:rangein=full:range=full"
|
||||
if config.UseAlpha {
|
||||
vf += ",format=yuva420p"
|
||||
} else {
|
||||
vf += ",format=yuv420p"
|
||||
}
|
||||
}
|
||||
args = append(args, "-vf", vf)
|
||||
}
|
||||
|
||||
args = append(args, codecArgs...)
|
||||
args = append(args, "-pass", "1", "-f", "null", "/dev/null")
|
||||
|
||||
log.Printf("Build Software Pass 1 command: ffmpeg %s", strings.Join(args, " "))
|
||||
cmd := exec.Command("ffmpeg", args...)
|
||||
cmd.Dir = config.WorkDir
|
||||
return cmd
|
||||
}
|
||||
980
internal/runner/encoding/encoders_test.go
Normal file
980
internal/runner/encoding/encoders_test.go
Normal file
@@ -0,0 +1,980 @@
|
||||
package encoding
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSoftwareEncoder_BuildCommand_H264_EXR(t *testing.T) {
|
||||
encoder := &SoftwareEncoder{codec: "libx264"}
|
||||
config := &EncodeConfig{
|
||||
InputPattern: "frame_%04d.exr",
|
||||
OutputPath: "output.mp4",
|
||||
StartFrame: 1,
|
||||
FrameRate: 24.0,
|
||||
WorkDir: "/tmp",
|
||||
UseAlpha: false,
|
||||
TwoPass: true,
|
||||
SourceFormat: "exr",
|
||||
}
|
||||
|
||||
cmd := encoder.BuildCommand(config)
|
||||
if cmd == nil {
|
||||
t.Fatal("BuildCommand returned nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(cmd.Path, "ffmpeg") {
|
||||
t.Errorf("Expected command path to contain 'ffmpeg', got '%s'", cmd.Path)
|
||||
}
|
||||
|
||||
if cmd.Dir != "/tmp" {
|
||||
t.Errorf("Expected work dir '/tmp', got '%s'", cmd.Dir)
|
||||
}
|
||||
|
||||
args := cmd.Args[1:] // Skip "ffmpeg"
|
||||
argsStr := strings.Join(args, " ")
|
||||
|
||||
// Check required arguments
|
||||
checks := []struct {
|
||||
name string
|
||||
expected string
|
||||
}{
|
||||
{"-y flag", "-y"},
|
||||
{"image2 format", "-f image2"},
|
||||
{"start number", "-start_number 1"},
|
||||
{"framerate", "-framerate 24.00"},
|
||||
{"input pattern", "-i frame_%04d.exr"},
|
||||
{"codec", "-c:v libx264"},
|
||||
{"pixel format", "-pix_fmt yuv420p"}, // EXR now treated as SDR (like PNG)
|
||||
{"frame rate", "-r 24.00"},
|
||||
{"color primaries", "-color_primaries bt709"}, // EXR now uses bt709 (SDR)
|
||||
{"color trc", "-color_trc bt709"}, // EXR now uses bt709 (SDR)
|
||||
{"colorspace", "-colorspace bt709"},
|
||||
{"color range", "-color_range tv"},
|
||||
{"video filter", "-vf"},
|
||||
{"preset", "-preset veryslow"},
|
||||
{"crf", "-crf 15"},
|
||||
{"profile", "-profile:v high"}, // EXR now uses high profile (SDR)
|
||||
{"pass 2", "-pass 2"},
|
||||
{"output path", "output.mp4"},
|
||||
}
|
||||
|
||||
for _, check := range checks {
|
||||
if !strings.Contains(argsStr, check.expected) {
|
||||
t.Errorf("Missing expected argument: %s", check.expected)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify filter is present for EXR (linear RGB to sRGB conversion, like Krita does)
|
||||
if !strings.Contains(argsStr, "format=gbrpf32le") {
|
||||
t.Error("Expected format conversion filter for EXR source, but not found")
|
||||
}
|
||||
if !strings.Contains(argsStr, "zscale=transferin=8:transfer=13") {
|
||||
t.Error("Expected linear to sRGB conversion for EXR source, but not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftwareEncoder_BuildCommand_H264_PNG(t *testing.T) {
|
||||
encoder := &SoftwareEncoder{codec: "libx264"}
|
||||
config := &EncodeConfig{
|
||||
InputPattern: "frame_%04d.png",
|
||||
OutputPath: "output.mp4",
|
||||
StartFrame: 1,
|
||||
FrameRate: 24.0,
|
||||
WorkDir: "/tmp",
|
||||
UseAlpha: false,
|
||||
TwoPass: true,
|
||||
SourceFormat: "png",
|
||||
}
|
||||
|
||||
cmd := encoder.BuildCommand(config)
|
||||
args := cmd.Args[1:]
|
||||
argsStr := strings.Join(args, " ")
|
||||
|
||||
// PNG should NOT have video filter
|
||||
if strings.Contains(argsStr, "-vf") {
|
||||
t.Error("PNG source should not have video filter, but -vf was found")
|
||||
}
|
||||
|
||||
// Should still have all other required args
|
||||
if !strings.Contains(argsStr, "-c:v libx264") {
|
||||
t.Error("Missing codec argument")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftwareEncoder_BuildCommand_AV1_WithAlpha(t *testing.T) {
|
||||
encoder := &SoftwareEncoder{codec: "libaom-av1"}
|
||||
config := &EncodeConfig{
|
||||
InputPattern: "frame_%04d.exr",
|
||||
OutputPath: "output.mp4",
|
||||
StartFrame: 100,
|
||||
FrameRate: 30.0,
|
||||
WorkDir: "/tmp",
|
||||
UseAlpha: true,
|
||||
TwoPass: true,
|
||||
SourceFormat: "exr",
|
||||
}
|
||||
|
||||
cmd := encoder.BuildCommand(config)
|
||||
args := cmd.Args[1:]
|
||||
argsStr := strings.Join(args, " ")
|
||||
|
||||
// Check alpha-specific settings
|
||||
if !strings.Contains(argsStr, "-pix_fmt yuva420p") {
|
||||
t.Error("Expected yuva420p pixel format for alpha, but not found")
|
||||
}
|
||||
|
||||
// Check AV1-specific arguments
|
||||
av1Checks := []string{
|
||||
"-c:v libaom-av1",
|
||||
"-crf 30",
|
||||
"-b:v 0",
|
||||
"-tiles 2x2",
|
||||
"-g 240",
|
||||
}
|
||||
|
||||
for _, check := range av1Checks {
|
||||
if !strings.Contains(argsStr, check) {
|
||||
t.Errorf("Missing AV1 argument: %s", check)
|
||||
}
|
||||
}
|
||||
|
||||
// Check tonemap filter includes alpha format
|
||||
if !strings.Contains(argsStr, "format=yuva420p") {
|
||||
t.Error("Expected tonemap filter to output yuva420p for alpha, but not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftwareEncoder_BuildCommand_VP9(t *testing.T) {
|
||||
encoder := &SoftwareEncoder{codec: "libvpx-vp9"}
|
||||
config := &EncodeConfig{
|
||||
InputPattern: "frame_%04d.exr",
|
||||
OutputPath: "output.webm",
|
||||
StartFrame: 1,
|
||||
FrameRate: 24.0,
|
||||
WorkDir: "/tmp",
|
||||
UseAlpha: true,
|
||||
TwoPass: true,
|
||||
SourceFormat: "exr",
|
||||
}
|
||||
|
||||
cmd := encoder.BuildCommand(config)
|
||||
args := cmd.Args[1:]
|
||||
argsStr := strings.Join(args, " ")
|
||||
|
||||
// Check VP9-specific arguments
|
||||
vp9Checks := []string{
|
||||
"-c:v libvpx-vp9",
|
||||
"-crf 30",
|
||||
"-b:v 0",
|
||||
"-row-mt 1",
|
||||
"-g 240",
|
||||
}
|
||||
|
||||
for _, check := range vp9Checks {
|
||||
if !strings.Contains(argsStr, check) {
|
||||
t.Errorf("Missing VP9 argument: %s", check)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftwareEncoder_BuildPass1Command(t *testing.T) {
|
||||
encoder := &SoftwareEncoder{codec: "libx264"}
|
||||
config := &EncodeConfig{
|
||||
InputPattern: "frame_%04d.exr",
|
||||
OutputPath: "output.mp4",
|
||||
StartFrame: 1,
|
||||
FrameRate: 24.0,
|
||||
WorkDir: "/tmp",
|
||||
UseAlpha: false,
|
||||
TwoPass: true,
|
||||
SourceFormat: "exr",
|
||||
}
|
||||
|
||||
cmd := encoder.BuildPass1Command(config)
|
||||
args := cmd.Args[1:]
|
||||
argsStr := strings.Join(args, " ")
|
||||
|
||||
// Pass 1 should have -pass 1 and output to null
|
||||
if !strings.Contains(argsStr, "-pass 1") {
|
||||
t.Error("Pass 1 command should include '-pass 1'")
|
||||
}
|
||||
|
||||
if !strings.Contains(argsStr, "-f null") {
|
||||
t.Error("Pass 1 command should include '-f null'")
|
||||
}
|
||||
|
||||
if !strings.Contains(argsStr, "/dev/null") {
|
||||
t.Error("Pass 1 command should output to /dev/null")
|
||||
}
|
||||
|
||||
// Should NOT have output path
|
||||
if strings.Contains(argsStr, "output.mp4") {
|
||||
t.Error("Pass 1 command should not include output path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftwareEncoder_BuildPass1Command_AV1(t *testing.T) {
|
||||
encoder := &SoftwareEncoder{codec: "libaom-av1"}
|
||||
config := &EncodeConfig{
|
||||
InputPattern: "frame_%04d.exr",
|
||||
OutputPath: "output.mp4",
|
||||
StartFrame: 1,
|
||||
FrameRate: 24.0,
|
||||
WorkDir: "/tmp",
|
||||
UseAlpha: false,
|
||||
TwoPass: true,
|
||||
SourceFormat: "exr",
|
||||
}
|
||||
|
||||
cmd := encoder.BuildPass1Command(config)
|
||||
args := cmd.Args[1:]
|
||||
argsStr := strings.Join(args, " ")
|
||||
|
||||
// Pass 1 should have -pass 1 and output to null
|
||||
if !strings.Contains(argsStr, "-pass 1") {
|
||||
t.Error("Pass 1 command should include '-pass 1'")
|
||||
}
|
||||
|
||||
if !strings.Contains(argsStr, "-f null") {
|
||||
t.Error("Pass 1 command should include '-f null'")
|
||||
}
|
||||
|
||||
if !strings.Contains(argsStr, "/dev/null") {
|
||||
t.Error("Pass 1 command should output to /dev/null")
|
||||
}
|
||||
|
||||
// Check AV1-specific arguments in pass 1
|
||||
av1Checks := []string{
|
||||
"-c:v libaom-av1",
|
||||
"-crf 30",
|
||||
"-b:v 0",
|
||||
"-tiles 2x2",
|
||||
"-g 240",
|
||||
}
|
||||
|
||||
for _, check := range av1Checks {
|
||||
if !strings.Contains(argsStr, check) {
|
||||
t.Errorf("Missing AV1 argument in pass 1: %s", check)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftwareEncoder_BuildPass1Command_VP9(t *testing.T) {
|
||||
encoder := &SoftwareEncoder{codec: "libvpx-vp9"}
|
||||
config := &EncodeConfig{
|
||||
InputPattern: "frame_%04d.exr",
|
||||
OutputPath: "output.webm",
|
||||
StartFrame: 1,
|
||||
FrameRate: 24.0,
|
||||
WorkDir: "/tmp",
|
||||
UseAlpha: false,
|
||||
TwoPass: true,
|
||||
SourceFormat: "exr",
|
||||
}
|
||||
|
||||
cmd := encoder.BuildPass1Command(config)
|
||||
args := cmd.Args[1:]
|
||||
argsStr := strings.Join(args, " ")
|
||||
|
||||
// Pass 1 should have -pass 1 and output to null
|
||||
if !strings.Contains(argsStr, "-pass 1") {
|
||||
t.Error("Pass 1 command should include '-pass 1'")
|
||||
}
|
||||
|
||||
if !strings.Contains(argsStr, "-f null") {
|
||||
t.Error("Pass 1 command should include '-f null'")
|
||||
}
|
||||
|
||||
if !strings.Contains(argsStr, "/dev/null") {
|
||||
t.Error("Pass 1 command should output to /dev/null")
|
||||
}
|
||||
|
||||
// Check VP9-specific arguments in pass 1
|
||||
vp9Checks := []string{
|
||||
"-c:v libvpx-vp9",
|
||||
"-crf 30",
|
||||
"-b:v 0",
|
||||
"-row-mt 1",
|
||||
"-g 240",
|
||||
}
|
||||
|
||||
for _, check := range vp9Checks {
|
||||
if !strings.Contains(argsStr, check) {
|
||||
t.Errorf("Missing VP9 argument in pass 1: %s", check)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftwareEncoder_BuildCommand_NoTwoPass(t *testing.T) {
|
||||
encoder := &SoftwareEncoder{codec: "libx264"}
|
||||
config := &EncodeConfig{
|
||||
InputPattern: "frame_%04d.exr",
|
||||
OutputPath: "output.mp4",
|
||||
StartFrame: 1,
|
||||
FrameRate: 24.0,
|
||||
WorkDir: "/tmp",
|
||||
UseAlpha: false,
|
||||
TwoPass: false,
|
||||
SourceFormat: "exr",
|
||||
}
|
||||
|
||||
cmd := encoder.BuildCommand(config)
|
||||
args := cmd.Args[1:]
|
||||
argsStr := strings.Join(args, " ")
|
||||
|
||||
// Should NOT have -pass flag when TwoPass is false
|
||||
if strings.Contains(argsStr, "-pass") {
|
||||
t.Error("Command should not include -pass flag when TwoPass is false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelector_SelectH264(t *testing.T) {
|
||||
selector := NewSelector()
|
||||
encoder := selector.SelectH264()
|
||||
|
||||
if encoder == nil {
|
||||
t.Fatal("SelectH264 returned nil")
|
||||
}
|
||||
|
||||
if encoder.Codec() != "libx264" {
|
||||
t.Errorf("Expected codec 'libx264', got '%s'", encoder.Codec())
|
||||
}
|
||||
|
||||
if encoder.Name() != "software" {
|
||||
t.Errorf("Expected name 'software', got '%s'", encoder.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelector_SelectAV1(t *testing.T) {
|
||||
selector := NewSelector()
|
||||
encoder := selector.SelectAV1()
|
||||
|
||||
if encoder == nil {
|
||||
t.Fatal("SelectAV1 returned nil")
|
||||
}
|
||||
|
||||
if encoder.Codec() != "libaom-av1" {
|
||||
t.Errorf("Expected codec 'libaom-av1', got '%s'", encoder.Codec())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelector_SelectVP9(t *testing.T) {
|
||||
selector := NewSelector()
|
||||
encoder := selector.SelectVP9()
|
||||
|
||||
if encoder == nil {
|
||||
t.Fatal("SelectVP9 returned nil")
|
||||
}
|
||||
|
||||
if encoder.Codec() != "libvpx-vp9" {
|
||||
t.Errorf("Expected codec 'libvpx-vp9', got '%s'", encoder.Codec())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTonemapFilter_WithAlpha(t *testing.T) {
|
||||
filter := tonemapFilter(true)
|
||||
|
||||
// Filter should convert from gbrpf32le to yuva420p10le with proper colorspace conversion
|
||||
if !strings.Contains(filter, "yuva420p10le") {
|
||||
t.Error("Tonemap filter with alpha should output yuva420p10le format for HDR")
|
||||
}
|
||||
|
||||
if !strings.Contains(filter, "gbrpf32le") {
|
||||
t.Error("Tonemap filter should start with gbrpf32le format")
|
||||
}
|
||||
|
||||
// Should use zscale for colorspace conversion from linear RGB to bt2020 YUV
|
||||
if !strings.Contains(filter, "zscale") {
|
||||
t.Error("Tonemap filter should use zscale for colorspace conversion")
|
||||
}
|
||||
|
||||
// Check for HLG transfer function (numeric value 18 or string arib-std-b67)
|
||||
if !strings.Contains(filter, "transfer=18") && !strings.Contains(filter, "transfer=arib-std-b67") {
|
||||
t.Error("Tonemap filter should use HLG transfer function (18 or arib-std-b67)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTonemapFilter_WithoutAlpha(t *testing.T) {
|
||||
filter := tonemapFilter(false)
|
||||
|
||||
// Filter should convert from gbrpf32le to yuv420p10le with proper colorspace conversion
|
||||
if !strings.Contains(filter, "yuv420p10le") {
|
||||
t.Error("Tonemap filter without alpha should output yuv420p10le format for HDR")
|
||||
}
|
||||
|
||||
if strings.Contains(filter, "yuva420p") {
|
||||
t.Error("Tonemap filter without alpha should not output yuva420p format")
|
||||
}
|
||||
|
||||
if !strings.Contains(filter, "gbrpf32le") {
|
||||
t.Error("Tonemap filter should start with gbrpf32le format")
|
||||
}
|
||||
|
||||
// Should use zscale for colorspace conversion from linear RGB to bt2020 YUV
|
||||
if !strings.Contains(filter, "zscale") {
|
||||
t.Error("Tonemap filter should use zscale for colorspace conversion")
|
||||
}
|
||||
|
||||
// Check for HLG transfer function (numeric value 18 or string arib-std-b67)
|
||||
if !strings.Contains(filter, "transfer=18") && !strings.Contains(filter, "transfer=arib-std-b67") {
|
||||
t.Error("Tonemap filter should use HLG transfer function (18 or arib-std-b67)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftwareEncoder_Available(t *testing.T) {
|
||||
encoder := &SoftwareEncoder{codec: "libx264"}
|
||||
if !encoder.Available() {
|
||||
t.Error("Software encoder should always be available")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeConfig_DefaultSourceFormat(t *testing.T) {
|
||||
config := &EncodeConfig{
|
||||
InputPattern: "frame_%04d.exr",
|
||||
OutputPath: "output.mp4",
|
||||
StartFrame: 1,
|
||||
FrameRate: 24.0,
|
||||
WorkDir: "/tmp",
|
||||
UseAlpha: false,
|
||||
TwoPass: false,
|
||||
// SourceFormat not set, should default to empty string (treated as exr)
|
||||
}
|
||||
|
||||
encoder := &SoftwareEncoder{codec: "libx264"}
|
||||
cmd := encoder.BuildCommand(config)
|
||||
args := strings.Join(cmd.Args[1:], " ")
|
||||
|
||||
// Should still have tonemap filter when SourceFormat is empty (defaults to exr behavior)
|
||||
if !strings.Contains(args, "-vf") {
|
||||
t.Error("Empty SourceFormat should default to EXR behavior with tonemap filter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandOrder(t *testing.T) {
|
||||
encoder := &SoftwareEncoder{codec: "libx264"}
|
||||
config := &EncodeConfig{
|
||||
InputPattern: "frame_%04d.exr",
|
||||
OutputPath: "output.mp4",
|
||||
StartFrame: 1,
|
||||
FrameRate: 24.0,
|
||||
WorkDir: "/tmp",
|
||||
UseAlpha: false,
|
||||
TwoPass: true,
|
||||
SourceFormat: "exr",
|
||||
}
|
||||
|
||||
cmd := encoder.BuildCommand(config)
|
||||
args := cmd.Args[1:]
|
||||
|
||||
// Verify argument order: input should come before codec
|
||||
inputIdx := -1
|
||||
codecIdx := -1
|
||||
vfIdx := -1
|
||||
|
||||
for i, arg := range args {
|
||||
if arg == "-i" && i+1 < len(args) && args[i+1] == "frame_%04d.exr" {
|
||||
inputIdx = i
|
||||
}
|
||||
if arg == "-c:v" && i+1 < len(args) && args[i+1] == "libx264" {
|
||||
codecIdx = i
|
||||
}
|
||||
if arg == "-vf" {
|
||||
vfIdx = i
|
||||
}
|
||||
}
|
||||
|
||||
if inputIdx == -1 {
|
||||
t.Fatal("Input pattern not found in command")
|
||||
}
|
||||
if codecIdx == -1 {
|
||||
t.Fatal("Codec not found in command")
|
||||
}
|
||||
if vfIdx == -1 {
|
||||
t.Fatal("Video filter not found in command")
|
||||
}
|
||||
|
||||
// Input should come before codec
|
||||
if inputIdx >= codecIdx {
|
||||
t.Error("Input pattern should come before codec in command")
|
||||
}
|
||||
|
||||
// Video filter should come after input (order: input -> codec -> colorspace -> filter -> codec args)
|
||||
// In practice, the filter comes after codec and colorspace metadata but before codec-specific args
|
||||
if vfIdx <= inputIdx {
|
||||
t.Error("Video filter should come after input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommand_ColorspaceMetadata(t *testing.T) {
|
||||
encoder := &SoftwareEncoder{codec: "libx264"}
|
||||
config := &EncodeConfig{
|
||||
InputPattern: "frame_%04d.exr",
|
||||
OutputPath: "output.mp4",
|
||||
StartFrame: 1,
|
||||
FrameRate: 24.0,
|
||||
WorkDir: "/tmp",
|
||||
UseAlpha: false,
|
||||
TwoPass: false,
|
||||
SourceFormat: "exr",
|
||||
PreserveHDR: false, // SDR encoding
|
||||
}
|
||||
|
||||
cmd := encoder.BuildCommand(config)
|
||||
args := cmd.Args[1:]
|
||||
argsStr := strings.Join(args, " ")
|
||||
|
||||
// Verify all SDR colorspace metadata is present for EXR (SDR encoding)
|
||||
colorspaceArgs := []string{
|
||||
"-color_primaries bt709", // EXR uses bt709 (SDR)
|
||||
"-color_trc bt709", // EXR uses bt709 (SDR)
|
||||
"-colorspace bt709",
|
||||
"-color_range tv",
|
||||
}
|
||||
|
||||
for _, arg := range colorspaceArgs {
|
||||
if !strings.Contains(argsStr, arg) {
|
||||
t.Errorf("Missing colorspace metadata: %s", arg)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify SDR pixel format
|
||||
if !strings.Contains(argsStr, "-pix_fmt yuv420p") {
|
||||
t.Error("SDR encoding should use yuv420p pixel format")
|
||||
}
|
||||
|
||||
// Verify H.264 high profile (not high10)
|
||||
if !strings.Contains(argsStr, "-profile:v high") {
|
||||
t.Error("SDR encoding should use high profile")
|
||||
}
|
||||
if strings.Contains(argsStr, "-profile:v high10") {
|
||||
t.Error("SDR encoding should not use high10 profile")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommand_HDR_ColorspaceMetadata(t *testing.T) {
|
||||
encoder := &SoftwareEncoder{codec: "libx264"}
|
||||
config := &EncodeConfig{
|
||||
InputPattern: "frame_%04d.exr",
|
||||
OutputPath: "output.mp4",
|
||||
StartFrame: 1,
|
||||
FrameRate: 24.0,
|
||||
WorkDir: "/tmp",
|
||||
UseAlpha: false,
|
||||
TwoPass: false,
|
||||
SourceFormat: "exr",
|
||||
PreserveHDR: true, // HDR encoding
|
||||
}
|
||||
|
||||
cmd := encoder.BuildCommand(config)
|
||||
args := cmd.Args[1:]
|
||||
argsStr := strings.Join(args, " ")
|
||||
|
||||
// Verify all HDR colorspace metadata is present for EXR (HDR encoding)
|
||||
colorspaceArgs := []string{
|
||||
"-color_primaries bt709", // bt709 primaries to match PNG color appearance
|
||||
"-color_trc arib-std-b67", // HLG transfer function for HDR/SDR compatibility
|
||||
"-colorspace bt709", // bt709 colorspace to match PNG
|
||||
"-color_range tv",
|
||||
}
|
||||
|
||||
for _, arg := range colorspaceArgs {
|
||||
if !strings.Contains(argsStr, arg) {
|
||||
t.Errorf("Missing HDR colorspace metadata: %s", arg)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify HDR pixel format (10-bit)
|
||||
if !strings.Contains(argsStr, "-pix_fmt yuv420p10le") {
|
||||
t.Error("HDR encoding should use yuv420p10le pixel format")
|
||||
}
|
||||
|
||||
// Verify H.264 high10 profile (for 10-bit)
|
||||
if !strings.Contains(argsStr, "-profile:v high10") {
|
||||
t.Error("HDR encoding should use high10 profile")
|
||||
}
|
||||
|
||||
// Verify HDR filter chain (linear -> sRGB -> HLG)
|
||||
if !strings.Contains(argsStr, "-vf") {
|
||||
t.Fatal("HDR encoding should have video filter")
|
||||
}
|
||||
vfIdx := -1
|
||||
for i, arg := range args {
|
||||
if arg == "-vf" && i+1 < len(args) {
|
||||
vfIdx = i + 1
|
||||
break
|
||||
}
|
||||
}
|
||||
if vfIdx == -1 {
|
||||
t.Fatal("Video filter not found")
|
||||
}
|
||||
filter := args[vfIdx]
|
||||
if !strings.Contains(filter, "transfer=18") {
|
||||
t.Error("HDR filter should convert to HLG (transfer=18)")
|
||||
}
|
||||
if !strings.Contains(filter, "yuv420p10le") {
|
||||
t.Error("HDR filter should output yuv420p10le format")
|
||||
}
|
||||
}
|
||||
|
||||
// Integration tests using example files
|
||||
func TestIntegration_Encode_EXR_H264(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// Check if example file exists
|
||||
exampleDir := filepath.Join("..", "..", "..", "examples")
|
||||
exrFile := filepath.Join(exampleDir, "frame_0800.exr")
|
||||
if _, err := os.Stat(exrFile); os.IsNotExist(err) {
|
||||
t.Skipf("Example file not found: %s", exrFile)
|
||||
}
|
||||
|
||||
// Get absolute paths
|
||||
workspaceRoot, err := filepath.Abs(filepath.Join("..", "..", ".."))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get workspace root: %v", err)
|
||||
}
|
||||
exampleDirAbs, err := filepath.Abs(exampleDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get example directory: %v", err)
|
||||
}
|
||||
tmpDir := filepath.Join(workspaceRoot, "tmp")
|
||||
if err := os.MkdirAll(tmpDir, 0755); err != nil {
|
||||
t.Fatalf("Failed to create tmp directory: %v", err)
|
||||
}
|
||||
|
||||
encoder := &SoftwareEncoder{codec: "libx264"}
|
||||
config := &EncodeConfig{
|
||||
InputPattern: filepath.Join(exampleDirAbs, "frame_%04d.exr"),
|
||||
OutputPath: filepath.Join(tmpDir, "test_exr_h264.mp4"),
|
||||
StartFrame: 800,
|
||||
FrameRate: 24.0,
|
||||
WorkDir: tmpDir,
|
||||
UseAlpha: false,
|
||||
TwoPass: false, // Use single pass for faster testing
|
||||
SourceFormat: "exr",
|
||||
}
|
||||
|
||||
// Build and run command
|
||||
cmd := encoder.BuildCommand(config)
|
||||
if cmd == nil {
|
||||
t.Fatal("BuildCommand returned nil")
|
||||
}
|
||||
|
||||
// Capture stderr to see what went wrong
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Errorf("FFmpeg command failed: %v\nCommand output: %s", err, string(output))
|
||||
return
|
||||
}
|
||||
|
||||
// Verify output file was created
|
||||
if _, err := os.Stat(config.OutputPath); os.IsNotExist(err) {
|
||||
t.Errorf("Output file was not created: %s\nCommand output: %s", config.OutputPath, string(output))
|
||||
} else {
|
||||
t.Logf("Successfully created output file: %s", config.OutputPath)
|
||||
// Verify file has content
|
||||
info, _ := os.Stat(config.OutputPath)
|
||||
if info.Size() == 0 {
|
||||
t.Errorf("Output file was created but is empty\nCommand output: %s", string(output))
|
||||
} else {
|
||||
t.Logf("Output file size: %d bytes", info.Size())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_Encode_PNG_H264(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// Check if example file exists
|
||||
exampleDir := filepath.Join("..", "..", "..", "examples")
|
||||
pngFile := filepath.Join(exampleDir, "frame_0800.png")
|
||||
if _, err := os.Stat(pngFile); os.IsNotExist(err) {
|
||||
t.Skipf("Example file not found: %s", pngFile)
|
||||
}
|
||||
|
||||
// Get absolute paths
|
||||
workspaceRoot, err := filepath.Abs(filepath.Join("..", "..", ".."))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get workspace root: %v", err)
|
||||
}
|
||||
exampleDirAbs, err := filepath.Abs(exampleDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get example directory: %v", err)
|
||||
}
|
||||
tmpDir := filepath.Join(workspaceRoot, "tmp")
|
||||
if err := os.MkdirAll(tmpDir, 0755); err != nil {
|
||||
t.Fatalf("Failed to create tmp directory: %v", err)
|
||||
}
|
||||
|
||||
encoder := &SoftwareEncoder{codec: "libx264"}
|
||||
config := &EncodeConfig{
|
||||
InputPattern: filepath.Join(exampleDirAbs, "frame_%04d.png"),
|
||||
OutputPath: filepath.Join(tmpDir, "test_png_h264.mp4"),
|
||||
StartFrame: 800,
|
||||
FrameRate: 24.0,
|
||||
WorkDir: tmpDir,
|
||||
UseAlpha: false,
|
||||
TwoPass: false, // Use single pass for faster testing
|
||||
SourceFormat: "png",
|
||||
}
|
||||
|
||||
// Build and run command
|
||||
cmd := encoder.BuildCommand(config)
|
||||
if cmd == nil {
|
||||
t.Fatal("BuildCommand returned nil")
|
||||
}
|
||||
|
||||
// Verify no video filter is used for PNG
|
||||
argsStr := strings.Join(cmd.Args, " ")
|
||||
if strings.Contains(argsStr, "-vf") {
|
||||
t.Error("PNG encoding should not use video filter, but -vf was found in command")
|
||||
}
|
||||
|
||||
// Run the command
|
||||
cmdOutput, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Errorf("FFmpeg command failed: %v\nCommand output: %s", err, string(cmdOutput))
|
||||
return
|
||||
}
|
||||
|
||||
// Verify output file was created
|
||||
if _, err := os.Stat(config.OutputPath); os.IsNotExist(err) {
|
||||
t.Errorf("Output file was not created: %s\nCommand output: %s", config.OutputPath, string(cmdOutput))
|
||||
} else {
|
||||
t.Logf("Successfully created output file: %s", config.OutputPath)
|
||||
info, _ := os.Stat(config.OutputPath)
|
||||
if info.Size() == 0 {
|
||||
t.Error("Output file was created but is empty")
|
||||
} else {
|
||||
t.Logf("Output file size: %d bytes", info.Size())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_Encode_EXR_VP9(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// Check if example file exists
|
||||
exampleDir := filepath.Join("..", "..", "..", "examples")
|
||||
exrFile := filepath.Join(exampleDir, "frame_0800.exr")
|
||||
if _, err := os.Stat(exrFile); os.IsNotExist(err) {
|
||||
t.Skipf("Example file not found: %s", exrFile)
|
||||
}
|
||||
|
||||
// Check if VP9 encoder is available
|
||||
checkCmd := exec.Command("ffmpeg", "-hide_banner", "-encoders")
|
||||
checkOutput, err := checkCmd.CombinedOutput()
|
||||
if err != nil || !strings.Contains(string(checkOutput), "libvpx-vp9") {
|
||||
t.Skip("VP9 encoder (libvpx-vp9) not available in ffmpeg")
|
||||
}
|
||||
|
||||
// Get absolute paths
|
||||
workspaceRoot, err := filepath.Abs(filepath.Join("..", "..", ".."))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get workspace root: %v", err)
|
||||
}
|
||||
exampleDirAbs, err := filepath.Abs(exampleDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get example directory: %v", err)
|
||||
}
|
||||
tmpDir := filepath.Join(workspaceRoot, "tmp")
|
||||
if err := os.MkdirAll(tmpDir, 0755); err != nil {
|
||||
t.Fatalf("Failed to create tmp directory: %v", err)
|
||||
}
|
||||
|
||||
encoder := &SoftwareEncoder{codec: "libvpx-vp9"}
|
||||
config := &EncodeConfig{
|
||||
InputPattern: filepath.Join(exampleDirAbs, "frame_%04d.exr"),
|
||||
OutputPath: filepath.Join(tmpDir, "test_exr_vp9.webm"),
|
||||
StartFrame: 800,
|
||||
FrameRate: 24.0,
|
||||
WorkDir: tmpDir,
|
||||
UseAlpha: false,
|
||||
TwoPass: false, // Use single pass for faster testing
|
||||
SourceFormat: "exr",
|
||||
}
|
||||
|
||||
// Build and run command
|
||||
cmd := encoder.BuildCommand(config)
|
||||
if cmd == nil {
|
||||
t.Fatal("BuildCommand returned nil")
|
||||
}
|
||||
|
||||
// Capture stderr to see what went wrong
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Errorf("FFmpeg command failed: %v\nCommand output: %s", err, string(output))
|
||||
return
|
||||
}
|
||||
|
||||
// Verify output file was created
|
||||
if _, err := os.Stat(config.OutputPath); os.IsNotExist(err) {
|
||||
t.Errorf("Output file was not created: %s\nCommand output: %s", config.OutputPath, string(output))
|
||||
} else {
|
||||
t.Logf("Successfully created output file: %s", config.OutputPath)
|
||||
// Verify file has content
|
||||
info, _ := os.Stat(config.OutputPath)
|
||||
if info.Size() == 0 {
|
||||
t.Errorf("Output file was created but is empty\nCommand output: %s", string(output))
|
||||
} else {
|
||||
t.Logf("Output file size: %d bytes", info.Size())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_Encode_EXR_AV1(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// Check if example file exists
|
||||
exampleDir := filepath.Join("..", "..", "..", "examples")
|
||||
exrFile := filepath.Join(exampleDir, "frame_0800.exr")
|
||||
if _, err := os.Stat(exrFile); os.IsNotExist(err) {
|
||||
t.Skipf("Example file not found: %s", exrFile)
|
||||
}
|
||||
|
||||
// Check if AV1 encoder is available
|
||||
checkCmd := exec.Command("ffmpeg", "-hide_banner", "-encoders")
|
||||
output, err := checkCmd.CombinedOutput()
|
||||
if err != nil || !strings.Contains(string(output), "libaom-av1") {
|
||||
t.Skip("AV1 encoder (libaom-av1) not available in ffmpeg")
|
||||
}
|
||||
|
||||
// Get absolute paths
|
||||
workspaceRoot, err := filepath.Abs(filepath.Join("..", "..", ".."))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get workspace root: %v", err)
|
||||
}
|
||||
exampleDirAbs, err := filepath.Abs(exampleDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get example directory: %v", err)
|
||||
}
|
||||
tmpDir := filepath.Join(workspaceRoot, "tmp")
|
||||
if err := os.MkdirAll(tmpDir, 0755); err != nil {
|
||||
t.Fatalf("Failed to create tmp directory: %v", err)
|
||||
}
|
||||
|
||||
encoder := &SoftwareEncoder{codec: "libaom-av1"}
|
||||
config := &EncodeConfig{
|
||||
InputPattern: filepath.Join(exampleDirAbs, "frame_%04d.exr"),
|
||||
OutputPath: filepath.Join(tmpDir, "test_exr_av1.mp4"),
|
||||
StartFrame: 800,
|
||||
FrameRate: 24.0,
|
||||
WorkDir: tmpDir,
|
||||
UseAlpha: false,
|
||||
TwoPass: false,
|
||||
SourceFormat: "exr",
|
||||
}
|
||||
|
||||
// Build and run command
|
||||
cmd := encoder.BuildCommand(config)
|
||||
cmdOutput, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Errorf("FFmpeg command failed: %v\nCommand output: %s", err, string(cmdOutput))
|
||||
return
|
||||
}
|
||||
|
||||
// Verify output file was created
|
||||
if _, err := os.Stat(config.OutputPath); os.IsNotExist(err) {
|
||||
t.Errorf("Output file was not created: %s\nCommand output: %s", config.OutputPath, string(cmdOutput))
|
||||
} else {
|
||||
t.Logf("Successfully created AV1 output file: %s", config.OutputPath)
|
||||
info, _ := os.Stat(config.OutputPath)
|
||||
if info.Size() == 0 {
|
||||
t.Errorf("Output file was created but is empty\nCommand output: %s", string(cmdOutput))
|
||||
} else {
|
||||
t.Logf("Output file size: %d bytes", info.Size())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_Encode_EXR_VP9_WithAlpha(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// Check if example file exists
|
||||
exampleDir := filepath.Join("..", "..", "..", "examples")
|
||||
exrFile := filepath.Join(exampleDir, "frame_0800.exr")
|
||||
if _, err := os.Stat(exrFile); os.IsNotExist(err) {
|
||||
t.Skipf("Example file not found: %s", exrFile)
|
||||
}
|
||||
|
||||
// Check if VP9 encoder is available
|
||||
checkCmd := exec.Command("ffmpeg", "-hide_banner", "-encoders")
|
||||
output, err := checkCmd.CombinedOutput()
|
||||
if err != nil || !strings.Contains(string(output), "libvpx-vp9") {
|
||||
t.Skip("VP9 encoder (libvpx-vp9) not available in ffmpeg")
|
||||
}
|
||||
|
||||
// Get absolute paths
|
||||
workspaceRoot, err := filepath.Abs(filepath.Join("..", "..", ".."))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get workspace root: %v", err)
|
||||
}
|
||||
exampleDirAbs, err := filepath.Abs(exampleDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get example directory: %v", err)
|
||||
}
|
||||
tmpDir := filepath.Join(workspaceRoot, "tmp")
|
||||
if err := os.MkdirAll(tmpDir, 0755); err != nil {
|
||||
t.Fatalf("Failed to create tmp directory: %v", err)
|
||||
}
|
||||
|
||||
encoder := &SoftwareEncoder{codec: "libvpx-vp9"}
|
||||
config := &EncodeConfig{
|
||||
InputPattern: filepath.Join(exampleDirAbs, "frame_%04d.exr"),
|
||||
OutputPath: filepath.Join(tmpDir, "test_exr_vp9_alpha.webm"),
|
||||
StartFrame: 800,
|
||||
FrameRate: 24.0,
|
||||
WorkDir: tmpDir,
|
||||
UseAlpha: true, // Test with alpha
|
||||
TwoPass: false, // Use single pass for faster testing
|
||||
SourceFormat: "exr",
|
||||
}
|
||||
|
||||
// Build and run command
|
||||
cmd := encoder.BuildCommand(config)
|
||||
if cmd == nil {
|
||||
t.Fatal("BuildCommand returned nil")
|
||||
}
|
||||
|
||||
// Capture stderr to see what went wrong
|
||||
cmdOutput, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Errorf("FFmpeg command failed: %v\nCommand output: %s", err, string(cmdOutput))
|
||||
return
|
||||
}
|
||||
|
||||
// Verify output file was created
|
||||
if _, err := os.Stat(config.OutputPath); os.IsNotExist(err) {
|
||||
t.Errorf("Output file was not created: %s\nCommand output: %s", config.OutputPath, string(cmdOutput))
|
||||
} else {
|
||||
t.Logf("Successfully created VP9 output file with alpha: %s", config.OutputPath)
|
||||
info, _ := os.Stat(config.OutputPath)
|
||||
if info.Size() == 0 {
|
||||
t.Errorf("Output file was created but is empty\nCommand output: %s", string(cmdOutput))
|
||||
} else {
|
||||
t.Logf("Output file size: %d bytes", info.Size())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to copy files
|
||||
func copyFile(src, dst string) error {
|
||||
data, err := os.ReadFile(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(dst, data, 0644)
|
||||
}
|
||||
365
internal/runner/runner.go
Normal file
365
internal/runner/runner.go
Normal file
@@ -0,0 +1,365 @@
|
||||
// Package runner provides the Jiggablend render runner.
|
||||
package runner
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"jiggablend/internal/runner/api"
|
||||
"jiggablend/internal/runner/blender"
|
||||
"jiggablend/internal/runner/encoding"
|
||||
"jiggablend/internal/runner/tasks"
|
||||
"jiggablend/internal/runner/workspace"
|
||||
"jiggablend/pkg/executils"
|
||||
"jiggablend/pkg/types"
|
||||
)
|
||||
|
||||
// Runner is the main render runner.
|
||||
type Runner struct {
|
||||
id int64
|
||||
name string
|
||||
hostname string
|
||||
|
||||
manager *api.ManagerClient
|
||||
workspace *workspace.Manager
|
||||
blender *blender.Manager
|
||||
encoder *encoding.Selector
|
||||
processes *executils.ProcessTracker
|
||||
|
||||
processors map[string]tasks.Processor
|
||||
stopChan chan struct{}
|
||||
|
||||
fingerprint string
|
||||
fingerprintMu sync.RWMutex
|
||||
}
|
||||
|
||||
// New creates a new runner.
|
||||
func New(managerURL, name, hostname string) *Runner {
|
||||
manager := api.NewManagerClient(managerURL)
|
||||
|
||||
r := &Runner{
|
||||
name: name,
|
||||
hostname: hostname,
|
||||
manager: manager,
|
||||
processes: executils.NewProcessTracker(),
|
||||
stopChan: make(chan struct{}),
|
||||
processors: make(map[string]tasks.Processor),
|
||||
}
|
||||
|
||||
// Generate fingerprint
|
||||
r.generateFingerprint()
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// CheckRequiredTools verifies that required external tools are available.
|
||||
func (r *Runner) CheckRequiredTools() error {
|
||||
if err := exec.Command("zstd", "--version").Run(); err != nil {
|
||||
return fmt.Errorf("zstd not found - required for compressed blend file support. Install with: apt install zstd")
|
||||
}
|
||||
log.Printf("Found zstd for compressed blend file support")
|
||||
|
||||
if err := exec.Command("xvfb-run", "--help").Run(); err != nil {
|
||||
return fmt.Errorf("xvfb-run not found - required for headless Blender rendering. Install with: apt install xvfb")
|
||||
}
|
||||
log.Printf("Found xvfb-run for headless rendering without -b option")
|
||||
return nil
|
||||
}
|
||||
|
||||
var cachedCapabilities map[string]interface{} = nil
|
||||
|
||||
// ProbeCapabilities detects hardware capabilities.
|
||||
func (r *Runner) ProbeCapabilities() map[string]interface{} {
|
||||
if cachedCapabilities != nil {
|
||||
return cachedCapabilities
|
||||
}
|
||||
|
||||
caps := make(map[string]interface{})
|
||||
|
||||
// Check for ffmpeg and probe encoding capabilities
|
||||
if err := exec.Command("ffmpeg", "-version").Run(); err == nil {
|
||||
caps["ffmpeg"] = true
|
||||
} else {
|
||||
caps["ffmpeg"] = false
|
||||
}
|
||||
|
||||
cachedCapabilities = caps
|
||||
return caps
|
||||
}
|
||||
|
||||
// Register registers the runner with the manager.
|
||||
func (r *Runner) Register(apiKey string) (int64, error) {
|
||||
caps := r.ProbeCapabilities()
|
||||
|
||||
id, err := r.manager.Register(r.name, r.hostname, caps, apiKey, r.GetFingerprint())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
r.id = id
|
||||
|
||||
// Initialize workspace after registration
|
||||
r.workspace = workspace.NewManager(r.name)
|
||||
|
||||
// Initialize blender manager
|
||||
r.blender = blender.NewManager(r.manager, r.workspace.BaseDir())
|
||||
|
||||
// Initialize encoder selector
|
||||
r.encoder = encoding.NewSelector()
|
||||
|
||||
// Register task processors
|
||||
r.processors["render"] = tasks.NewRenderProcessor()
|
||||
r.processors["encode"] = tasks.NewEncodeProcessor()
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// Start starts the job polling loop.
|
||||
func (r *Runner) Start(pollInterval time.Duration) {
|
||||
log.Printf("Starting job polling loop (interval: %v)", pollInterval)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.stopChan:
|
||||
log.Printf("Stopping job polling loop")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
log.Printf("Polling for next job (runner ID: %d)", r.id)
|
||||
job, err := r.manager.PollNextJob()
|
||||
if err != nil {
|
||||
log.Printf("Error polling for job: %v", err)
|
||||
time.Sleep(pollInterval)
|
||||
continue
|
||||
}
|
||||
|
||||
if job == nil {
|
||||
log.Printf("No job available, sleeping for %v", pollInterval)
|
||||
time.Sleep(pollInterval)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("Received job assignment: task=%d, job=%d, type=%s",
|
||||
job.Task.TaskID, job.Task.JobID, job.Task.TaskType)
|
||||
|
||||
if err := r.executeJob(job); err != nil {
|
||||
log.Printf("Error processing job: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the runner.
|
||||
func (r *Runner) Stop() {
|
||||
close(r.stopChan)
|
||||
}
|
||||
|
||||
// KillAllProcesses kills all running processes.
|
||||
func (r *Runner) KillAllProcesses() {
|
||||
log.Printf("Killing all running processes...")
|
||||
killedCount := r.processes.KillAll()
|
||||
|
||||
// Release all allocated devices
|
||||
if r.encoder != nil {
|
||||
// Device pool cleanup is handled internally
|
||||
}
|
||||
|
||||
log.Printf("Killed %d process(es)", killedCount)
|
||||
}
|
||||
|
||||
// Cleanup removes the workspace directory.
|
||||
func (r *Runner) Cleanup() {
|
||||
if r.workspace != nil {
|
||||
r.workspace.Cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
// executeJob handles a job using per-job WebSocket connection.
|
||||
func (r *Runner) executeJob(job *api.NextJobResponse) (err error) {
|
||||
// Recover from panics to prevent runner process crashes during task execution
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
log.Printf("Task execution panicked: %v", rec)
|
||||
err = fmt.Errorf("task execution panicked: %v", rec)
|
||||
}
|
||||
}()
|
||||
|
||||
// Connect to job WebSocket (no runnerID needed - authentication handles it)
|
||||
jobConn := api.NewJobConnection()
|
||||
if err := jobConn.Connect(r.manager.GetBaseURL(), job.JobPath, job.JobToken); err != nil {
|
||||
return fmt.Errorf("failed to connect job WebSocket: %w", err)
|
||||
}
|
||||
defer jobConn.Close()
|
||||
|
||||
log.Printf("Job WebSocket authenticated for task %d", job.Task.TaskID)
|
||||
|
||||
// Create task context
|
||||
workDir := r.workspace.JobDir(job.Task.JobID)
|
||||
ctx := tasks.NewContext(
|
||||
job.Task.TaskID,
|
||||
job.Task.JobID,
|
||||
job.Task.JobName,
|
||||
job.Task.Frame,
|
||||
job.Task.TaskType,
|
||||
workDir,
|
||||
job.JobToken,
|
||||
job.Task.Metadata,
|
||||
r.manager,
|
||||
jobConn,
|
||||
r.workspace,
|
||||
r.blender,
|
||||
r.encoder,
|
||||
r.processes,
|
||||
)
|
||||
|
||||
ctx.Info(fmt.Sprintf("Task assignment received (job: %d, type: %s)",
|
||||
job.Task.JobID, job.Task.TaskType))
|
||||
|
||||
// Get processor for task type
|
||||
processor, ok := r.processors[job.Task.TaskType]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown task type: %s", job.Task.TaskType)
|
||||
}
|
||||
|
||||
// Process the task
|
||||
var processErr error
|
||||
switch job.Task.TaskType {
|
||||
case "render": // this task has a upload outputs step because the frames are not uploaded by the render task directly we have to do it manually here TODO: maybe we should make it work like the encode task
|
||||
// Download context
|
||||
contextPath := job.JobPath + "/context.tar"
|
||||
if err := r.downloadContext(job.Task.JobID, contextPath, job.JobToken); err != nil {
|
||||
jobConn.Log(job.Task.TaskID, types.LogLevelError, fmt.Sprintf("Failed to download context: %v", err))
|
||||
jobConn.Complete(job.Task.TaskID, false, fmt.Errorf("failed to download context: %v", err))
|
||||
return fmt.Errorf("failed to download context: %w", err)
|
||||
}
|
||||
processErr = processor.Process(ctx)
|
||||
if processErr == nil {
|
||||
processErr = r.uploadOutputs(ctx, job)
|
||||
}
|
||||
case "encode": // this task doesn't have a upload outputs step because the video is already uploaded by the encode task
|
||||
processErr = processor.Process(ctx)
|
||||
default:
|
||||
return fmt.Errorf("unknown task type: %s", job.Task.TaskType)
|
||||
}
|
||||
|
||||
if processErr != nil {
|
||||
ctx.Error(fmt.Sprintf("Task failed: %v", processErr))
|
||||
ctx.Complete(false, processErr)
|
||||
return processErr
|
||||
}
|
||||
|
||||
ctx.Complete(true, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) downloadContext(jobID int64, contextPath, jobToken string) error {
|
||||
reader, err := r.manager.DownloadContext(contextPath, jobToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
jobDir := r.workspace.JobDir(jobID)
|
||||
return workspace.ExtractTar(reader, jobDir)
|
||||
}
|
||||
|
||||
func (r *Runner) uploadOutputs(ctx *tasks.Context, job *api.NextJobResponse) error {
|
||||
outputDir := ctx.WorkDir + "/output"
|
||||
uploadPath := fmt.Sprintf("/api/runner/jobs/%d/upload", job.Task.JobID)
|
||||
|
||||
entries, err := os.ReadDir(outputDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read output directory: %w", err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
filePath := outputDir + "/" + entry.Name()
|
||||
if err := r.manager.UploadFile(uploadPath, job.JobToken, filePath); err != nil {
|
||||
log.Printf("Failed to upload %s: %v", filePath, err)
|
||||
} else {
|
||||
ctx.OutputUploaded(entry.Name())
|
||||
// Delete file after successful upload to prevent duplicate uploads
|
||||
if err := os.Remove(filePath); err != nil {
|
||||
log.Printf("Warning: Failed to delete file %s after upload: %v", filePath, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateFingerprint creates a unique hardware fingerprint.
|
||||
func (r *Runner) generateFingerprint() {
|
||||
r.fingerprintMu.Lock()
|
||||
defer r.fingerprintMu.Unlock()
|
||||
|
||||
var components []string
|
||||
components = append(components, r.hostname)
|
||||
|
||||
if machineID, err := os.ReadFile("/etc/machine-id"); err == nil {
|
||||
components = append(components, strings.TrimSpace(string(machineID)))
|
||||
}
|
||||
|
||||
if productUUID, err := os.ReadFile("/sys/class/dmi/id/product_uuid"); err == nil {
|
||||
components = append(components, strings.TrimSpace(string(productUUID)))
|
||||
}
|
||||
|
||||
if macAddr, err := r.getMACAddress(); err == nil {
|
||||
components = append(components, macAddr)
|
||||
}
|
||||
|
||||
if len(components) <= 1 {
|
||||
components = append(components, fmt.Sprintf("%d", os.Getpid()))
|
||||
components = append(components, fmt.Sprintf("%d", time.Now().Unix()))
|
||||
}
|
||||
|
||||
h := sha256.New()
|
||||
for _, comp := range components {
|
||||
h.Write([]byte(comp))
|
||||
h.Write([]byte{0})
|
||||
}
|
||||
|
||||
r.fingerprint = hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func (r *Runner) getMACAddress() (string, error) {
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, iface := range interfaces {
|
||||
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
|
||||
continue
|
||||
}
|
||||
if len(iface.HardwareAddr) == 0 {
|
||||
continue
|
||||
}
|
||||
return iface.HardwareAddr.String(), nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no suitable network interface found")
|
||||
}
|
||||
|
||||
// GetFingerprint returns the runner's hardware fingerprint.
|
||||
func (r *Runner) GetFingerprint() string {
|
||||
r.fingerprintMu.RLock()
|
||||
defer r.fingerprintMu.RUnlock()
|
||||
return r.fingerprint
|
||||
}
|
||||
|
||||
// GetID returns the runner ID.
|
||||
func (r *Runner) GetID() int64 {
|
||||
return r.id
|
||||
}
|
||||
594
internal/runner/tasks/encode.go
Normal file
594
internal/runner/tasks/encode.go
Normal file
@@ -0,0 +1,594 @@
|
||||
package tasks
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"jiggablend/internal/runner/encoding"
|
||||
)
|
||||
|
||||
// EncodeProcessor handles encode tasks.
|
||||
type EncodeProcessor struct{}
|
||||
|
||||
// NewEncodeProcessor creates a new encode processor.
|
||||
func NewEncodeProcessor() *EncodeProcessor {
|
||||
return &EncodeProcessor{}
|
||||
}
|
||||
|
||||
// Process executes an encode task.
|
||||
func (p *EncodeProcessor) Process(ctx *Context) error {
|
||||
ctx.Info(fmt.Sprintf("Starting encode task: job %d", ctx.JobID))
|
||||
log.Printf("Processing encode task %d for job %d", ctx.TaskID, ctx.JobID)
|
||||
|
||||
// Create temporary work directory
|
||||
workDir, err := ctx.Workspace.CreateVideoDir(ctx.JobID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create work directory: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ctx.Workspace.CleanupVideoDir(ctx.JobID); err != nil {
|
||||
log.Printf("Warning: Failed to cleanup encode work directory: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Get output format and frame rate
|
||||
outputFormat := ctx.GetOutputFormat()
|
||||
if outputFormat == "" {
|
||||
outputFormat = "EXR_264_MP4"
|
||||
}
|
||||
frameRate := ctx.GetFrameRate()
|
||||
|
||||
ctx.Info(fmt.Sprintf("Encode: detected output format '%s'", outputFormat))
|
||||
ctx.Info(fmt.Sprintf("Encode: using frame rate %.2f fps", frameRate))
|
||||
|
||||
// Get job files
|
||||
files, err := ctx.Manager.GetJobFiles(ctx.JobID)
|
||||
if err != nil {
|
||||
ctx.Error(fmt.Sprintf("Failed to get job files: %v", err))
|
||||
return fmt.Errorf("failed to get job files: %w", err)
|
||||
}
|
||||
|
||||
ctx.Info(fmt.Sprintf("GetJobFiles returned %d total files for job %d", len(files), ctx.JobID))
|
||||
|
||||
// Log all files for debugging
|
||||
for _, file := range files {
|
||||
ctx.Info(fmt.Sprintf("File: %s (type: %s, size: %d)", file.FileName, file.FileType, file.FileSize))
|
||||
}
|
||||
|
||||
// Determine source format based on output format
|
||||
sourceFormat := "exr"
|
||||
fileExt := ".exr"
|
||||
|
||||
// Find and deduplicate frame files (EXR or PNG)
|
||||
frameFileSet := make(map[string]bool)
|
||||
var frameFilesList []string
|
||||
for _, file := range files {
|
||||
if file.FileType == "output" && strings.HasSuffix(strings.ToLower(file.FileName), fileExt) {
|
||||
// Deduplicate by filename
|
||||
if !frameFileSet[file.FileName] {
|
||||
frameFileSet[file.FileName] = true
|
||||
frameFilesList = append(frameFilesList, file.FileName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(frameFilesList) == 0 {
|
||||
// Log why no files matched (deduplicate for error reporting)
|
||||
outputFileSet := make(map[string]bool)
|
||||
frameFilesOtherTypeSet := make(map[string]bool)
|
||||
var outputFiles []string
|
||||
var frameFilesOtherType []string
|
||||
|
||||
for _, file := range files {
|
||||
if file.FileType == "output" {
|
||||
if !outputFileSet[file.FileName] {
|
||||
outputFileSet[file.FileName] = true
|
||||
outputFiles = append(outputFiles, file.FileName)
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(strings.ToLower(file.FileName), fileExt) {
|
||||
key := fmt.Sprintf("%s (type: %s)", file.FileName, file.FileType)
|
||||
if !frameFilesOtherTypeSet[key] {
|
||||
frameFilesOtherTypeSet[key] = true
|
||||
frameFilesOtherType = append(frameFilesOtherType, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
ctx.Error(fmt.Sprintf("no %s frame files found for encode: found %d total files, %d unique output files, %d unique %s files (with other types)", strings.ToUpper(fileExt[1:]), len(files), len(outputFiles), len(frameFilesOtherType), strings.ToUpper(fileExt[1:])))
|
||||
if len(outputFiles) > 0 {
|
||||
ctx.Error(fmt.Sprintf("Output files found: %v", outputFiles))
|
||||
}
|
||||
if len(frameFilesOtherType) > 0 {
|
||||
ctx.Error(fmt.Sprintf("%s files with wrong type: %v", strings.ToUpper(fileExt[1:]), frameFilesOtherType))
|
||||
}
|
||||
err := fmt.Errorf("no %s frame files found for encode", strings.ToUpper(fileExt[1:]))
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Info(fmt.Sprintf("Found %d %s frames for encode", len(frameFilesList), strings.ToUpper(fileExt[1:])))
|
||||
|
||||
// Download frames
|
||||
ctx.Info(fmt.Sprintf("Downloading %d %s frames for encode...", len(frameFilesList), strings.ToUpper(fileExt[1:])))
|
||||
|
||||
var frameFiles []string
|
||||
for i, fileName := range frameFilesList {
|
||||
ctx.Info(fmt.Sprintf("Downloading frame %d/%d: %s", i+1, len(frameFilesList), fileName))
|
||||
framePath := filepath.Join(workDir, fileName)
|
||||
if err := ctx.Manager.DownloadFrame(ctx.JobID, fileName, framePath); err != nil {
|
||||
ctx.Error(fmt.Sprintf("Failed to download %s frame %s: %v", strings.ToUpper(fileExt[1:]), fileName, err))
|
||||
log.Printf("Failed to download %s frame for encode %s: %v", strings.ToUpper(fileExt[1:]), fileName, err)
|
||||
continue
|
||||
}
|
||||
ctx.Info(fmt.Sprintf("Successfully downloaded frame %d/%d: %s", i+1, len(frameFilesList), fileName))
|
||||
frameFiles = append(frameFiles, framePath)
|
||||
}
|
||||
|
||||
if len(frameFiles) == 0 {
|
||||
err := fmt.Errorf("failed to download any %s frames for encode", strings.ToUpper(fileExt[1:]))
|
||||
ctx.Error(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
sort.Strings(frameFiles)
|
||||
ctx.Info(fmt.Sprintf("Downloaded %d frames", len(frameFiles)))
|
||||
|
||||
// Check if EXR files have alpha channel and HDR content (only for EXR source format)
|
||||
hasAlpha := false
|
||||
hasHDR := false
|
||||
if sourceFormat == "exr" {
|
||||
// Check first frame for alpha channel and HDR using ffprobe
|
||||
firstFrame := frameFiles[0]
|
||||
hasAlpha = detectAlphaChannel(ctx, firstFrame)
|
||||
if hasAlpha {
|
||||
ctx.Info("Detected alpha channel in EXR files")
|
||||
} else {
|
||||
ctx.Info("No alpha channel detected in EXR files")
|
||||
}
|
||||
|
||||
hasHDR = detectHDR(ctx, firstFrame)
|
||||
if hasHDR {
|
||||
ctx.Info("Detected HDR content in EXR files")
|
||||
} else {
|
||||
ctx.Info("No HDR content detected in EXR files (SDR range)")
|
||||
}
|
||||
}
|
||||
|
||||
// Generate video
|
||||
// Use alpha if:
|
||||
// 1. User explicitly enabled it OR source has alpha channel AND
|
||||
// 2. Codec supports alpha (AV1 or VP9)
|
||||
preserveAlpha := ctx.ShouldPreserveAlpha()
|
||||
useAlpha := (preserveAlpha || hasAlpha) && (outputFormat == "EXR_AV1_MP4" || outputFormat == "EXR_VP9_WEBM")
|
||||
if (preserveAlpha || hasAlpha) && outputFormat == "EXR_264_MP4" {
|
||||
ctx.Warn("Alpha channel requested/detected but H.264 does not support alpha. Consider using EXR_AV1_MP4 or EXR_VP9_WEBM to preserve alpha.")
|
||||
}
|
||||
if preserveAlpha && !hasAlpha {
|
||||
ctx.Warn("Alpha preservation requested but no alpha channel detected in EXR files.")
|
||||
}
|
||||
if useAlpha {
|
||||
if preserveAlpha && hasAlpha {
|
||||
ctx.Info("Alpha preservation enabled: Using alpha channel encoding")
|
||||
} else if hasAlpha {
|
||||
ctx.Info("Alpha channel detected - automatically enabling alpha encoding")
|
||||
}
|
||||
}
|
||||
var outputExt string
|
||||
switch outputFormat {
|
||||
case "EXR_VP9_WEBM":
|
||||
outputExt = "webm"
|
||||
ctx.Info("Encoding WebM video with VP9 codec (with alpha channel and HDR support)...")
|
||||
case "EXR_AV1_MP4":
|
||||
outputExt = "mp4"
|
||||
ctx.Info("Encoding MP4 video with AV1 codec (with alpha channel)...")
|
||||
default:
|
||||
outputExt = "mp4"
|
||||
ctx.Info("Encoding MP4 video with H.264 codec...")
|
||||
}
|
||||
|
||||
outputVideo := filepath.Join(workDir, fmt.Sprintf("output_%d.%s", ctx.JobID, outputExt))
|
||||
|
||||
// Build input pattern
|
||||
firstFrame := frameFiles[0]
|
||||
baseName := filepath.Base(firstFrame)
|
||||
re := regexp.MustCompile(`_(\d+)\.`)
|
||||
var pattern string
|
||||
var startNumber int
|
||||
frameNumStr := re.FindStringSubmatch(baseName)
|
||||
if len(frameNumStr) > 1 {
|
||||
pattern = re.ReplaceAllString(baseName, "_%04d.")
|
||||
fmt.Sscanf(frameNumStr[1], "%d", &startNumber)
|
||||
} else {
|
||||
startNumber = extractFrameNumber(baseName)
|
||||
pattern = strings.Replace(baseName, fmt.Sprintf("%d", startNumber), "%04d", 1)
|
||||
}
|
||||
patternPath := filepath.Join(workDir, pattern)
|
||||
|
||||
// Select encoder and build command (software encoding only)
|
||||
var encoder encoding.Encoder
|
||||
switch outputFormat {
|
||||
case "EXR_AV1_MP4":
|
||||
encoder = ctx.Encoder.SelectAV1()
|
||||
case "EXR_VP9_WEBM":
|
||||
encoder = ctx.Encoder.SelectVP9()
|
||||
default:
|
||||
encoder = ctx.Encoder.SelectH264()
|
||||
}
|
||||
|
||||
ctx.Info(fmt.Sprintf("Using encoder: %s (%s)", encoder.Name(), encoder.Codec()))
|
||||
|
||||
// All software encoders use 2-pass for optimal quality
|
||||
ctx.Info("Starting 2-pass encode for optimal quality...")
|
||||
|
||||
// Pass 1
|
||||
ctx.Info("Pass 1/2: Analyzing content for optimal encode...")
|
||||
softEncoder := encoder.(*encoding.SoftwareEncoder)
|
||||
// Use HDR if: user explicitly enabled it OR HDR content was detected
|
||||
preserveHDR := (ctx.ShouldPreserveHDR() || hasHDR) && sourceFormat == "exr"
|
||||
if hasHDR && !ctx.ShouldPreserveHDR() {
|
||||
ctx.Info("HDR content detected - automatically enabling HDR preservation")
|
||||
}
|
||||
pass1Cmd := softEncoder.BuildPass1Command(&encoding.EncodeConfig{
|
||||
InputPattern: patternPath,
|
||||
OutputPath: outputVideo,
|
||||
StartFrame: startNumber,
|
||||
FrameRate: frameRate,
|
||||
WorkDir: workDir,
|
||||
UseAlpha: useAlpha,
|
||||
TwoPass: true,
|
||||
SourceFormat: sourceFormat,
|
||||
PreserveHDR: preserveHDR,
|
||||
})
|
||||
if err := pass1Cmd.Run(); err != nil {
|
||||
ctx.Warn(fmt.Sprintf("Pass 1 completed (warnings expected): %v", err))
|
||||
}
|
||||
|
||||
// Pass 2
|
||||
ctx.Info("Pass 2/2: Encoding with optimal quality...")
|
||||
|
||||
preserveHDR = (ctx.ShouldPreserveHDR() || hasHDR) && sourceFormat == "exr"
|
||||
if preserveHDR {
|
||||
if hasHDR && !ctx.ShouldPreserveHDR() {
|
||||
ctx.Info("HDR preservation enabled (auto-detected): Using HLG transfer with bt709 primaries")
|
||||
} else {
|
||||
ctx.Info("HDR preservation enabled: Using HLG transfer with bt709 primaries")
|
||||
}
|
||||
}
|
||||
|
||||
config := &encoding.EncodeConfig{
|
||||
InputPattern: patternPath,
|
||||
OutputPath: outputVideo,
|
||||
StartFrame: startNumber,
|
||||
FrameRate: frameRate,
|
||||
WorkDir: workDir,
|
||||
UseAlpha: useAlpha,
|
||||
TwoPass: true, // Software encoding always uses 2-pass for quality
|
||||
SourceFormat: sourceFormat,
|
||||
PreserveHDR: preserveHDR,
|
||||
}
|
||||
|
||||
cmd := encoder.BuildCommand(config)
|
||||
if cmd == nil {
|
||||
return errors.New("failed to build encode command")
|
||||
}
|
||||
|
||||
// Set up pipes
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start encode command: %w", err)
|
||||
}
|
||||
|
||||
ctx.Processes.Track(ctx.TaskID, cmd)
|
||||
defer ctx.Processes.Untrack(ctx.TaskID)
|
||||
|
||||
// Stream stdout
|
||||
stdoutDone := make(chan bool)
|
||||
go func() {
|
||||
defer close(stdoutDone)
|
||||
scanner := bufio.NewScanner(stdoutPipe)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line != "" {
|
||||
ctx.Info(line)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Stream stderr
|
||||
stderrDone := make(chan bool)
|
||||
go func() {
|
||||
defer close(stderrDone)
|
||||
scanner := bufio.NewScanner(stderrPipe)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line != "" {
|
||||
ctx.Warn(line)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err = cmd.Wait()
|
||||
<-stdoutDone
|
||||
<-stderrDone
|
||||
|
||||
if err != nil {
|
||||
var errMsg string
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
if exitErr.ExitCode() == 137 {
|
||||
errMsg = "FFmpeg was killed due to excessive memory usage (OOM)"
|
||||
} else {
|
||||
errMsg = fmt.Sprintf("ffmpeg encoding failed: %v", err)
|
||||
}
|
||||
} else {
|
||||
errMsg = fmt.Sprintf("ffmpeg encoding failed: %v", err)
|
||||
}
|
||||
|
||||
if sizeErr := checkFFmpegSizeError(errMsg); sizeErr != nil {
|
||||
ctx.Error(sizeErr.Error())
|
||||
return sizeErr
|
||||
}
|
||||
|
||||
ctx.Error(errMsg)
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
// Verify output
|
||||
if _, err := os.Stat(outputVideo); os.IsNotExist(err) {
|
||||
err := fmt.Errorf("video %s file not created: %s", outputExt, outputVideo)
|
||||
ctx.Error(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
// Clean up 2-pass log files
|
||||
os.Remove(filepath.Join(workDir, "ffmpeg2pass-0.log"))
|
||||
os.Remove(filepath.Join(workDir, "ffmpeg2pass-0.log.mbtree"))
|
||||
|
||||
ctx.Info(fmt.Sprintf("%s video encoded successfully", strings.ToUpper(outputExt)))
|
||||
|
||||
// Upload video
|
||||
ctx.Info(fmt.Sprintf("Uploading encoded %s video...", strings.ToUpper(outputExt)))
|
||||
|
||||
uploadPath := fmt.Sprintf("/api/runner/jobs/%d/upload", ctx.JobID)
|
||||
if err := ctx.Manager.UploadFile(uploadPath, ctx.JobToken, outputVideo); err != nil {
|
||||
ctx.Error(fmt.Sprintf("Failed to upload %s: %v", strings.ToUpper(outputExt), err))
|
||||
return fmt.Errorf("failed to upload %s: %w", strings.ToUpper(outputExt), err)
|
||||
}
|
||||
|
||||
ctx.Info(fmt.Sprintf("Successfully uploaded %s: %s", strings.ToUpper(outputExt), filepath.Base(outputVideo)))
|
||||
|
||||
// Delete file after successful upload to prevent duplicate uploads
|
||||
if err := os.Remove(outputVideo); err != nil {
|
||||
log.Printf("Warning: Failed to delete video file %s after upload: %v", outputVideo, err)
|
||||
ctx.Warn(fmt.Sprintf("Warning: Failed to delete video file after upload: %v", err))
|
||||
}
|
||||
|
||||
log.Printf("Successfully generated and uploaded %s for job %d: %s", strings.ToUpper(outputExt), ctx.JobID, filepath.Base(outputVideo))
|
||||
return nil
|
||||
}
|
||||
|
||||
// detectAlphaChannel checks if an EXR file has an alpha channel using ffprobe
|
||||
func detectAlphaChannel(ctx *Context, filePath string) bool {
|
||||
// Use ffprobe to check pixel format and stream properties
|
||||
// EXR files with alpha will have formats like gbrapf32le (RGBA) vs gbrpf32le (RGB)
|
||||
cmd := exec.Command("ffprobe",
|
||||
"-v", "error",
|
||||
"-select_streams", "v:0",
|
||||
"-show_entries", "stream=pix_fmt:stream=codec_name",
|
||||
"-of", "default=noprint_wrappers=1",
|
||||
filePath,
|
||||
)
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
// If ffprobe fails, assume no alpha (conservative approach)
|
||||
ctx.Warn(fmt.Sprintf("Failed to detect alpha channel in %s: %v", filepath.Base(filePath), err))
|
||||
return false
|
||||
}
|
||||
|
||||
outputStr := string(output)
|
||||
// Check pixel format - EXR with alpha typically has 'a' in the format name (e.g., gbrapf32le)
|
||||
// Also check for formats that explicitly indicate alpha
|
||||
hasAlpha := strings.Contains(outputStr, "pix_fmt=gbrap") ||
|
||||
strings.Contains(outputStr, "pix_fmt=rgba") ||
|
||||
strings.Contains(outputStr, "pix_fmt=yuva") ||
|
||||
strings.Contains(outputStr, "pix_fmt=abgr")
|
||||
|
||||
if hasAlpha {
|
||||
ctx.Info(fmt.Sprintf("Detected alpha channel in EXR file: %s", filepath.Base(filePath)))
|
||||
}
|
||||
|
||||
return hasAlpha
|
||||
}
|
||||
|
||||
// detectHDR checks if an EXR file contains HDR content using ffprobe
|
||||
func detectHDR(ctx *Context, filePath string) bool {
|
||||
// First, check if the pixel format supports HDR (32-bit float)
|
||||
cmd := exec.Command("ffprobe",
|
||||
"-v", "error",
|
||||
"-select_streams", "v:0",
|
||||
"-show_entries", "stream=pix_fmt",
|
||||
"-of", "default=noprint_wrappers=1:nokey=1",
|
||||
filePath,
|
||||
)
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
// If ffprobe fails, assume no HDR (conservative approach)
|
||||
ctx.Warn(fmt.Sprintf("Failed to detect HDR in %s: %v", filepath.Base(filePath), err))
|
||||
return false
|
||||
}
|
||||
|
||||
pixFmt := strings.TrimSpace(string(output))
|
||||
// EXR files with 32-bit float format (gbrpf32le, gbrapf32le) can contain HDR
|
||||
// Check if it's a 32-bit float format
|
||||
isFloat32 := strings.Contains(pixFmt, "f32") || strings.Contains(pixFmt, "f32le")
|
||||
|
||||
if !isFloat32 {
|
||||
// Not a float format, definitely not HDR
|
||||
return false
|
||||
}
|
||||
|
||||
// For 32-bit float EXR, sample pixels to check if values exceed SDR range (> 1.0)
|
||||
// Use ffmpeg to extract pixel statistics - check max pixel values
|
||||
// This is more efficient than sampling individual pixels
|
||||
cmd = exec.Command("ffmpeg",
|
||||
"-v", "error",
|
||||
"-i", filePath,
|
||||
"-vf", "signalstats",
|
||||
"-f", "null",
|
||||
"-",
|
||||
)
|
||||
|
||||
output, err = cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
// If stats extraction fails, try sampling a few pixels directly
|
||||
return detectHDRBySampling(ctx, filePath)
|
||||
}
|
||||
|
||||
// Check output for max pixel values
|
||||
outputStr := string(output)
|
||||
// Look for max values in the signalstats output
|
||||
// If we find values > 1.0, it's HDR
|
||||
if strings.Contains(outputStr, "MAX") {
|
||||
// Try to extract max values from signalstats output
|
||||
// Format is typically like: YMAX:1.234 UMAX:0.567 VMAX:0.890
|
||||
// For EXR (RGB), we need to check R, G, B channels
|
||||
// Since signalstats works on YUV, we'll use a different approach
|
||||
return detectHDRBySampling(ctx, filePath)
|
||||
}
|
||||
|
||||
// Fallback to pixel sampling
|
||||
return detectHDRBySampling(ctx, filePath)
|
||||
}
|
||||
|
||||
// detectHDRBySampling samples pixels from multiple regions to detect HDR content
|
||||
func detectHDRBySampling(ctx *Context, filePath string) bool {
|
||||
// Sample multiple 10x10 regions from different parts of the image
|
||||
// This gives us better coverage than a single sample
|
||||
sampleRegions := []string{
|
||||
"crop=10:10:iw/4:ih/4", // Top-left quadrant
|
||||
"crop=10:10:iw*3/4:ih/4", // Top-right quadrant
|
||||
"crop=10:10:iw/4:ih*3/4", // Bottom-left quadrant
|
||||
"crop=10:10:iw*3/4:ih*3/4", // Bottom-right quadrant
|
||||
"crop=10:10:iw/2:ih/2", // Center
|
||||
}
|
||||
|
||||
for _, region := range sampleRegions {
|
||||
cmd := exec.Command("ffmpeg",
|
||||
"-v", "error",
|
||||
"-i", filePath,
|
||||
"-vf", fmt.Sprintf("%s,scale=1:1", region),
|
||||
"-f", "rawvideo",
|
||||
"-pix_fmt", "gbrpf32le",
|
||||
"-",
|
||||
)
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
continue // Skip this region if sampling fails
|
||||
}
|
||||
|
||||
// Parse the float32 values (4 bytes per float, 3 channels RGB)
|
||||
if len(output) >= 12 { // At least 3 floats (RGB) = 12 bytes
|
||||
for i := 0; i < len(output)-11; i += 12 {
|
||||
// Read RGB values (little-endian float32)
|
||||
r := float32FromBytes(output[i : i+4])
|
||||
g := float32FromBytes(output[i+4 : i+8])
|
||||
b := float32FromBytes(output[i+8 : i+12])
|
||||
|
||||
// Check if any channel exceeds 1.0 (SDR range)
|
||||
if r > 1.0 || g > 1.0 || b > 1.0 {
|
||||
maxVal := max(r, max(g, b))
|
||||
ctx.Info(fmt.Sprintf("Detected HDR content in EXR file: %s (max value: %.2f)", filepath.Base(filePath), maxVal))
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we sampled multiple regions and none exceed 1.0, it's likely SDR content
|
||||
// But since it's 32-bit float format, user can still manually enable HDR if needed
|
||||
return false
|
||||
}
|
||||
|
||||
// float32FromBytes converts 4 bytes (little-endian) to float32
|
||||
func float32FromBytes(bytes []byte) float32 {
|
||||
if len(bytes) < 4 {
|
||||
return 0
|
||||
}
|
||||
bits := uint32(bytes[0]) | uint32(bytes[1])<<8 | uint32(bytes[2])<<16 | uint32(bytes[3])<<24
|
||||
return math.Float32frombits(bits)
|
||||
}
|
||||
|
||||
// max returns the maximum of two float32 values
|
||||
func max(a, b float32) float32 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func extractFrameNumber(filename string) int {
|
||||
parts := strings.Split(filepath.Base(filename), "_")
|
||||
if len(parts) < 2 {
|
||||
return 0
|
||||
}
|
||||
framePart := strings.Split(parts[1], ".")[0]
|
||||
var frameNum int
|
||||
fmt.Sscanf(framePart, "%d", &frameNum)
|
||||
return frameNum
|
||||
}
|
||||
|
||||
func checkFFmpegSizeError(output string) error {
|
||||
outputLower := strings.ToLower(output)
|
||||
|
||||
if strings.Contains(outputLower, "hardware does not support encoding at size") {
|
||||
constraintsMatch := regexp.MustCompile(`constraints:\s*width\s+(\d+)-(\d+)\s+height\s+(\d+)-(\d+)`).FindStringSubmatch(output)
|
||||
if len(constraintsMatch) == 5 {
|
||||
return fmt.Errorf("video frame size is outside hardware encoder limits. Hardware requires: width %s-%s, height %s-%s",
|
||||
constraintsMatch[1], constraintsMatch[2], constraintsMatch[3], constraintsMatch[4])
|
||||
}
|
||||
return fmt.Errorf("video frame size is outside hardware encoder limits")
|
||||
}
|
||||
|
||||
if strings.Contains(outputLower, "picture size") && strings.Contains(outputLower, "is invalid") {
|
||||
sizeMatch := regexp.MustCompile(`picture size\s+(\d+)x(\d+)`).FindStringSubmatch(output)
|
||||
if len(sizeMatch) == 3 {
|
||||
return fmt.Errorf("invalid video frame size: %sx%s", sizeMatch[1], sizeMatch[2])
|
||||
}
|
||||
return fmt.Errorf("invalid video frame size")
|
||||
}
|
||||
|
||||
if strings.Contains(outputLower, "error while opening encoder") &&
|
||||
(strings.Contains(outputLower, "width") || strings.Contains(outputLower, "height") || strings.Contains(outputLower, "size")) {
|
||||
sizeMatch := regexp.MustCompile(`at size\s+(\d+)x(\d+)`).FindStringSubmatch(output)
|
||||
if len(sizeMatch) == 3 {
|
||||
return fmt.Errorf("hardware encoder cannot encode frame size %sx%s", sizeMatch[1], sizeMatch[2])
|
||||
}
|
||||
return fmt.Errorf("hardware encoder error: frame size may be invalid")
|
||||
}
|
||||
|
||||
if strings.Contains(outputLower, "invalid") &&
|
||||
(strings.Contains(outputLower, "width") || strings.Contains(outputLower, "height") || strings.Contains(outputLower, "dimension")) {
|
||||
return fmt.Errorf("invalid frame dimensions detected")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
156
internal/runner/tasks/processor.go
Normal file
156
internal/runner/tasks/processor.go
Normal file
@@ -0,0 +1,156 @@
|
||||
// Package tasks provides task processing implementations.
|
||||
package tasks
|
||||
|
||||
import (
|
||||
"jiggablend/internal/runner/api"
|
||||
"jiggablend/internal/runner/blender"
|
||||
"jiggablend/internal/runner/encoding"
|
||||
"jiggablend/internal/runner/workspace"
|
||||
"jiggablend/pkg/executils"
|
||||
"jiggablend/pkg/types"
|
||||
)
|
||||
|
||||
// Processor handles a specific task type.
|
||||
type Processor interface {
|
||||
Process(ctx *Context) error
|
||||
}
|
||||
|
||||
// Context provides task execution context.
|
||||
type Context struct {
|
||||
TaskID int64
|
||||
JobID int64
|
||||
JobName string
|
||||
Frame int
|
||||
TaskType string
|
||||
WorkDir string
|
||||
JobToken string
|
||||
Metadata *types.BlendMetadata
|
||||
|
||||
Manager *api.ManagerClient
|
||||
JobConn *api.JobConnection
|
||||
Workspace *workspace.Manager
|
||||
Blender *blender.Manager
|
||||
Encoder *encoding.Selector
|
||||
Processes *executils.ProcessTracker
|
||||
}
|
||||
|
||||
// NewContext creates a new task context.
|
||||
func NewContext(
|
||||
taskID, jobID int64,
|
||||
jobName string,
|
||||
frame int,
|
||||
taskType string,
|
||||
workDir string,
|
||||
jobToken string,
|
||||
metadata *types.BlendMetadata,
|
||||
manager *api.ManagerClient,
|
||||
jobConn *api.JobConnection,
|
||||
ws *workspace.Manager,
|
||||
blenderMgr *blender.Manager,
|
||||
encoder *encoding.Selector,
|
||||
processes *executils.ProcessTracker,
|
||||
) *Context {
|
||||
return &Context{
|
||||
TaskID: taskID,
|
||||
JobID: jobID,
|
||||
JobName: jobName,
|
||||
Frame: frame,
|
||||
TaskType: taskType,
|
||||
WorkDir: workDir,
|
||||
JobToken: jobToken,
|
||||
Metadata: metadata,
|
||||
Manager: manager,
|
||||
JobConn: jobConn,
|
||||
Workspace: ws,
|
||||
Blender: blenderMgr,
|
||||
Encoder: encoder,
|
||||
Processes: processes,
|
||||
}
|
||||
}
|
||||
|
||||
// Log sends a log entry to the manager.
|
||||
func (c *Context) Log(level types.LogLevel, message string) {
|
||||
if c.JobConn != nil {
|
||||
c.JobConn.Log(c.TaskID, level, message)
|
||||
}
|
||||
}
|
||||
|
||||
// Info logs an info message.
|
||||
func (c *Context) Info(message string) {
|
||||
c.Log(types.LogLevelInfo, message)
|
||||
}
|
||||
|
||||
// Warn logs a warning message.
|
||||
func (c *Context) Warn(message string) {
|
||||
c.Log(types.LogLevelWarn, message)
|
||||
}
|
||||
|
||||
// Error logs an error message.
|
||||
func (c *Context) Error(message string) {
|
||||
c.Log(types.LogLevelError, message)
|
||||
}
|
||||
|
||||
// Progress sends a progress update.
|
||||
func (c *Context) Progress(progress float64) {
|
||||
if c.JobConn != nil {
|
||||
c.JobConn.Progress(c.TaskID, progress)
|
||||
}
|
||||
}
|
||||
|
||||
// OutputUploaded notifies that an output file was uploaded.
|
||||
func (c *Context) OutputUploaded(fileName string) {
|
||||
if c.JobConn != nil {
|
||||
c.JobConn.OutputUploaded(c.TaskID, fileName)
|
||||
}
|
||||
}
|
||||
|
||||
// Complete sends task completion.
|
||||
func (c *Context) Complete(success bool, errorMsg error) {
|
||||
if c.JobConn != nil {
|
||||
c.JobConn.Complete(c.TaskID, success, errorMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// GetOutputFormat returns the output format from metadata or default.
|
||||
func (c *Context) GetOutputFormat() string {
|
||||
if c.Metadata != nil && c.Metadata.RenderSettings.OutputFormat != "" {
|
||||
return c.Metadata.RenderSettings.OutputFormat
|
||||
}
|
||||
return "PNG"
|
||||
}
|
||||
|
||||
// GetFrameRate returns the frame rate from metadata or default.
|
||||
func (c *Context) GetFrameRate() float64 {
|
||||
if c.Metadata != nil && c.Metadata.RenderSettings.FrameRate > 0 {
|
||||
return c.Metadata.RenderSettings.FrameRate
|
||||
}
|
||||
return 24.0
|
||||
}
|
||||
|
||||
// GetBlenderVersion returns the Blender version from metadata.
|
||||
func (c *Context) GetBlenderVersion() string {
|
||||
if c.Metadata != nil {
|
||||
return c.Metadata.BlenderVersion
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ShouldUnhideObjects returns whether to unhide objects.
|
||||
func (c *Context) ShouldUnhideObjects() bool {
|
||||
return c.Metadata != nil && c.Metadata.UnhideObjects != nil && *c.Metadata.UnhideObjects
|
||||
}
|
||||
|
||||
// ShouldEnableExecution returns whether to enable auto-execution.
|
||||
func (c *Context) ShouldEnableExecution() bool {
|
||||
return c.Metadata != nil && c.Metadata.EnableExecution != nil && *c.Metadata.EnableExecution
|
||||
}
|
||||
|
||||
// ShouldPreserveHDR returns whether to preserve HDR range for EXR encoding.
|
||||
func (c *Context) ShouldPreserveHDR() bool {
|
||||
return c.Metadata != nil && c.Metadata.PreserveHDR != nil && *c.Metadata.PreserveHDR
|
||||
}
|
||||
|
||||
// ShouldPreserveAlpha returns whether to preserve alpha channel for EXR encoding.
|
||||
func (c *Context) ShouldPreserveAlpha() bool {
|
||||
return c.Metadata != nil && c.Metadata.PreserveAlpha != nil && *c.Metadata.PreserveAlpha
|
||||
}
|
||||
301
internal/runner/tasks/render.go
Normal file
301
internal/runner/tasks/render.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package tasks
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"jiggablend/internal/runner/blender"
|
||||
"jiggablend/internal/runner/workspace"
|
||||
"jiggablend/pkg/scripts"
|
||||
"jiggablend/pkg/types"
|
||||
)
|
||||
|
||||
// RenderProcessor handles render tasks.
|
||||
type RenderProcessor struct{}
|
||||
|
||||
// NewRenderProcessor creates a new render processor.
|
||||
func NewRenderProcessor() *RenderProcessor {
|
||||
return &RenderProcessor{}
|
||||
}
|
||||
|
||||
// Process executes a render task.
|
||||
func (p *RenderProcessor) Process(ctx *Context) error {
|
||||
ctx.Info(fmt.Sprintf("Starting task: job %d, frame %d, format: %s",
|
||||
ctx.JobID, ctx.Frame, ctx.GetOutputFormat()))
|
||||
log.Printf("Processing task %d: job %d, frame %d", ctx.TaskID, ctx.JobID, ctx.Frame)
|
||||
|
||||
// Find .blend file
|
||||
blendFile, err := workspace.FindFirstBlendFile(ctx.WorkDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find blend file: %w", err)
|
||||
}
|
||||
|
||||
// Get Blender binary
|
||||
blenderBinary := "blender"
|
||||
if version := ctx.GetBlenderVersion(); version != "" {
|
||||
ctx.Info(fmt.Sprintf("Job requires Blender %s", version))
|
||||
binaryPath, err := ctx.Blender.GetBinaryPath(version)
|
||||
if err != nil {
|
||||
ctx.Warn(fmt.Sprintf("Could not get Blender %s, using system blender: %v", version, err))
|
||||
} else {
|
||||
blenderBinary = binaryPath
|
||||
ctx.Info(fmt.Sprintf("Using Blender binary: %s", blenderBinary))
|
||||
}
|
||||
} else {
|
||||
ctx.Info("No Blender version specified, using system blender")
|
||||
}
|
||||
|
||||
// Create output directory
|
||||
outputDir := filepath.Join(ctx.WorkDir, "output")
|
||||
if err := os.MkdirAll(outputDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create output directory: %w", err)
|
||||
}
|
||||
|
||||
// Create home directory for Blender inside workspace
|
||||
blenderHome := filepath.Join(ctx.WorkDir, "home")
|
||||
if err := os.MkdirAll(blenderHome, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create Blender home directory: %w", err)
|
||||
}
|
||||
|
||||
// Determine render format
|
||||
outputFormat := ctx.GetOutputFormat()
|
||||
renderFormat := outputFormat
|
||||
if outputFormat == "EXR_264_MP4" || outputFormat == "EXR_AV1_MP4" || outputFormat == "EXR_VP9_WEBM" {
|
||||
renderFormat = "EXR" // Use EXR for maximum quality
|
||||
}
|
||||
|
||||
// Create render script
|
||||
if err := p.createRenderScript(ctx, renderFormat); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Render
|
||||
ctx.Info(fmt.Sprintf("Starting Blender render for frame %d...", ctx.Frame))
|
||||
if err := p.runBlender(ctx, blenderBinary, blendFile, outputDir, renderFormat, blenderHome); err != nil {
|
||||
ctx.Error(fmt.Sprintf("Blender render failed: %v", err))
|
||||
return err
|
||||
}
|
||||
|
||||
// Verify output
|
||||
if _, err := p.findOutputFile(ctx, outputDir, renderFormat); err != nil {
|
||||
ctx.Error(fmt.Sprintf("Output verification failed: %v", err))
|
||||
return err
|
||||
}
|
||||
ctx.Info(fmt.Sprintf("Blender render completed for frame %d", ctx.Frame))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *RenderProcessor) createRenderScript(ctx *Context, renderFormat string) error {
|
||||
formatFilePath := filepath.Join(ctx.WorkDir, "output_format.txt")
|
||||
renderSettingsFilePath := filepath.Join(ctx.WorkDir, "render_settings.json")
|
||||
|
||||
// Build unhide code conditionally
|
||||
unhideCode := ""
|
||||
if ctx.ShouldUnhideObjects() {
|
||||
unhideCode = scripts.UnhideObjects
|
||||
}
|
||||
|
||||
// Load template and replace placeholders
|
||||
scriptContent := scripts.RenderBlenderTemplate
|
||||
scriptContent = strings.ReplaceAll(scriptContent, "{{UNHIDE_CODE}}", unhideCode)
|
||||
scriptContent = strings.ReplaceAll(scriptContent, "{{FORMAT_FILE_PATH}}", fmt.Sprintf("%q", formatFilePath))
|
||||
scriptContent = strings.ReplaceAll(scriptContent, "{{RENDER_SETTINGS_FILE}}", fmt.Sprintf("%q", renderSettingsFilePath))
|
||||
|
||||
scriptPath := filepath.Join(ctx.WorkDir, "enable_gpu.py")
|
||||
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0644); err != nil {
|
||||
errMsg := fmt.Sprintf("failed to create GPU enable script: %v", err)
|
||||
ctx.Error(errMsg)
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
// Write output format
|
||||
outputFormat := ctx.GetOutputFormat()
|
||||
ctx.Info(fmt.Sprintf("Writing output format '%s' to format file", outputFormat))
|
||||
if err := os.WriteFile(formatFilePath, []byte(outputFormat), 0644); err != nil {
|
||||
errMsg := fmt.Sprintf("failed to create format file: %v", err)
|
||||
ctx.Error(errMsg)
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
// Write render settings if available
|
||||
if ctx.Metadata != nil && ctx.Metadata.RenderSettings.EngineSettings != nil {
|
||||
settingsJSON, err := json.Marshal(ctx.Metadata.RenderSettings)
|
||||
if err == nil {
|
||||
if err := os.WriteFile(renderSettingsFilePath, settingsJSON, 0644); err != nil {
|
||||
ctx.Warn(fmt.Sprintf("Failed to write render settings file: %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *RenderProcessor) runBlender(ctx *Context, blenderBinary, blendFile, outputDir, renderFormat, blenderHome string) error {
|
||||
scriptPath := filepath.Join(ctx.WorkDir, "enable_gpu.py")
|
||||
|
||||
args := []string{"-b", blendFile, "--python", scriptPath}
|
||||
if ctx.ShouldEnableExecution() {
|
||||
args = append(args, "--enable-autoexec")
|
||||
}
|
||||
|
||||
// Output pattern
|
||||
outputPattern := filepath.Join(outputDir, fmt.Sprintf("frame_####.%s", strings.ToLower(renderFormat)))
|
||||
outputAbsPattern, _ := filepath.Abs(outputPattern)
|
||||
args = append(args, "-o", outputAbsPattern)
|
||||
|
||||
args = append(args, "-f", fmt.Sprintf("%d", ctx.Frame))
|
||||
|
||||
// Wrap with xvfb-run
|
||||
xvfbArgs := []string{"-a", "-s", "-screen 0 800x600x24", blenderBinary}
|
||||
xvfbArgs = append(xvfbArgs, args...)
|
||||
cmd := exec.Command("xvfb-run", xvfbArgs...)
|
||||
cmd.Dir = ctx.WorkDir
|
||||
|
||||
// Set up environment with custom HOME directory
|
||||
env := os.Environ()
|
||||
// Remove existing HOME if present and add our custom one
|
||||
newEnv := make([]string, 0, len(env)+1)
|
||||
for _, e := range env {
|
||||
if !strings.HasPrefix(e, "HOME=") {
|
||||
newEnv = append(newEnv, e)
|
||||
}
|
||||
}
|
||||
newEnv = append(newEnv, fmt.Sprintf("HOME=%s", blenderHome))
|
||||
cmd.Env = newEnv
|
||||
|
||||
// Set up pipes
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start blender: %w", err)
|
||||
}
|
||||
|
||||
// Track process
|
||||
ctx.Processes.Track(ctx.TaskID, cmd)
|
||||
defer ctx.Processes.Untrack(ctx.TaskID)
|
||||
|
||||
// Stream stdout
|
||||
stdoutDone := make(chan bool)
|
||||
go func() {
|
||||
defer close(stdoutDone)
|
||||
scanner := bufio.NewScanner(stdoutPipe)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line != "" {
|
||||
shouldFilter, logLevel := blender.FilterLog(line)
|
||||
if !shouldFilter {
|
||||
ctx.Log(logLevel, line)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Stream stderr
|
||||
stderrDone := make(chan bool)
|
||||
go func() {
|
||||
defer close(stderrDone)
|
||||
scanner := bufio.NewScanner(stderrPipe)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line != "" {
|
||||
shouldFilter, logLevel := blender.FilterLog(line)
|
||||
if !shouldFilter {
|
||||
if logLevel == types.LogLevelInfo {
|
||||
logLevel = types.LogLevelWarn
|
||||
}
|
||||
ctx.Log(logLevel, line)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for completion
|
||||
err = cmd.Wait()
|
||||
<-stdoutDone
|
||||
<-stderrDone
|
||||
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
if exitErr.ExitCode() == 137 {
|
||||
return errors.New("Blender was killed due to excessive memory usage (OOM)")
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("blender failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *RenderProcessor) findOutputFile(ctx *Context, outputDir, renderFormat string) (string, error) {
|
||||
entries, err := os.ReadDir(outputDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read output directory: %w", err)
|
||||
}
|
||||
|
||||
ctx.Info("Checking output directory for files...")
|
||||
|
||||
// Try exact match first
|
||||
expectedFile := filepath.Join(outputDir, fmt.Sprintf("frame_%04d.%s", ctx.Frame, strings.ToLower(renderFormat)))
|
||||
if _, err := os.Stat(expectedFile); err == nil {
|
||||
ctx.Info(fmt.Sprintf("Found output file: %s", filepath.Base(expectedFile)))
|
||||
return expectedFile, nil
|
||||
}
|
||||
|
||||
// Try without zero padding
|
||||
altFile := filepath.Join(outputDir, fmt.Sprintf("frame_%d.%s", ctx.Frame, strings.ToLower(renderFormat)))
|
||||
if _, err := os.Stat(altFile); err == nil {
|
||||
ctx.Info(fmt.Sprintf("Found output file: %s", filepath.Base(altFile)))
|
||||
return altFile, nil
|
||||
}
|
||||
|
||||
// Try just frame number
|
||||
altFile2 := filepath.Join(outputDir, fmt.Sprintf("%04d.%s", ctx.Frame, strings.ToLower(renderFormat)))
|
||||
if _, err := os.Stat(altFile2); err == nil {
|
||||
ctx.Info(fmt.Sprintf("Found output file: %s", filepath.Base(altFile2)))
|
||||
return altFile2, nil
|
||||
}
|
||||
|
||||
// Search through all files
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
fileName := entry.Name()
|
||||
if strings.Contains(fileName, "%04d") || strings.Contains(fileName, "%d") {
|
||||
ctx.Warn(fmt.Sprintf("Skipping file with literal pattern: %s", fileName))
|
||||
continue
|
||||
}
|
||||
frameStr := fmt.Sprintf("%d", ctx.Frame)
|
||||
frameStrPadded := fmt.Sprintf("%04d", ctx.Frame)
|
||||
if strings.Contains(fileName, frameStrPadded) ||
|
||||
(strings.Contains(fileName, frameStr) && strings.HasSuffix(strings.ToLower(fileName), strings.ToLower(renderFormat))) {
|
||||
outputFile := filepath.Join(outputDir, fileName)
|
||||
ctx.Info(fmt.Sprintf("Found output file: %s", fileName))
|
||||
return outputFile, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Not found
|
||||
fileList := []string{}
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
fileList = append(fileList, entry.Name())
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("output file not found: %s\nFiles in output directory: %v", expectedFile, fileList)
|
||||
}
|
||||
146
internal/runner/workspace/archive.go
Normal file
146
internal/runner/workspace/archive.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package workspace
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ExtractTar extracts a tar archive from a reader to a directory.
|
||||
func ExtractTar(reader io.Reader, destDir string) error {
|
||||
if err := os.MkdirAll(destDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create destination directory: %w", err)
|
||||
}
|
||||
|
||||
tarReader := tar.NewReader(reader)
|
||||
|
||||
for {
|
||||
header, err := tarReader.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read tar header: %w", err)
|
||||
}
|
||||
|
||||
// Sanitize path to prevent directory traversal
|
||||
targetPath := filepath.Join(destDir, header.Name)
|
||||
if !strings.HasPrefix(filepath.Clean(targetPath), filepath.Clean(destDir)+string(os.PathSeparator)) {
|
||||
return fmt.Errorf("invalid file path in tar: %s", header.Name)
|
||||
}
|
||||
|
||||
switch header.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
case tar.TypeReg:
|
||||
if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
|
||||
return fmt.Errorf("failed to create parent directory: %w", err)
|
||||
}
|
||||
|
||||
outFile, err := os.Create(targetPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create file: %w", err)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(outFile, tarReader); err != nil {
|
||||
outFile.Close()
|
||||
return fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
outFile.Close()
|
||||
|
||||
if err := os.Chmod(targetPath, os.FileMode(header.Mode)); err != nil {
|
||||
log.Printf("Warning: failed to set file permissions: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExtractTarStripPrefix extracts a tar archive, stripping the top-level directory.
|
||||
// Useful for Blender archives like "blender-4.2.3-linux-x64/".
|
||||
func ExtractTarStripPrefix(reader io.Reader, destDir string) error {
|
||||
if err := os.MkdirAll(destDir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tarReader := tar.NewReader(reader)
|
||||
stripPrefix := ""
|
||||
|
||||
for {
|
||||
header, err := tarReader.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Determine strip prefix from first entry (e.g., "blender-4.2.3-linux-x64/")
|
||||
if stripPrefix == "" {
|
||||
parts := strings.SplitN(header.Name, "/", 2)
|
||||
if len(parts) > 0 {
|
||||
stripPrefix = parts[0] + "/"
|
||||
}
|
||||
}
|
||||
|
||||
// Strip the top-level directory
|
||||
name := strings.TrimPrefix(header.Name, stripPrefix)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
targetPath := filepath.Join(destDir, name)
|
||||
|
||||
switch header.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case tar.TypeReg:
|
||||
if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
outFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.Copy(outFile, tarReader); err != nil {
|
||||
outFile.Close()
|
||||
return err
|
||||
}
|
||||
outFile.Close()
|
||||
|
||||
case tar.TypeSymlink:
|
||||
if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
os.Remove(targetPath) // Remove existing symlink if present
|
||||
if err := os.Symlink(header.Linkname, targetPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExtractTarFile extracts a tar file to a directory.
|
||||
func ExtractTarFile(tarPath, destDir string) error {
|
||||
file, err := os.Open(tarPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open tar file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
return ExtractTar(file, destDir)
|
||||
}
|
||||
|
||||
217
internal/runner/workspace/workspace.go
Normal file
217
internal/runner/workspace/workspace.go
Normal file
@@ -0,0 +1,217 @@
|
||||
// Package workspace manages runner workspace directories.
|
||||
package workspace
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Manager handles workspace directory operations.
|
||||
type Manager struct {
|
||||
baseDir string
|
||||
runnerName string
|
||||
}
|
||||
|
||||
// NewManager creates a new workspace manager.
|
||||
func NewManager(runnerName string) *Manager {
|
||||
m := &Manager{
|
||||
runnerName: sanitizeName(runnerName),
|
||||
}
|
||||
m.init()
|
||||
return m
|
||||
}
|
||||
|
||||
func sanitizeName(name string) string {
|
||||
name = strings.ReplaceAll(name, " ", "_")
|
||||
name = strings.ReplaceAll(name, "/", "_")
|
||||
name = strings.ReplaceAll(name, "\\", "_")
|
||||
name = strings.ReplaceAll(name, ":", "_")
|
||||
return name
|
||||
}
|
||||
|
||||
func (m *Manager) init() {
|
||||
// Prefer current directory if writable, otherwise use temp
|
||||
baseDir := os.TempDir()
|
||||
if cwd, err := os.Getwd(); err == nil {
|
||||
baseDir = cwd
|
||||
}
|
||||
|
||||
m.baseDir = filepath.Join(baseDir, "jiggablend-workspaces", m.runnerName)
|
||||
if err := os.MkdirAll(m.baseDir, 0755); err != nil {
|
||||
log.Printf("Warning: Failed to create workspace directory %s: %v", m.baseDir, err)
|
||||
// Fallback to temp directory
|
||||
m.baseDir = filepath.Join(os.TempDir(), "jiggablend-workspaces", m.runnerName)
|
||||
if err := os.MkdirAll(m.baseDir, 0755); err != nil {
|
||||
log.Printf("Error: Failed to create fallback workspace directory: %v", err)
|
||||
// Last resort
|
||||
m.baseDir = filepath.Join(os.TempDir(), "jiggablend-runner")
|
||||
os.MkdirAll(m.baseDir, 0755)
|
||||
}
|
||||
}
|
||||
log.Printf("Runner workspace initialized at: %s", m.baseDir)
|
||||
}
|
||||
|
||||
// BaseDir returns the base workspace directory.
|
||||
func (m *Manager) BaseDir() string {
|
||||
return m.baseDir
|
||||
}
|
||||
|
||||
// JobDir returns the directory for a specific job.
|
||||
func (m *Manager) JobDir(jobID int64) string {
|
||||
return filepath.Join(m.baseDir, fmt.Sprintf("job-%d", jobID))
|
||||
}
|
||||
|
||||
// VideoDir returns the directory for encoding.
|
||||
func (m *Manager) VideoDir(jobID int64) string {
|
||||
return filepath.Join(m.baseDir, fmt.Sprintf("job-%d-video", jobID))
|
||||
}
|
||||
|
||||
// BlenderDir returns the directory for Blender installations.
|
||||
func (m *Manager) BlenderDir() string {
|
||||
return filepath.Join(m.baseDir, "blender-versions")
|
||||
}
|
||||
|
||||
// CreateJobDir creates and returns the job directory.
|
||||
func (m *Manager) CreateJobDir(jobID int64) (string, error) {
|
||||
dir := m.JobDir(jobID)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create job directory: %w", err)
|
||||
}
|
||||
return dir, nil
|
||||
}
|
||||
|
||||
// CreateVideoDir creates and returns the encode directory.
|
||||
func (m *Manager) CreateVideoDir(jobID int64) (string, error) {
|
||||
dir := m.VideoDir(jobID)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create video directory: %w", err)
|
||||
}
|
||||
return dir, nil
|
||||
}
|
||||
|
||||
// CleanupJobDir removes a job directory.
|
||||
func (m *Manager) CleanupJobDir(jobID int64) error {
|
||||
dir := m.JobDir(jobID)
|
||||
return os.RemoveAll(dir)
|
||||
}
|
||||
|
||||
// CleanupVideoDir removes an encode directory.
|
||||
func (m *Manager) CleanupVideoDir(jobID int64) error {
|
||||
dir := m.VideoDir(jobID)
|
||||
return os.RemoveAll(dir)
|
||||
}
|
||||
|
||||
// Cleanup removes the entire workspace directory.
|
||||
func (m *Manager) Cleanup() {
|
||||
if m.baseDir != "" {
|
||||
log.Printf("Cleaning up workspace directory: %s", m.baseDir)
|
||||
if err := os.RemoveAll(m.baseDir); err != nil {
|
||||
log.Printf("Warning: Failed to remove workspace directory %s: %v", m.baseDir, err)
|
||||
} else {
|
||||
log.Printf("Successfully removed workspace directory: %s", m.baseDir)
|
||||
}
|
||||
}
|
||||
|
||||
// Also clean up any orphaned jiggablend directories
|
||||
cleanupOrphanedWorkspaces()
|
||||
}
|
||||
|
||||
// cleanupOrphanedWorkspaces removes any jiggablend workspace directories
|
||||
// that might be left behind from previous runs or crashes.
|
||||
func cleanupOrphanedWorkspaces() {
|
||||
log.Printf("Cleaning up orphaned jiggablend workspace directories...")
|
||||
|
||||
dirsToCheck := []string{".", os.TempDir()}
|
||||
for _, baseDir := range dirsToCheck {
|
||||
workspaceDir := filepath.Join(baseDir, "jiggablend-workspaces")
|
||||
if _, err := os.Stat(workspaceDir); err == nil {
|
||||
log.Printf("Removing orphaned workspace directory: %s", workspaceDir)
|
||||
if err := os.RemoveAll(workspaceDir); err != nil {
|
||||
log.Printf("Warning: Failed to remove workspace directory %s: %v", workspaceDir, err)
|
||||
} else {
|
||||
log.Printf("Successfully removed workspace directory: %s", workspaceDir)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FindBlendFiles finds all .blend files in a directory.
|
||||
func FindBlendFiles(dir string) ([]string, error) {
|
||||
var blendFiles []string
|
||||
|
||||
err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".blend") {
|
||||
// Check it's not a Blender save file (.blend1, .blend2, etc.)
|
||||
lower := strings.ToLower(info.Name())
|
||||
idx := strings.LastIndex(lower, ".blend")
|
||||
if idx != -1 {
|
||||
suffix := lower[idx+len(".blend"):]
|
||||
isSaveFile := false
|
||||
if len(suffix) > 0 {
|
||||
isSaveFile = true
|
||||
for _, r := range suffix {
|
||||
if r < '0' || r > '9' {
|
||||
isSaveFile = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !isSaveFile {
|
||||
relPath, _ := filepath.Rel(dir, path)
|
||||
blendFiles = append(blendFiles, relPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return blendFiles, err
|
||||
}
|
||||
|
||||
// FindFirstBlendFile finds the first .blend file in a directory.
|
||||
func FindFirstBlendFile(dir string) (string, error) {
|
||||
var blendFile string
|
||||
|
||||
err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".blend") {
|
||||
lower := strings.ToLower(info.Name())
|
||||
idx := strings.LastIndex(lower, ".blend")
|
||||
if idx != -1 {
|
||||
suffix := lower[idx+len(".blend"):]
|
||||
isSaveFile := false
|
||||
if len(suffix) > 0 {
|
||||
isSaveFile = true
|
||||
for _, r := range suffix {
|
||||
if r < '0' || r > '9' {
|
||||
isSaveFile = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !isSaveFile {
|
||||
blendFile = path
|
||||
return filepath.SkipAll
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if blendFile == "" {
|
||||
return "", fmt.Errorf("no .blend file found in %s", dir)
|
||||
}
|
||||
return blendFile, nil
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"archive/zip"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -29,6 +31,7 @@ func (s *Storage) init() error {
|
||||
s.basePath,
|
||||
s.uploadsPath(),
|
||||
s.outputsPath(),
|
||||
s.tempPath(),
|
||||
}
|
||||
|
||||
for _, dir := range dirs {
|
||||
@@ -40,6 +43,28 @@ func (s *Storage) init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// tempPath returns the path for temporary files
|
||||
func (s *Storage) tempPath() string {
|
||||
return filepath.Join(s.basePath, "temp")
|
||||
}
|
||||
|
||||
// BasePath returns the storage base path (for cleanup tasks)
|
||||
func (s *Storage) BasePath() string {
|
||||
return s.basePath
|
||||
}
|
||||
|
||||
// TempDir creates a temporary directory under the storage base path
|
||||
// Returns the path to the temporary directory
|
||||
func (s *Storage) TempDir(pattern string) (string, error) {
|
||||
// Ensure temp directory exists
|
||||
if err := os.MkdirAll(s.tempPath(), 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
|
||||
// Create temp directory under storage base path
|
||||
return os.MkdirTemp(s.tempPath(), pattern)
|
||||
}
|
||||
|
||||
// uploadsPath returns the path for uploads
|
||||
func (s *Storage) uploadsPath() string {
|
||||
return filepath.Join(s.basePath, "uploads")
|
||||
@@ -140,6 +165,13 @@ func (s *Storage) GetFileSize(filePath string) (int64, error) {
|
||||
// ExtractZip extracts a ZIP file to the destination directory
|
||||
// Returns a list of all extracted file paths
|
||||
func (s *Storage) ExtractZip(zipPath, destDir string) ([]string, error) {
|
||||
log.Printf("Extracting ZIP archive: %s -> %s", zipPath, destDir)
|
||||
|
||||
// Ensure destination directory exists
|
||||
if err := os.MkdirAll(destDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create destination directory: %w", err)
|
||||
}
|
||||
|
||||
r, err := zip.OpenReader(zipPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open ZIP file: %w", err)
|
||||
@@ -147,12 +179,20 @@ func (s *Storage) ExtractZip(zipPath, destDir string) ([]string, error) {
|
||||
defer r.Close()
|
||||
|
||||
var extractedFiles []string
|
||||
fileCount := 0
|
||||
dirCount := 0
|
||||
|
||||
log.Printf("ZIP contains %d entries", len(r.File))
|
||||
|
||||
for _, f := range r.File {
|
||||
// Sanitize file path to prevent directory traversal
|
||||
destPath := filepath.Join(destDir, f.Name)
|
||||
if !strings.HasPrefix(destPath, filepath.Clean(destDir)+string(os.PathSeparator)) {
|
||||
return nil, fmt.Errorf("invalid file path in ZIP: %s", f.Name)
|
||||
|
||||
cleanDestPath := filepath.Clean(destPath)
|
||||
cleanDestDir := filepath.Clean(destDir)
|
||||
if !strings.HasPrefix(cleanDestPath, cleanDestDir+string(os.PathSeparator)) && cleanDestPath != cleanDestDir {
|
||||
log.Printf("ERROR: Invalid file path in ZIP - target: %s, destDir: %s", cleanDestPath, cleanDestDir)
|
||||
return nil, fmt.Errorf("invalid file path in ZIP: %s (target: %s, destDir: %s)", f.Name, cleanDestPath, cleanDestDir)
|
||||
}
|
||||
|
||||
// Create directory structure
|
||||
@@ -160,6 +200,7 @@ func (s *Storage) ExtractZip(zipPath, destDir string) ([]string, error) {
|
||||
if err := os.MkdirAll(destPath, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
dirCount++
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -189,8 +230,381 @@ func (s *Storage) ExtractZip(zipPath, destDir string) ([]string, error) {
|
||||
}
|
||||
|
||||
extractedFiles = append(extractedFiles, destPath)
|
||||
fileCount++
|
||||
}
|
||||
|
||||
log.Printf("ZIP extraction complete: %d files, %d directories extracted to %s", fileCount, dirCount, destDir)
|
||||
return extractedFiles, nil
|
||||
}
|
||||
|
||||
// findCommonPrefix finds the common leading directory prefix if all paths share the same first-level directory
|
||||
// Returns the prefix to strip (with trailing slash) or empty string if no common prefix
|
||||
func findCommonPrefix(relPaths []string) string {
|
||||
if len(relPaths) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Get the first path component of each path
|
||||
firstComponents := make([]string, 0, len(relPaths))
|
||||
for _, path := range relPaths {
|
||||
parts := strings.Split(filepath.ToSlash(path), "/")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
firstComponents = append(firstComponents, parts[0])
|
||||
} else {
|
||||
// If any path is at root level, no common prefix
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// Check if all first components are the same
|
||||
if len(firstComponents) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
commonFirst := firstComponents[0]
|
||||
for _, comp := range firstComponents {
|
||||
if comp != commonFirst {
|
||||
// Not all paths share the same first directory
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// All paths share the same first directory - return it with trailing slash
|
||||
return commonFirst + "/"
|
||||
}
|
||||
|
||||
// isBlenderSaveFile checks if a filename is a Blender save file (.blend1, .blend2, etc.)
|
||||
// Returns true for files like "file.blend1", "file.blend2", but false for "file.blend"
|
||||
func isBlenderSaveFile(filename string) bool {
|
||||
lower := strings.ToLower(filename)
|
||||
// Check if it ends with .blend followed by one or more digits
|
||||
// Pattern: *.blend[digits]
|
||||
if !strings.HasSuffix(lower, ".blend") {
|
||||
// Doesn't end with .blend, check if it ends with .blend + digits
|
||||
idx := strings.LastIndex(lower, ".blend")
|
||||
if idx == -1 {
|
||||
return false
|
||||
}
|
||||
// Check if there are digits after .blend
|
||||
suffix := lower[idx+len(".blend"):]
|
||||
if len(suffix) == 0 {
|
||||
return false
|
||||
}
|
||||
// All remaining characters must be digits
|
||||
for _, r := range suffix {
|
||||
if r < '0' || r > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
// Ends with .blend exactly - this is a regular blend file, not a save file
|
||||
return false
|
||||
}
|
||||
|
||||
// CreateJobContext creates a tar archive containing all job input files
|
||||
// Filters out Blender save files (.blend1, .blend2, etc.)
|
||||
// Uses temporary directories and streaming to handle large files efficiently
|
||||
func (s *Storage) CreateJobContext(jobID int64) (string, error) {
|
||||
jobPath := s.JobPath(jobID)
|
||||
contextPath := filepath.Join(jobPath, "context.tar")
|
||||
|
||||
// Create temporary directory for staging
|
||||
tmpDir, err := os.MkdirTemp("", "jiggablend-context-*")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temporary directory: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Collect all files from job directory, excluding the context file itself and Blender save files
|
||||
var filesToInclude []string
|
||||
err = filepath.Walk(jobPath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip directories
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip the context file itself if it exists
|
||||
if path == contextPath {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip Blender save files
|
||||
if isBlenderSaveFile(info.Name()) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get relative path from job directory
|
||||
relPath, err := filepath.Rel(jobPath, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Sanitize path - ensure it doesn't escape the job directory
|
||||
cleanRelPath := filepath.Clean(relPath)
|
||||
if strings.HasPrefix(cleanRelPath, "..") {
|
||||
return fmt.Errorf("invalid file path: %s", relPath)
|
||||
}
|
||||
|
||||
filesToInclude = append(filesToInclude, path)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to walk job directory: %w", err)
|
||||
}
|
||||
|
||||
if len(filesToInclude) == 0 {
|
||||
return "", fmt.Errorf("no files found to include in context")
|
||||
}
|
||||
|
||||
// Create the tar file using streaming
|
||||
contextFile, err := os.Create(contextPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create context file: %w", err)
|
||||
}
|
||||
defer contextFile.Close()
|
||||
|
||||
tarWriter := tar.NewWriter(contextFile)
|
||||
defer tarWriter.Close()
|
||||
|
||||
// Add each file to the tar archive
|
||||
for _, filePath := range filesToInclude {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to open file %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Use a function closure to ensure file is closed even on error
|
||||
err = func() error {
|
||||
defer file.Close()
|
||||
|
||||
info, err := file.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stat file %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Get relative path for tar header
|
||||
relPath, err := filepath.Rel(jobPath, filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get relative path for %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Normalize path separators for tar (use forward slashes)
|
||||
tarPath := filepath.ToSlash(relPath)
|
||||
|
||||
// Create tar header
|
||||
header, err := tar.FileInfoHeader(info, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create tar header for %s: %w", filePath, err)
|
||||
}
|
||||
header.Name = tarPath
|
||||
|
||||
// Write header
|
||||
if err := tarWriter.WriteHeader(header); err != nil {
|
||||
return fmt.Errorf("failed to write tar header for %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Copy file contents using streaming
|
||||
if _, err := io.Copy(tarWriter, file); err != nil {
|
||||
return fmt.Errorf("failed to write file %s to tar: %w", filePath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure all data is flushed
|
||||
if err := tarWriter.Close(); err != nil {
|
||||
return "", fmt.Errorf("failed to close tar writer: %w", err)
|
||||
}
|
||||
if err := contextFile.Close(); err != nil {
|
||||
return "", fmt.Errorf("failed to close context file: %w", err)
|
||||
}
|
||||
|
||||
return contextPath, nil
|
||||
}
|
||||
|
||||
// CreateJobContextFromDir creates a context archive (tar) from files in a source directory
|
||||
// This is used during upload to immediately create the context archive as the primary artifact
|
||||
// excludeFiles is a set of relative paths (from sourceDir) to exclude from the context
|
||||
func (s *Storage) CreateJobContextFromDir(sourceDir string, jobID int64, excludeFiles ...string) (string, error) {
|
||||
jobPath := s.JobPath(jobID)
|
||||
contextPath := filepath.Join(jobPath, "context.tar")
|
||||
|
||||
// Ensure job directory exists
|
||||
if err := os.MkdirAll(jobPath, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create job directory: %w", err)
|
||||
}
|
||||
|
||||
// Build set of files to exclude (normalize paths)
|
||||
excludeSet := make(map[string]bool)
|
||||
for _, excludeFile := range excludeFiles {
|
||||
// Normalize the exclude path
|
||||
excludePath := filepath.Clean(excludeFile)
|
||||
excludeSet[excludePath] = true
|
||||
// Also add with forward slash for cross-platform compatibility
|
||||
excludeSet[filepath.ToSlash(excludePath)] = true
|
||||
}
|
||||
|
||||
// Collect all files from source directory, excluding Blender save files and excluded files
|
||||
var filesToInclude []string
|
||||
err := filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip directories
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip Blender save files
|
||||
if isBlenderSaveFile(info.Name()) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get relative path from source directory
|
||||
relPath, err := filepath.Rel(sourceDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Sanitize path - ensure it doesn't escape the source directory
|
||||
cleanRelPath := filepath.Clean(relPath)
|
||||
if strings.HasPrefix(cleanRelPath, "..") {
|
||||
return fmt.Errorf("invalid file path: %s", relPath)
|
||||
}
|
||||
|
||||
// Check if this file should be excluded
|
||||
if excludeSet[cleanRelPath] || excludeSet[filepath.ToSlash(cleanRelPath)] {
|
||||
return nil
|
||||
}
|
||||
|
||||
filesToInclude = append(filesToInclude, path)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to walk source directory: %w", err)
|
||||
}
|
||||
|
||||
if len(filesToInclude) == 0 {
|
||||
return "", fmt.Errorf("no files found to include in context archive")
|
||||
}
|
||||
|
||||
// Collect relative paths to find common prefix
|
||||
relPaths := make([]string, 0, len(filesToInclude))
|
||||
for _, filePath := range filesToInclude {
|
||||
relPath, err := filepath.Rel(sourceDir, filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get relative path for %s: %w", filePath, err)
|
||||
}
|
||||
relPaths = append(relPaths, relPath)
|
||||
}
|
||||
|
||||
// Find and strip common leading directory if all files share one
|
||||
commonPrefix := findCommonPrefix(relPaths)
|
||||
|
||||
// Validate that there's exactly one .blend file at the root level after prefix stripping
|
||||
blendFilesAtRoot := 0
|
||||
for _, relPath := range relPaths {
|
||||
tarPath := filepath.ToSlash(relPath)
|
||||
// Strip common prefix if present
|
||||
if commonPrefix != "" && strings.HasPrefix(tarPath, commonPrefix) {
|
||||
tarPath = strings.TrimPrefix(tarPath, commonPrefix)
|
||||
}
|
||||
|
||||
// Check if it's a .blend file at root (no path separators after prefix stripping)
|
||||
if strings.HasSuffix(strings.ToLower(tarPath), ".blend") {
|
||||
// Check if it's at root level (no directory separators)
|
||||
if !strings.Contains(tarPath, "/") {
|
||||
blendFilesAtRoot++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if blendFilesAtRoot == 0 {
|
||||
return "", fmt.Errorf("no .blend file found at root level in context archive - .blend files must be at the root level of the uploaded archive, not in subdirectories")
|
||||
}
|
||||
if blendFilesAtRoot > 1 {
|
||||
return "", fmt.Errorf("multiple .blend files found at root level in context archive (found %d, expected 1)", blendFilesAtRoot)
|
||||
}
|
||||
|
||||
// Create the tar file using streaming
|
||||
contextFile, err := os.Create(contextPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create context file: %w", err)
|
||||
}
|
||||
defer contextFile.Close()
|
||||
|
||||
tarWriter := tar.NewWriter(contextFile)
|
||||
defer tarWriter.Close()
|
||||
|
||||
// Add each file to the tar archive
|
||||
for i, filePath := range filesToInclude {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to open file %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Use a function closure to ensure file is closed even on error
|
||||
err = func() error {
|
||||
defer file.Close()
|
||||
|
||||
info, err := file.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stat file %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Get relative path and strip common prefix if present
|
||||
relPath := relPaths[i]
|
||||
tarPath := filepath.ToSlash(relPath)
|
||||
|
||||
// Strip common prefix if found
|
||||
if commonPrefix != "" && strings.HasPrefix(tarPath, commonPrefix) {
|
||||
tarPath = strings.TrimPrefix(tarPath, commonPrefix)
|
||||
}
|
||||
|
||||
// Create tar header
|
||||
header, err := tar.FileInfoHeader(info, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create tar header for %s: %w", filePath, err)
|
||||
}
|
||||
header.Name = tarPath
|
||||
|
||||
// Write header
|
||||
if err := tarWriter.WriteHeader(header); err != nil {
|
||||
return fmt.Errorf("failed to write tar header for %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Copy file contents using streaming
|
||||
if _, err := io.Copy(tarWriter, file); err != nil {
|
||||
return fmt.Errorf("failed to write file %s to tar: %w", filePath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure all data is flushed
|
||||
if err := tarWriter.Close(); err != nil {
|
||||
return "", fmt.Errorf("failed to close tar writer: %w", err)
|
||||
}
|
||||
if err := contextFile.Close(); err != nil {
|
||||
return "", fmt.Errorf("failed to close context file: %w", err)
|
||||
}
|
||||
|
||||
return contextPath, nil
|
||||
}
|
||||
|
||||
|
||||
BIN
jiggablend
Executable file
BIN
jiggablend
Executable file
Binary file not shown.
366
pkg/executils/exec.go
Normal file
366
pkg/executils/exec.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package executils
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"jiggablend/pkg/types"
|
||||
)
|
||||
|
||||
// DefaultTracker is the global default process tracker
|
||||
// Use this for processes that should be tracked globally and killed on shutdown
|
||||
var DefaultTracker = NewProcessTracker()
|
||||
|
||||
// ProcessTracker tracks running processes for cleanup
|
||||
type ProcessTracker struct {
|
||||
processes sync.Map // map[int64]*exec.Cmd - tracks running processes by task ID
|
||||
}
|
||||
|
||||
// NewProcessTracker creates a new process tracker
|
||||
func NewProcessTracker() *ProcessTracker {
|
||||
return &ProcessTracker{}
|
||||
}
|
||||
|
||||
// Track registers a process for tracking
|
||||
func (pt *ProcessTracker) Track(taskID int64, cmd *exec.Cmd) {
|
||||
pt.processes.Store(taskID, cmd)
|
||||
}
|
||||
|
||||
// Untrack removes a process from tracking
|
||||
func (pt *ProcessTracker) Untrack(taskID int64) {
|
||||
pt.processes.Delete(taskID)
|
||||
}
|
||||
|
||||
// Get returns the command for a task ID if it exists
|
||||
func (pt *ProcessTracker) Get(taskID int64) (*exec.Cmd, bool) {
|
||||
if val, ok := pt.processes.Load(taskID); ok {
|
||||
return val.(*exec.Cmd), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Kill kills a specific process by task ID
|
||||
// Returns true if the process was found and killed
|
||||
func (pt *ProcessTracker) Kill(taskID int64) bool {
|
||||
cmd, ok := pt.Get(taskID)
|
||||
if !ok || cmd.Process == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Try graceful kill first (SIGINT)
|
||||
if err := cmd.Process.Signal(os.Interrupt); err != nil {
|
||||
// If SIGINT fails, try SIGKILL
|
||||
cmd.Process.Kill()
|
||||
} else {
|
||||
// Give it a moment to clean up gracefully
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// Force kill if still running
|
||||
cmd.Process.Kill()
|
||||
}
|
||||
|
||||
pt.Untrack(taskID)
|
||||
return true
|
||||
}
|
||||
|
||||
// KillAll kills all tracked processes
|
||||
// Returns the number of processes killed
|
||||
func (pt *ProcessTracker) KillAll() int {
|
||||
var killedCount int
|
||||
pt.processes.Range(func(key, value interface{}) bool {
|
||||
taskID := key.(int64)
|
||||
cmd := value.(*exec.Cmd)
|
||||
if cmd.Process != nil {
|
||||
// Try graceful kill first (SIGINT)
|
||||
if err := cmd.Process.Signal(os.Interrupt); err == nil {
|
||||
// Give it a moment to clean up
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
// Force kill
|
||||
cmd.Process.Kill()
|
||||
killedCount++
|
||||
}
|
||||
pt.processes.Delete(taskID)
|
||||
return true
|
||||
})
|
||||
return killedCount
|
||||
}
|
||||
|
||||
// Count returns the number of tracked processes
|
||||
func (pt *ProcessTracker) Count() int {
|
||||
count := 0
|
||||
pt.processes.Range(func(key, value interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
return count
|
||||
}
|
||||
|
||||
// CommandResult holds the output from a command execution
|
||||
type CommandResult struct {
|
||||
Stdout string
|
||||
Stderr string
|
||||
ExitCode int
|
||||
}
|
||||
|
||||
// RunCommand executes a command and returns the output
|
||||
// If tracker is provided, the process will be registered for tracking
|
||||
// This is useful for commands where you need to capture output (like metadata extraction)
|
||||
func RunCommand(
|
||||
cmdPath string,
|
||||
args []string,
|
||||
dir string,
|
||||
env []string,
|
||||
taskID int64,
|
||||
tracker *ProcessTracker,
|
||||
) (*CommandResult, error) {
|
||||
cmd := exec.Command(cmdPath, args...)
|
||||
cmd.Dir = dir
|
||||
if env != nil {
|
||||
cmd.Env = env
|
||||
}
|
||||
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start command: %w", err)
|
||||
}
|
||||
|
||||
// Track the process if tracker is provided
|
||||
if tracker != nil {
|
||||
tracker.Track(taskID, cmd)
|
||||
defer tracker.Untrack(taskID)
|
||||
}
|
||||
|
||||
// Collect stdout
|
||||
var stdoutBuf, stderrBuf []byte
|
||||
var stdoutErr, stderrErr error
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
stdoutBuf, stdoutErr = readAll(stdoutPipe)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
stderrBuf, stderrErr = readAll(stderrPipe)
|
||||
}()
|
||||
|
||||
waitErr := cmd.Wait()
|
||||
wg.Wait()
|
||||
|
||||
// Check for read errors
|
||||
if stdoutErr != nil {
|
||||
return nil, fmt.Errorf("failed to read stdout: %w", stdoutErr)
|
||||
}
|
||||
if stderrErr != nil {
|
||||
return nil, fmt.Errorf("failed to read stderr: %w", stderrErr)
|
||||
}
|
||||
|
||||
result := &CommandResult{
|
||||
Stdout: string(stdoutBuf),
|
||||
Stderr: string(stderrBuf),
|
||||
}
|
||||
|
||||
if waitErr != nil {
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
result.ExitCode = exitErr.ExitCode()
|
||||
} else {
|
||||
result.ExitCode = -1
|
||||
}
|
||||
return result, waitErr
|
||||
}
|
||||
|
||||
result.ExitCode = 0
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// readAll reads all data from a reader
|
||||
func readAll(r interface{ Read([]byte) (int, error) }) ([]byte, error) {
|
||||
var buf []byte
|
||||
tmp := make([]byte, 4096)
|
||||
for {
|
||||
n, err := r.Read(tmp)
|
||||
if n > 0 {
|
||||
buf = append(buf, tmp[:n]...)
|
||||
}
|
||||
if err != nil {
|
||||
if err.Error() == "EOF" {
|
||||
break
|
||||
}
|
||||
return buf, err
|
||||
}
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// LogSender is a function type for sending logs
|
||||
type LogSender func(taskID int, level types.LogLevel, message string, stepName string)
|
||||
|
||||
// LineFilter is a function that processes a line and returns whether to filter it out and the log level
|
||||
type LineFilter func(line string) (shouldFilter bool, level types.LogLevel)
|
||||
|
||||
// RunCommandWithStreaming executes a command with streaming output and OOM detection
|
||||
// If tracker is provided, the process will be registered for tracking
|
||||
func RunCommandWithStreaming(
|
||||
cmdPath string,
|
||||
args []string,
|
||||
dir string,
|
||||
env []string,
|
||||
taskID int,
|
||||
stepName string,
|
||||
logSender LogSender,
|
||||
stdoutFilter LineFilter,
|
||||
stderrFilter LineFilter,
|
||||
oomMessage string,
|
||||
tracker *ProcessTracker,
|
||||
) error {
|
||||
cmd := exec.Command(cmdPath, args...)
|
||||
cmd.Dir = dir
|
||||
cmd.Env = env
|
||||
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("failed to create stdout pipe: %v", err)
|
||||
logSender(taskID, types.LogLevelError, errMsg, stepName)
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("failed to create stderr pipe: %v", err)
|
||||
logSender(taskID, types.LogLevelError, errMsg, stepName)
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
errMsg := fmt.Sprintf("failed to start command: %v", err)
|
||||
logSender(taskID, types.LogLevelError, errMsg, stepName)
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
// Track the process if tracker is provided
|
||||
if tracker != nil {
|
||||
tracker.Track(int64(taskID), cmd)
|
||||
defer tracker.Untrack(int64(taskID))
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
scanner := bufio.NewScanner(stdoutPipe)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line != "" {
|
||||
shouldFilter, level := stdoutFilter(line)
|
||||
if !shouldFilter {
|
||||
logSender(taskID, level, line, stepName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
scanner := bufio.NewScanner(stderrPipe)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line != "" {
|
||||
shouldFilter, level := stderrFilter(line)
|
||||
if !shouldFilter {
|
||||
logSender(taskID, level, line, stepName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err = cmd.Wait()
|
||||
wg.Wait()
|
||||
|
||||
if err != nil {
|
||||
var errMsg string
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
if exitErr.ExitCode() == 137 {
|
||||
errMsg = oomMessage
|
||||
} else {
|
||||
errMsg = fmt.Sprintf("command failed: %v", err)
|
||||
}
|
||||
} else {
|
||||
errMsg = fmt.Sprintf("command failed: %v", err)
|
||||
}
|
||||
logSender(taskID, types.LogLevelError, errMsg, stepName)
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper functions using DefaultTracker
|
||||
// ============================================================================
|
||||
|
||||
// Run executes a command using the default tracker and returns the output
|
||||
// This is a convenience wrapper around RunCommand that uses DefaultTracker
|
||||
func Run(cmdPath string, args []string, dir string, env []string, taskID int64) (*CommandResult, error) {
|
||||
return RunCommand(cmdPath, args, dir, env, taskID, DefaultTracker)
|
||||
}
|
||||
|
||||
// RunStreaming executes a command with streaming output using the default tracker
|
||||
// This is a convenience wrapper around RunCommandWithStreaming that uses DefaultTracker
|
||||
func RunStreaming(
|
||||
cmdPath string,
|
||||
args []string,
|
||||
dir string,
|
||||
env []string,
|
||||
taskID int,
|
||||
stepName string,
|
||||
logSender LogSender,
|
||||
stdoutFilter LineFilter,
|
||||
stderrFilter LineFilter,
|
||||
oomMessage string,
|
||||
) error {
|
||||
return RunCommandWithStreaming(cmdPath, args, dir, env, taskID, stepName, logSender, stdoutFilter, stderrFilter, oomMessage, DefaultTracker)
|
||||
}
|
||||
|
||||
// KillAll kills all processes tracked by the default tracker
|
||||
// Returns the number of processes killed
|
||||
func KillAll() int {
|
||||
return DefaultTracker.KillAll()
|
||||
}
|
||||
|
||||
// Kill kills a specific process by task ID using the default tracker
|
||||
// Returns true if the process was found and killed
|
||||
func Kill(taskID int64) bool {
|
||||
return DefaultTracker.Kill(taskID)
|
||||
}
|
||||
|
||||
// Track registers a process with the default tracker
|
||||
func Track(taskID int64, cmd *exec.Cmd) {
|
||||
DefaultTracker.Track(taskID, cmd)
|
||||
}
|
||||
|
||||
// Untrack removes a process from the default tracker
|
||||
func Untrack(taskID int64) {
|
||||
DefaultTracker.Untrack(taskID)
|
||||
}
|
||||
|
||||
// GetTrackedCount returns the number of processes tracked by the default tracker
|
||||
func GetTrackedCount() int {
|
||||
return DefaultTracker.Count()
|
||||
}
|
||||
13
pkg/scripts/scripts.go
Normal file
13
pkg/scripts/scripts.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package scripts
|
||||
|
||||
import _ "embed"
|
||||
|
||||
//go:embed scripts/extract_metadata.py
|
||||
var ExtractMetadata string
|
||||
|
||||
//go:embed scripts/unhide_objects.py
|
||||
var UnhideObjects string
|
||||
|
||||
//go:embed scripts/render_blender.py.template
|
||||
var RenderBlenderTemplate string
|
||||
|
||||
370
pkg/scripts/scripts/extract_metadata.py
Normal file
370
pkg/scripts/scripts/extract_metadata.py
Normal file
@@ -0,0 +1,370 @@
|
||||
import bpy
|
||||
import json
|
||||
import sys
|
||||
|
||||
# Make all file paths relative to the blend file location FIRST
|
||||
# This must be done immediately after file load, before any other operations
|
||||
# to prevent Blender from trying to access external files with absolute paths
|
||||
try:
|
||||
bpy.ops.file.make_paths_relative()
|
||||
print("Made all file paths relative to blend file")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not make paths relative: {e}")
|
||||
|
||||
# Check for missing addons that the blend file requires
|
||||
# Blender marks missing addons with "_missing" suffix in preferences
|
||||
missing_files_info = {
|
||||
"checked": False,
|
||||
"has_missing": False,
|
||||
"missing_files": [],
|
||||
"missing_addons": []
|
||||
}
|
||||
|
||||
try:
|
||||
missing = []
|
||||
for mod in bpy.context.preferences.addons:
|
||||
if mod.module.endswith("_missing"):
|
||||
missing.append(mod.module.rsplit("_", 1)[0])
|
||||
|
||||
missing_files_info["checked"] = True
|
||||
if missing:
|
||||
missing_files_info["has_missing"] = True
|
||||
missing_files_info["missing_addons"] = missing
|
||||
print("Missing add-ons required by this .blend:")
|
||||
for name in missing:
|
||||
print(" -", name)
|
||||
else:
|
||||
print("No missing add-ons detected – file is headless-safe")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not check for missing addons: {e}")
|
||||
missing_files_info["error"] = str(e)
|
||||
|
||||
# Get scene
|
||||
scene = bpy.context.scene
|
||||
|
||||
# Extract frame range from scene settings
|
||||
frame_start = scene.frame_start
|
||||
frame_end = scene.frame_end
|
||||
|
||||
# Check for negative frames (not supported)
|
||||
has_negative_start = frame_start < 0
|
||||
has_negative_end = frame_end < 0
|
||||
|
||||
# Also check for actual animation range (keyframes)
|
||||
# Find the earliest and latest keyframes across all objects
|
||||
animation_start = None
|
||||
animation_end = None
|
||||
|
||||
for obj in scene.objects:
|
||||
if obj.animation_data and obj.animation_data.action:
|
||||
action = obj.animation_data.action
|
||||
# Check if action has fcurves attribute (varies by Blender version/context)
|
||||
try:
|
||||
fcurves = action.fcurves if hasattr(action, 'fcurves') else None
|
||||
if fcurves:
|
||||
for fcurve in fcurves:
|
||||
if fcurve.keyframe_points:
|
||||
for keyframe in fcurve.keyframe_points:
|
||||
frame = int(keyframe.co[0])
|
||||
if animation_start is None or frame < animation_start:
|
||||
animation_start = frame
|
||||
if animation_end is None or frame > animation_end:
|
||||
animation_end = frame
|
||||
except (AttributeError, TypeError) as e:
|
||||
# Action doesn't have fcurves or fcurves is not iterable - skip this object
|
||||
pass
|
||||
|
||||
# Use animation range if available, otherwise use scene frame range
|
||||
# If scene range seems wrong (start == end), prefer animation range
|
||||
if animation_start is not None and animation_end is not None:
|
||||
if frame_start == frame_end or (animation_start < frame_start or animation_end > frame_end):
|
||||
# Use animation range if scene range is invalid or animation extends beyond it
|
||||
frame_start = animation_start
|
||||
frame_end = animation_end
|
||||
|
||||
# Check for negative frames (not supported)
|
||||
has_negative_start = frame_start < 0
|
||||
has_negative_end = frame_end < 0
|
||||
has_negative_animation = (animation_start is not None and animation_start < 0) or (animation_end is not None and animation_end < 0)
|
||||
|
||||
# Extract render settings
|
||||
render = scene.render
|
||||
resolution_x = render.resolution_x
|
||||
resolution_y = render.resolution_y
|
||||
frame_rate = render.fps / render.fps_base if render.fps_base != 0 else render.fps
|
||||
engine = scene.render.engine.upper()
|
||||
|
||||
# Determine output format from file format
|
||||
output_format = render.image_settings.file_format
|
||||
|
||||
# Extract engine-specific settings
|
||||
engine_settings = {}
|
||||
|
||||
if engine == 'CYCLES':
|
||||
cycles = scene.cycles
|
||||
# Get denoiser settings - in Blender 3.0+ it's on the view layer
|
||||
denoiser = 'OPENIMAGEDENOISE' # Default
|
||||
denoising_use_gpu = False
|
||||
denoising_input_passes = 'RGB_ALBEDO_NORMAL' # Default: Albedo and Normal
|
||||
denoising_prefilter = 'ACCURATE' # Default
|
||||
denoising_quality = 'HIGH' # Default (for OpenImageDenoise)
|
||||
try:
|
||||
view_layer = bpy.context.view_layer
|
||||
if hasattr(view_layer, 'cycles'):
|
||||
vl_cycles = view_layer.cycles
|
||||
denoiser = getattr(vl_cycles, 'denoiser', 'OPENIMAGEDENOISE')
|
||||
denoising_use_gpu = getattr(vl_cycles, 'denoising_use_gpu', False)
|
||||
denoising_input_passes = getattr(vl_cycles, 'denoising_input_passes', 'RGB_ALBEDO_NORMAL')
|
||||
denoising_prefilter = getattr(vl_cycles, 'denoising_prefilter', 'ACCURATE')
|
||||
# Quality is only for OpenImageDenoise in Blender 4.0+
|
||||
denoising_quality = getattr(vl_cycles, 'denoising_quality', 'HIGH')
|
||||
except:
|
||||
pass
|
||||
|
||||
engine_settings = {
|
||||
# Sampling settings
|
||||
"samples": getattr(cycles, 'samples', 4096), # Max Samples
|
||||
"adaptive_min_samples": getattr(cycles, 'adaptive_min_samples', 0), # Min Samples
|
||||
"use_adaptive_sampling": getattr(cycles, 'use_adaptive_sampling', True), # Noise Threshold enabled
|
||||
"adaptive_threshold": getattr(cycles, 'adaptive_threshold', 0.01), # Noise Threshold value
|
||||
"time_limit": getattr(cycles, 'time_limit', 0.0), # Time Limit (0 = disabled)
|
||||
|
||||
# Denoising settings
|
||||
"use_denoising": getattr(cycles, 'use_denoising', False),
|
||||
"denoiser": denoiser,
|
||||
"denoising_use_gpu": denoising_use_gpu,
|
||||
"denoising_input_passes": denoising_input_passes,
|
||||
"denoising_prefilter": denoising_prefilter,
|
||||
"denoising_quality": denoising_quality,
|
||||
|
||||
# Path Guiding settings
|
||||
"use_guiding": getattr(cycles, 'use_guiding', False),
|
||||
"guiding_training_samples": getattr(cycles, 'guiding_training_samples', 128),
|
||||
"use_surface_guiding": getattr(cycles, 'use_surface_guiding', True),
|
||||
"use_volume_guiding": getattr(cycles, 'use_volume_guiding', True),
|
||||
|
||||
# Lights settings
|
||||
"use_light_tree": getattr(cycles, 'use_light_tree', True),
|
||||
"light_sampling_threshold": getattr(cycles, 'light_sampling_threshold', 0.01),
|
||||
|
||||
# Device
|
||||
"device": getattr(cycles, 'device', 'CPU'),
|
||||
|
||||
# Advanced/Seed settings
|
||||
"seed": getattr(cycles, 'seed', 0),
|
||||
"use_animated_seed": getattr(cycles, 'use_animated_seed', False),
|
||||
"sampling_pattern": getattr(cycles, 'sampling_pattern', 'AUTOMATIC'),
|
||||
"scrambling_distance": getattr(cycles, 'scrambling_distance', 1.0),
|
||||
"auto_scrambling_distance_multiplier": getattr(cycles, 'auto_scrambling_distance_multiplier', 1.0),
|
||||
"preview_scrambling_distance": getattr(cycles, 'preview_scrambling_distance', False),
|
||||
"min_light_bounces": getattr(cycles, 'min_light_bounces', 0),
|
||||
"min_transparent_bounces": getattr(cycles, 'min_transparent_bounces', 0),
|
||||
|
||||
# Clamping
|
||||
"sample_clamp_direct": getattr(cycles, 'sample_clamp_direct', 0.0),
|
||||
"sample_clamp_indirect": getattr(cycles, 'sample_clamp_indirect', 0.0),
|
||||
|
||||
# Light Paths / Bounces
|
||||
"max_bounces": getattr(cycles, 'max_bounces', 12),
|
||||
"diffuse_bounces": getattr(cycles, 'diffuse_bounces', 4),
|
||||
"glossy_bounces": getattr(cycles, 'glossy_bounces', 4),
|
||||
"transmission_bounces": getattr(cycles, 'transmission_bounces', 12),
|
||||
"volume_bounces": getattr(cycles, 'volume_bounces', 0),
|
||||
"transparent_max_bounces": getattr(cycles, 'transparent_max_bounces', 8),
|
||||
|
||||
# Caustics
|
||||
"caustics_reflective": getattr(cycles, 'caustics_reflective', False),
|
||||
"caustics_refractive": getattr(cycles, 'caustics_refractive', False),
|
||||
"blur_glossy": getattr(cycles, 'blur_glossy', 0.0), # Filter Glossy
|
||||
|
||||
# Fast GI Approximation
|
||||
"use_fast_gi": getattr(cycles, 'use_fast_gi', False),
|
||||
"fast_gi_method": getattr(cycles, 'fast_gi_method', 'REPLACE'), # REPLACE or ADD
|
||||
"ao_bounces": getattr(cycles, 'ao_bounces', 1), # Viewport bounces
|
||||
"ao_bounces_render": getattr(cycles, 'ao_bounces_render', 1), # Render bounces
|
||||
|
||||
# Volumes
|
||||
"volume_step_rate": getattr(cycles, 'volume_step_rate', 1.0),
|
||||
"volume_preview_step_rate": getattr(cycles, 'volume_preview_step_rate', 1.0),
|
||||
"volume_max_steps": getattr(cycles, 'volume_max_steps', 1024),
|
||||
|
||||
# Film
|
||||
"film_exposure": getattr(cycles, 'film_exposure', 1.0),
|
||||
"film_transparent": getattr(cycles, 'film_transparent', False),
|
||||
"film_transparent_glass": getattr(cycles, 'film_transparent_glass', False),
|
||||
"film_transparent_roughness": getattr(cycles, 'film_transparent_roughness', 0.1),
|
||||
"filter_type": getattr(cycles, 'filter_type', 'BLACKMAN_HARRIS'), # BOX, GAUSSIAN, BLACKMAN_HARRIS
|
||||
"filter_width": getattr(cycles, 'filter_width', 1.5),
|
||||
"pixel_filter_type": getattr(cycles, 'pixel_filter_type', 'BLACKMAN_HARRIS'),
|
||||
|
||||
# Performance
|
||||
"use_auto_tile": getattr(cycles, 'use_auto_tile', True),
|
||||
"tile_size": getattr(cycles, 'tile_size', 2048),
|
||||
"use_persistent_data": getattr(cycles, 'use_persistent_data', False),
|
||||
|
||||
# Hair/Curves
|
||||
"use_hair": getattr(cycles, 'use_hair', True),
|
||||
"hair_subdivisions": getattr(cycles, 'hair_subdivisions', 2),
|
||||
"hair_shape": getattr(cycles, 'hair_shape', 'THICK'), # ROUND, RIBBONS, THICK
|
||||
|
||||
# Simplify (from scene.render)
|
||||
"use_simplify": getattr(scene.render, 'use_simplify', False),
|
||||
"simplify_subdivision_render": getattr(scene.render, 'simplify_subdivision_render', 6),
|
||||
"simplify_child_particles_render": getattr(scene.render, 'simplify_child_particles_render', 1.0),
|
||||
|
||||
# Other
|
||||
"use_light_linking": getattr(cycles, 'use_light_linking', False),
|
||||
"use_layer_samples": getattr(cycles, 'use_layer_samples', False),
|
||||
}
|
||||
elif engine == 'EEVEE' or engine == 'EEVEE_NEXT':
|
||||
# Treat EEVEE_NEXT as EEVEE (modern Blender uses EEVEE for what was EEVEE_NEXT)
|
||||
eevee = scene.eevee
|
||||
engine_settings = {
|
||||
# Sampling
|
||||
"taa_render_samples": getattr(eevee, 'taa_render_samples', 64),
|
||||
"taa_samples": getattr(eevee, 'taa_samples', 16), # Viewport samples
|
||||
"use_taa_reprojection": getattr(eevee, 'use_taa_reprojection', True),
|
||||
|
||||
# Clamping
|
||||
"clamp_surface_direct": getattr(eevee, 'clamp_surface_direct', 0.0),
|
||||
"clamp_surface_indirect": getattr(eevee, 'clamp_surface_indirect', 0.0),
|
||||
"clamp_volume_direct": getattr(eevee, 'clamp_volume_direct', 0.0),
|
||||
"clamp_volume_indirect": getattr(eevee, 'clamp_volume_indirect', 0.0),
|
||||
|
||||
# Shadows
|
||||
"shadow_cube_size": getattr(eevee, 'shadow_cube_size', '512'),
|
||||
"shadow_cascade_size": getattr(eevee, 'shadow_cascade_size', '1024'),
|
||||
"use_shadow_high_bitdepth": getattr(eevee, 'use_shadow_high_bitdepth', False),
|
||||
"use_soft_shadows": getattr(eevee, 'use_soft_shadows', True),
|
||||
"light_threshold": getattr(eevee, 'light_threshold', 0.01),
|
||||
|
||||
# Raytracing (EEVEE Next / modern EEVEE)
|
||||
"use_raytracing": getattr(eevee, 'use_raytracing', False),
|
||||
"ray_tracing_method": getattr(eevee, 'ray_tracing_method', 'SCREEN'), # SCREEN or PROBE
|
||||
"ray_tracing_options_trace_max_roughness": getattr(eevee, 'ray_tracing_options', {}).get('trace_max_roughness', 0.5) if hasattr(getattr(eevee, 'ray_tracing_options', None), 'get') else 0.5,
|
||||
|
||||
# Screen Space Reflections (legacy/fallback)
|
||||
"use_ssr": getattr(eevee, 'use_ssr', False),
|
||||
"use_ssr_refraction": getattr(eevee, 'use_ssr_refraction', False),
|
||||
"use_ssr_halfres": getattr(eevee, 'use_ssr_halfres', True),
|
||||
"ssr_quality": getattr(eevee, 'ssr_quality', 0.25),
|
||||
"ssr_max_roughness": getattr(eevee, 'ssr_max_roughness', 0.5),
|
||||
"ssr_thickness": getattr(eevee, 'ssr_thickness', 0.2),
|
||||
"ssr_border_fade": getattr(eevee, 'ssr_border_fade', 0.075),
|
||||
"ssr_firefly_fac": getattr(eevee, 'ssr_firefly_fac', 10.0),
|
||||
|
||||
# Ambient Occlusion
|
||||
"use_gtao": getattr(eevee, 'use_gtao', False),
|
||||
"gtao_distance": getattr(eevee, 'gtao_distance', 0.2),
|
||||
"gtao_factor": getattr(eevee, 'gtao_factor', 1.0),
|
||||
"gtao_quality": getattr(eevee, 'gtao_quality', 0.25),
|
||||
"use_gtao_bent_normals": getattr(eevee, 'use_gtao_bent_normals', True),
|
||||
"use_gtao_bounce": getattr(eevee, 'use_gtao_bounce', True),
|
||||
|
||||
# Bloom
|
||||
"use_bloom": getattr(eevee, 'use_bloom', False),
|
||||
"bloom_threshold": getattr(eevee, 'bloom_threshold', 0.8),
|
||||
"bloom_knee": getattr(eevee, 'bloom_knee', 0.5),
|
||||
"bloom_radius": getattr(eevee, 'bloom_radius', 6.5),
|
||||
"bloom_color": list(getattr(eevee, 'bloom_color', (1.0, 1.0, 1.0))),
|
||||
"bloom_intensity": getattr(eevee, 'bloom_intensity', 0.05),
|
||||
"bloom_clamp": getattr(eevee, 'bloom_clamp', 0.0),
|
||||
|
||||
# Depth of Field
|
||||
"bokeh_max_size": getattr(eevee, 'bokeh_max_size', 100.0),
|
||||
"bokeh_threshold": getattr(eevee, 'bokeh_threshold', 1.0),
|
||||
"bokeh_neighbor_max": getattr(eevee, 'bokeh_neighbor_max', 10.0),
|
||||
"bokeh_denoise_fac": getattr(eevee, 'bokeh_denoise_fac', 0.75),
|
||||
"use_bokeh_high_quality_slight_defocus": getattr(eevee, 'use_bokeh_high_quality_slight_defocus', False),
|
||||
"use_bokeh_jittered": getattr(eevee, 'use_bokeh_jittered', False),
|
||||
"bokeh_overblur": getattr(eevee, 'bokeh_overblur', 5.0),
|
||||
|
||||
# Subsurface Scattering
|
||||
"sss_samples": getattr(eevee, 'sss_samples', 7),
|
||||
"sss_jitter_threshold": getattr(eevee, 'sss_jitter_threshold', 0.3),
|
||||
|
||||
# Volumetrics
|
||||
"use_volumetric_lights": getattr(eevee, 'use_volumetric_lights', True),
|
||||
"use_volumetric_shadows": getattr(eevee, 'use_volumetric_shadows', False),
|
||||
"volumetric_start": getattr(eevee, 'volumetric_start', 0.1),
|
||||
"volumetric_end": getattr(eevee, 'volumetric_end', 100.0),
|
||||
"volumetric_tile_size": getattr(eevee, 'volumetric_tile_size', '8'),
|
||||
"volumetric_samples": getattr(eevee, 'volumetric_samples', 64),
|
||||
"volumetric_sample_distribution": getattr(eevee, 'volumetric_sample_distribution', 0.8),
|
||||
"volumetric_ray_depth": getattr(eevee, 'volumetric_ray_depth', 16),
|
||||
|
||||
# Motion Blur
|
||||
"use_motion_blur": getattr(eevee, 'use_motion_blur', False),
|
||||
"motion_blur_position": getattr(eevee, 'motion_blur_position', 'CENTER'),
|
||||
"motion_blur_shutter": getattr(eevee, 'motion_blur_shutter', 0.5),
|
||||
"motion_blur_depth_scale": getattr(eevee, 'motion_blur_depth_scale', 100.0),
|
||||
"motion_blur_max": getattr(eevee, 'motion_blur_max', 32),
|
||||
"motion_blur_steps": getattr(eevee, 'motion_blur_steps', 1),
|
||||
|
||||
# Film
|
||||
"use_overscan": getattr(eevee, 'use_overscan', False),
|
||||
"overscan_size": getattr(eevee, 'overscan_size', 3.0),
|
||||
|
||||
# Indirect Lighting
|
||||
"gi_diffuse_bounces": getattr(eevee, 'gi_diffuse_bounces', 3),
|
||||
"gi_cubemap_resolution": getattr(eevee, 'gi_cubemap_resolution', '512'),
|
||||
"gi_visibility_resolution": getattr(eevee, 'gi_visibility_resolution', '32'),
|
||||
"gi_irradiance_smoothing": getattr(eevee, 'gi_irradiance_smoothing', 0.1),
|
||||
"gi_glossy_clamp": getattr(eevee, 'gi_glossy_clamp', 0.0),
|
||||
"gi_filter_quality": getattr(eevee, 'gi_filter_quality', 3.0),
|
||||
"gi_show_irradiance": getattr(eevee, 'gi_show_irradiance', False),
|
||||
"gi_show_cubemaps": getattr(eevee, 'gi_show_cubemaps', False),
|
||||
"gi_auto_bake": getattr(eevee, 'gi_auto_bake', False),
|
||||
|
||||
# Hair/Curves
|
||||
"hair_type": getattr(eevee, 'hair_type', 'STRIP'), # STRIP or STRAND
|
||||
|
||||
# Performance
|
||||
"use_shadow_jitter_viewport": getattr(eevee, 'use_shadow_jitter_viewport', True),
|
||||
|
||||
# Simplify (from scene.render)
|
||||
"use_simplify": getattr(scene.render, 'use_simplify', False),
|
||||
"simplify_subdivision_render": getattr(scene.render, 'simplify_subdivision_render', 6),
|
||||
"simplify_child_particles_render": getattr(scene.render, 'simplify_child_particles_render', 1.0),
|
||||
}
|
||||
else:
|
||||
# For other engines, extract basic samples if available
|
||||
engine_settings = {
|
||||
"samples": getattr(scene, 'samples', 128) if hasattr(scene, 'samples') else 128
|
||||
}
|
||||
|
||||
# Extract scene info
|
||||
camera_count = len([obj for obj in scene.objects if obj.type == 'CAMERA'])
|
||||
object_count = len(scene.objects)
|
||||
material_count = len(bpy.data.materials)
|
||||
|
||||
# Extract Blender version info
|
||||
# bpy.data.version gives the version the file was saved with
|
||||
blender_version = ".".join(map(str, bpy.data.version)) if hasattr(bpy.data, 'version') else bpy.app.version_string
|
||||
|
||||
# Build metadata dictionary
|
||||
metadata = {
|
||||
"frame_start": frame_start,
|
||||
"frame_end": frame_end,
|
||||
"has_negative_frames": has_negative_start or has_negative_end or has_negative_animation,
|
||||
"blender_version": blender_version,
|
||||
"render_settings": {
|
||||
"resolution_x": resolution_x,
|
||||
"resolution_y": resolution_y,
|
||||
"frame_rate": frame_rate,
|
||||
"output_format": output_format,
|
||||
"engine": engine.lower(),
|
||||
"engine_settings": engine_settings
|
||||
},
|
||||
"scene_info": {
|
||||
"camera_count": camera_count,
|
||||
"object_count": object_count,
|
||||
"material_count": material_count
|
||||
},
|
||||
"missing_files_info": missing_files_info
|
||||
}
|
||||
|
||||
# Output as JSON
|
||||
print(json.dumps(metadata))
|
||||
sys.stdout.flush()
|
||||
|
||||
653
pkg/scripts/scripts/render_blender.py.template
Normal file
653
pkg/scripts/scripts/render_blender.py.template
Normal file
@@ -0,0 +1,653 @@
|
||||
import bpy
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
|
||||
# Make all file paths relative to the blend file location FIRST
|
||||
# This must be done immediately after file load, before any other operations
|
||||
# to prevent Blender from trying to access external files with absolute paths
|
||||
try:
|
||||
bpy.ops.file.make_paths_relative()
|
||||
print("Made all file paths relative to blend file")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not make paths relative: {e}")
|
||||
|
||||
# Auto-enable addons from blender_addons folder in context
|
||||
# Supports .zip files (installed via Blender API) and already-extracted addons
|
||||
blend_dir = os.path.dirname(bpy.data.filepath) if bpy.data.filepath else os.getcwd()
|
||||
addons_dir = os.path.join(blend_dir, "blender_addons")
|
||||
|
||||
if os.path.isdir(addons_dir):
|
||||
print(f"Found blender_addons folder: {addons_dir}")
|
||||
|
||||
for item in os.listdir(addons_dir):
|
||||
item_path = os.path.join(addons_dir, item)
|
||||
|
||||
try:
|
||||
if item.endswith('.zip'):
|
||||
# Install and enable zip addon using Blender's API
|
||||
bpy.ops.preferences.addon_install(filepath=item_path)
|
||||
# Get module name from zip (usually the folder name inside)
|
||||
import zipfile
|
||||
with zipfile.ZipFile(item_path, 'r') as zf:
|
||||
# Find the top-level module name
|
||||
names = zf.namelist()
|
||||
if names:
|
||||
module_name = names[0].split('/')[0]
|
||||
if module_name.endswith('.py'):
|
||||
module_name = module_name[:-3]
|
||||
bpy.ops.preferences.addon_enable(module=module_name)
|
||||
print(f" Installed and enabled addon: {module_name}")
|
||||
|
||||
elif item.endswith('.py') and not item.startswith('__'):
|
||||
# Single-file addon
|
||||
bpy.ops.preferences.addon_install(filepath=item_path)
|
||||
module_name = item[:-3]
|
||||
bpy.ops.preferences.addon_enable(module=module_name)
|
||||
print(f" Installed and enabled addon: {module_name}")
|
||||
|
||||
elif os.path.isdir(item_path) and os.path.exists(os.path.join(item_path, '__init__.py')):
|
||||
# Multi-file addon directory - add to path and enable
|
||||
if addons_dir not in sys.path:
|
||||
sys.path.insert(0, addons_dir)
|
||||
bpy.ops.preferences.addon_enable(module=item)
|
||||
print(f" Enabled addon: {item}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" Error with addon {item}: {e}")
|
||||
else:
|
||||
print(f"No blender_addons folder found at: {addons_dir}")
|
||||
|
||||
{{UNHIDE_CODE}}
|
||||
# Read output format from file (created by Go code)
|
||||
format_file_path = {{FORMAT_FILE_PATH}}
|
||||
output_format_override = None
|
||||
if os.path.exists(format_file_path):
|
||||
try:
|
||||
with open(format_file_path, 'r') as f:
|
||||
output_format_override = f.read().strip().upper()
|
||||
print(f"Read output format from file: '{output_format_override}'")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not read output format file: {e}")
|
||||
else:
|
||||
print(f"Warning: Output format file does not exist: {format_file_path}")
|
||||
|
||||
# Read render settings from JSON file (created by Go code)
|
||||
render_settings_file = {{RENDER_SETTINGS_FILE}}
|
||||
render_settings_override = None
|
||||
if os.path.exists(render_settings_file):
|
||||
try:
|
||||
with open(render_settings_file, 'r') as f:
|
||||
render_settings_override = json.load(f)
|
||||
print(f"Loaded render settings from job metadata")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not read render settings file: {e}")
|
||||
|
||||
# Get current scene settings (preserve blend file preferences)
|
||||
scene = bpy.context.scene
|
||||
current_engine = scene.render.engine
|
||||
current_device = scene.cycles.device if hasattr(scene, 'cycles') and scene.cycles else None
|
||||
current_output_format = scene.render.image_settings.file_format
|
||||
|
||||
print(f"Blend file render engine: {current_engine}")
|
||||
if current_device:
|
||||
print(f"Blend file device setting: {current_device}")
|
||||
print(f"Blend file output format: {current_output_format}")
|
||||
|
||||
# Override output format if specified
|
||||
# The format file always takes precedence (it's written specifically for this job)
|
||||
if output_format_override:
|
||||
print(f"Overriding output format from '{current_output_format}' to '{output_format_override}'")
|
||||
# Map common format names to Blender's format constants
|
||||
# For video formats, we render as appropriate frame format first
|
||||
format_to_use = output_format_override.upper()
|
||||
if format_to_use in ['EXR_264_MP4', 'EXR_AV1_MP4', 'EXR_VP9_WEBM']:
|
||||
format_to_use = 'EXR' # Render as EXR for EXR video formats
|
||||
|
||||
format_map = {
|
||||
'PNG': 'PNG',
|
||||
'JPEG': 'JPEG',
|
||||
'JPG': 'JPEG',
|
||||
'EXR': 'OPEN_EXR',
|
||||
'OPEN_EXR': 'OPEN_EXR',
|
||||
'TARGA': 'TARGA',
|
||||
'TIFF': 'TIFF',
|
||||
'BMP': 'BMP',
|
||||
}
|
||||
blender_format = format_map.get(format_to_use, format_to_use)
|
||||
try:
|
||||
scene.render.image_settings.file_format = blender_format
|
||||
print(f"Successfully set output format to: {blender_format}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not set output format to {blender_format}: {e}")
|
||||
print(f"Using blend file's format: {current_output_format}")
|
||||
else:
|
||||
print(f"Using blend file's output format: {current_output_format}")
|
||||
|
||||
# Apply render settings from job metadata if provided
|
||||
# Note: output_format is NOT applied from render_settings_override - it's already set from format file above
|
||||
if render_settings_override:
|
||||
engine_override = render_settings_override.get('engine', '').upper()
|
||||
engine_settings = render_settings_override.get('engine_settings', {})
|
||||
|
||||
# Switch engine if specified
|
||||
if engine_override and engine_override != current_engine.upper():
|
||||
print(f"Switching render engine from '{current_engine}' to '{engine_override}'")
|
||||
try:
|
||||
scene.render.engine = engine_override
|
||||
current_engine = engine_override
|
||||
print(f"Successfully switched to {engine_override} engine")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not switch engine to {engine_override}: {e}")
|
||||
print(f"Using blend file's engine: {current_engine}")
|
||||
|
||||
# Apply engine-specific settings
|
||||
if engine_settings:
|
||||
if current_engine.upper() == 'CYCLES':
|
||||
cycles = scene.cycles
|
||||
print("Applying Cycles render settings from job metadata...")
|
||||
for key, value in engine_settings.items():
|
||||
try:
|
||||
if hasattr(cycles, key):
|
||||
setattr(cycles, key, value)
|
||||
print(f" Set Cycles.{key} = {value}")
|
||||
else:
|
||||
print(f" Warning: Cycles has no attribute '{key}'")
|
||||
except Exception as e:
|
||||
print(f" Warning: Could not set Cycles.{key} = {value}: {e}")
|
||||
elif current_engine.upper() in ['EEVEE', 'EEVEE_NEXT']:
|
||||
eevee = scene.eevee
|
||||
print("Applying EEVEE render settings from job metadata...")
|
||||
for key, value in engine_settings.items():
|
||||
try:
|
||||
if hasattr(eevee, key):
|
||||
setattr(eevee, key, value)
|
||||
print(f" Set EEVEE.{key} = {value}")
|
||||
else:
|
||||
print(f" Warning: EEVEE has no attribute '{key}'")
|
||||
except Exception as e:
|
||||
print(f" Warning: Could not set EEVEE.{key} = {value}: {e}")
|
||||
|
||||
# Apply resolution if specified
|
||||
if 'resolution_x' in render_settings_override:
|
||||
try:
|
||||
scene.render.resolution_x = render_settings_override['resolution_x']
|
||||
print(f"Set resolution_x = {render_settings_override['resolution_x']}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not set resolution_x: {e}")
|
||||
if 'resolution_y' in render_settings_override:
|
||||
try:
|
||||
scene.render.resolution_y = render_settings_override['resolution_y']
|
||||
print(f"Set resolution_y = {render_settings_override['resolution_y']}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not set resolution_y: {e}")
|
||||
|
||||
# Only override device selection if using Cycles (other engines handle GPU differently)
|
||||
if current_engine == 'CYCLES':
|
||||
# Check if CPU rendering is forced
|
||||
force_cpu = False
|
||||
if render_settings_override and render_settings_override.get('force_cpu'):
|
||||
force_cpu = render_settings_override.get('force_cpu', False)
|
||||
print("Force CPU rendering is enabled - skipping GPU detection")
|
||||
|
||||
# Ensure Cycles addon is enabled
|
||||
try:
|
||||
if 'cycles' not in bpy.context.preferences.addons:
|
||||
bpy.ops.preferences.addon_enable(module='cycles')
|
||||
print("Enabled Cycles addon")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not enable Cycles addon: {e}")
|
||||
|
||||
# If CPU is forced, skip GPU detection and set CPU directly
|
||||
if force_cpu:
|
||||
scene.cycles.device = 'CPU'
|
||||
print("Forced CPU rendering (skipping GPU detection)")
|
||||
else:
|
||||
# Access Cycles preferences
|
||||
prefs = bpy.context.preferences
|
||||
try:
|
||||
cycles_prefs = prefs.addons['cycles'].preferences
|
||||
except (KeyError, AttributeError):
|
||||
try:
|
||||
cycles_addon = prefs.addons.get('cycles')
|
||||
if cycles_addon:
|
||||
cycles_prefs = cycles_addon.preferences
|
||||
else:
|
||||
raise Exception("Cycles addon not found")
|
||||
except Exception as e:
|
||||
print(f"ERROR: Could not access Cycles preferences: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
# Check all devices and choose the best GPU type
|
||||
# Device type preference order (most performant first)
|
||||
device_type_preference = ['OPTIX', 'CUDA', 'HIP', 'ONEAPI', 'METAL']
|
||||
gpu_available = False
|
||||
best_device_type = None
|
||||
best_gpu_devices = []
|
||||
devices_by_type = {} # {device_type: [devices]}
|
||||
seen_device_ids = set() # Track device IDs to avoid duplicates
|
||||
|
||||
print("Checking for GPU availability...")
|
||||
|
||||
# Try to get all devices - try each device type to see what's available
|
||||
for device_type in device_type_preference:
|
||||
try:
|
||||
cycles_prefs.compute_device_type = device_type
|
||||
cycles_prefs.refresh_devices()
|
||||
|
||||
# Get devices for this type
|
||||
devices = None
|
||||
if hasattr(cycles_prefs, 'devices'):
|
||||
try:
|
||||
devices_prop = cycles_prefs.devices
|
||||
if devices_prop:
|
||||
devices = list(devices_prop) if hasattr(devices_prop, '__iter__') else [devices_prop]
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
if not devices or len(devices) == 0:
|
||||
try:
|
||||
devices = cycles_prefs.get_devices()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
if devices and len(devices) > 0:
|
||||
# Categorize devices by their type attribute, avoiding duplicates
|
||||
for device in devices:
|
||||
if hasattr(device, 'type'):
|
||||
device_type_str = str(device.type).upper()
|
||||
device_id = getattr(device, 'id', None)
|
||||
|
||||
# Use device ID to avoid duplicates (same device appears when checking different compute_device_types)
|
||||
if device_id and device_id in seen_device_ids:
|
||||
continue
|
||||
|
||||
if device_id:
|
||||
seen_device_ids.add(device_id)
|
||||
|
||||
if device_type_str not in devices_by_type:
|
||||
devices_by_type[device_type_str] = []
|
||||
devices_by_type[device_type_str].append(device)
|
||||
except (ValueError, AttributeError, KeyError, TypeError):
|
||||
# Device type not supported, continue
|
||||
continue
|
||||
except Exception as e:
|
||||
# Other errors - log but continue
|
||||
print(f" Error checking {device_type}: {e}")
|
||||
continue
|
||||
|
||||
# Print what we found
|
||||
print(f"Found devices by type: {list(devices_by_type.keys())}")
|
||||
for dev_type, dev_list in devices_by_type.items():
|
||||
print(f" {dev_type}: {len(dev_list)} device(s)")
|
||||
for device in dev_list:
|
||||
device_name = getattr(device, 'name', 'Unknown')
|
||||
print(f" - {device_name}")
|
||||
|
||||
# Choose the best GPU type based on preference
|
||||
for preferred_type in device_type_preference:
|
||||
if preferred_type in devices_by_type:
|
||||
gpu_devices = [d for d in devices_by_type[preferred_type] if preferred_type in ['CUDA', 'OPENCL', 'OPTIX', 'HIP', 'METAL', 'ONEAPI']]
|
||||
if gpu_devices:
|
||||
best_device_type = preferred_type
|
||||
best_gpu_devices = [(d, preferred_type) for d in gpu_devices]
|
||||
print(f"Selected {preferred_type} as best GPU type with {len(gpu_devices)} device(s)")
|
||||
break
|
||||
|
||||
# Second pass: Enable the best GPU we found
|
||||
if best_device_type and best_gpu_devices:
|
||||
print(f"\nEnabling GPU devices for {best_device_type}...")
|
||||
try:
|
||||
# Set the device type again
|
||||
cycles_prefs.compute_device_type = best_device_type
|
||||
cycles_prefs.refresh_devices()
|
||||
|
||||
# First, disable all CPU devices to ensure only GPU is used
|
||||
print(f" Disabling CPU devices...")
|
||||
all_devices = cycles_prefs.devices if hasattr(cycles_prefs, 'devices') else cycles_prefs.get_devices()
|
||||
if all_devices:
|
||||
for device in all_devices:
|
||||
if hasattr(device, 'type') and str(device.type).upper() == 'CPU':
|
||||
try:
|
||||
device.use = False
|
||||
device_name = getattr(device, 'name', 'Unknown')
|
||||
print(f" Disabled CPU: {device_name}")
|
||||
except Exception as e:
|
||||
print(f" Warning: Could not disable CPU device {getattr(device, 'name', 'Unknown')}: {e}")
|
||||
|
||||
# Enable all GPU devices
|
||||
enabled_count = 0
|
||||
for device, device_type in best_gpu_devices:
|
||||
try:
|
||||
device.use = True
|
||||
enabled_count += 1
|
||||
device_name = getattr(device, 'name', 'Unknown')
|
||||
print(f" Enabled: {device_name}")
|
||||
except Exception as e:
|
||||
print(f" Warning: Could not enable device {getattr(device, 'name', 'Unknown')}: {e}")
|
||||
|
||||
# Enable ray tracing acceleration for supported device types
|
||||
try:
|
||||
if best_device_type == 'HIP':
|
||||
# HIPRT (HIP Ray Tracing) for AMD GPUs
|
||||
if hasattr(cycles_prefs, 'use_hiprt'):
|
||||
cycles_prefs.use_hiprt = True
|
||||
print(f" Enabled HIPRT (HIP Ray Tracing) for faster rendering")
|
||||
elif hasattr(scene.cycles, 'use_hiprt'):
|
||||
scene.cycles.use_hiprt = True
|
||||
print(f" Enabled HIPRT (HIP Ray Tracing) for faster rendering")
|
||||
else:
|
||||
print(f" HIPRT not available (requires Blender 4.0+)")
|
||||
elif best_device_type == 'OPTIX':
|
||||
# OptiX is already enabled when using OPTIX device type
|
||||
# But we can check if there are any OptiX-specific settings
|
||||
if hasattr(scene.cycles, 'use_optix_denoising'):
|
||||
scene.cycles.use_optix_denoising = True
|
||||
print(f" Enabled OptiX denoising")
|
||||
print(f" OptiX ray tracing is active (using OPTIX device type)")
|
||||
elif best_device_type == 'CUDA':
|
||||
# CUDA can use OptiX if available, but it's usually automatic
|
||||
# Check if we can prefer OptiX over CUDA
|
||||
if hasattr(scene.cycles, 'use_optix_denoising'):
|
||||
scene.cycles.use_optix_denoising = True
|
||||
print(f" Enabled OptiX denoising (if OptiX available)")
|
||||
print(f" CUDA ray tracing active")
|
||||
elif best_device_type == 'METAL':
|
||||
# MetalRT for Apple Silicon (if available)
|
||||
if hasattr(scene.cycles, 'use_metalrt'):
|
||||
scene.cycles.use_metalrt = True
|
||||
print(f" Enabled MetalRT (Metal Ray Tracing) for faster rendering")
|
||||
elif hasattr(cycles_prefs, 'use_metalrt'):
|
||||
cycles_prefs.use_metalrt = True
|
||||
print(f" Enabled MetalRT (Metal Ray Tracing) for faster rendering")
|
||||
else:
|
||||
print(f" MetalRT not available")
|
||||
elif best_device_type == 'ONEAPI':
|
||||
# Intel oneAPI - Embree might be available
|
||||
if hasattr(scene.cycles, 'use_embree'):
|
||||
scene.cycles.use_embree = True
|
||||
print(f" Enabled Embree for faster CPU ray tracing")
|
||||
print(f" oneAPI ray tracing active")
|
||||
except Exception as e:
|
||||
print(f" Could not enable ray tracing acceleration: {e}")
|
||||
|
||||
print(f"SUCCESS: Enabled {enabled_count} GPU device(s) for {best_device_type}")
|
||||
gpu_available = True
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to enable GPU devices: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# Set device based on availability (prefer GPU, fallback to CPU)
|
||||
if gpu_available:
|
||||
scene.cycles.device = 'GPU'
|
||||
print(f"Using GPU for rendering (blend file had: {current_device})")
|
||||
|
||||
# Auto-enable GPU denoising when using GPU (OpenImageDenoise supports all GPUs)
|
||||
try:
|
||||
view_layer = bpy.context.view_layer
|
||||
if hasattr(view_layer, 'cycles') and hasattr(view_layer.cycles, 'denoising_use_gpu'):
|
||||
view_layer.cycles.denoising_use_gpu = True
|
||||
print("Auto-enabled GPU denoising (OpenImageDenoise)")
|
||||
except Exception as e:
|
||||
print(f"Could not auto-enable GPU denoising: {e}")
|
||||
else:
|
||||
scene.cycles.device = 'CPU'
|
||||
print(f"GPU not available, using CPU for rendering (blend file had: {current_device})")
|
||||
|
||||
# Ensure GPU denoising is disabled when using CPU
|
||||
try:
|
||||
view_layer = bpy.context.view_layer
|
||||
if hasattr(view_layer, 'cycles') and hasattr(view_layer.cycles, 'denoising_use_gpu'):
|
||||
view_layer.cycles.denoising_use_gpu = False
|
||||
print("Using CPU denoising")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# Verify device setting
|
||||
if current_engine == 'CYCLES':
|
||||
final_device = scene.cycles.device
|
||||
print(f"Final Cycles device: {final_device}")
|
||||
else:
|
||||
# For other engines (EEVEE, etc.), respect blend file settings
|
||||
print(f"Using {current_engine} engine - respecting blend file settings")
|
||||
|
||||
# Enable GPU acceleration for EEVEE viewport rendering (if using EEVEE)
|
||||
if current_engine == 'EEVEE' or current_engine == 'EEVEE_NEXT':
|
||||
try:
|
||||
if hasattr(bpy.context.preferences.system, 'gpu_backend'):
|
||||
bpy.context.preferences.system.gpu_backend = 'OPENGL'
|
||||
print("Enabled OpenGL GPU backend for EEVEE")
|
||||
except Exception as e:
|
||||
print(f"Could not set EEVEE GPU backend: {e}")
|
||||
|
||||
# Enable GPU acceleration for compositing (if compositing is enabled)
|
||||
try:
|
||||
if scene.use_nodes and hasattr(scene, 'node_tree') and scene.node_tree:
|
||||
if hasattr(scene.node_tree, 'use_gpu_compositing'):
|
||||
scene.node_tree.use_gpu_compositing = True
|
||||
print("Enabled GPU compositing")
|
||||
except Exception as e:
|
||||
print(f"Could not enable GPU compositing: {e}")
|
||||
|
||||
# CRITICAL: Initialize headless rendering to prevent black images
|
||||
# This ensures the render engine is properly initialized before rendering
|
||||
print("Initializing headless rendering context...")
|
||||
try:
|
||||
# Ensure world exists and has proper settings
|
||||
if not scene.world:
|
||||
# Create a default world if none exists
|
||||
world = bpy.data.worlds.new("World")
|
||||
scene.world = world
|
||||
print("Created default world")
|
||||
|
||||
# Ensure world has a background shader (not just black)
|
||||
if scene.world:
|
||||
# Enable nodes if not already enabled
|
||||
if not scene.world.use_nodes:
|
||||
scene.world.use_nodes = True
|
||||
print("Enabled world nodes")
|
||||
|
||||
world_nodes = scene.world.node_tree
|
||||
if world_nodes:
|
||||
# Find or create background shader
|
||||
bg_shader = None
|
||||
for node in world_nodes.nodes:
|
||||
if node.type == 'BACKGROUND':
|
||||
bg_shader = node
|
||||
break
|
||||
|
||||
if not bg_shader:
|
||||
bg_shader = world_nodes.nodes.new(type='ShaderNodeBackground')
|
||||
# Connect to output
|
||||
output = world_nodes.nodes.get('World Output')
|
||||
if not output:
|
||||
output = world_nodes.nodes.new(type='ShaderNodeOutputWorld')
|
||||
output.name = 'World Output'
|
||||
if output and bg_shader:
|
||||
# Connect background to surface input
|
||||
if 'Surface' in output.inputs and 'Background' in bg_shader.outputs:
|
||||
world_nodes.links.new(bg_shader.outputs['Background'], output.inputs['Surface'])
|
||||
print("Created background shader for world")
|
||||
|
||||
# Ensure background has some color (not pure black)
|
||||
if bg_shader:
|
||||
# Only set if it's pure black (0,0,0)
|
||||
if hasattr(bg_shader.inputs, 'Color'):
|
||||
color = bg_shader.inputs['Color'].default_value
|
||||
if len(color) >= 3 and color[0] == 0.0 and color[1] == 0.0 and color[2] == 0.0:
|
||||
# Set to a very dark gray instead of pure black
|
||||
bg_shader.inputs['Color'].default_value = (0.01, 0.01, 0.01, 1.0)
|
||||
print("Adjusted world background color to prevent black renders")
|
||||
else:
|
||||
# Fallback: use legacy world color if nodes aren't working
|
||||
if hasattr(scene.world, 'color'):
|
||||
color = scene.world.color
|
||||
if len(color) >= 3 and color[0] == 0.0 and color[1] == 0.0 and color[2] == 0.0:
|
||||
scene.world.color = (0.01, 0.01, 0.01)
|
||||
print("Adjusted legacy world color to prevent black renders")
|
||||
|
||||
# For EEVEE, force viewport update to initialize render engine
|
||||
if current_engine in ['EEVEE', 'EEVEE_NEXT']:
|
||||
# Force EEVEE to update its internal state
|
||||
try:
|
||||
# Update depsgraph to ensure everything is initialized
|
||||
depsgraph = bpy.context.evaluated_depsgraph_get()
|
||||
if depsgraph:
|
||||
# Force update
|
||||
depsgraph.update()
|
||||
print("Forced EEVEE depsgraph update for headless rendering")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not force EEVEE update: {e}")
|
||||
|
||||
# Ensure EEVEE settings are applied
|
||||
try:
|
||||
# Force a material update to ensure shaders are compiled
|
||||
for obj in scene.objects:
|
||||
if obj.type == 'MESH' and obj.data.materials:
|
||||
for mat in obj.data.materials:
|
||||
if mat and mat.use_nodes:
|
||||
# Touch the material to force update
|
||||
mat.use_nodes = mat.use_nodes
|
||||
print("Forced material updates for EEVEE")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not update materials: {e}")
|
||||
|
||||
# For Cycles, ensure proper initialization
|
||||
if current_engine == 'CYCLES':
|
||||
# Ensure samples are set (even if 1 for preview)
|
||||
if not hasattr(scene.cycles, 'samples') or scene.cycles.samples < 1:
|
||||
scene.cycles.samples = 1
|
||||
print("Set minimum Cycles samples")
|
||||
|
||||
# Check for lights in the scene
|
||||
lights = [obj for obj in scene.objects if obj.type == 'LIGHT']
|
||||
print(f"Found {len(lights)} light(s) in scene")
|
||||
if len(lights) == 0:
|
||||
print("WARNING: No lights found in scene - rendering may be black!")
|
||||
print(" Consider adding lights or ensuring world background emits light")
|
||||
|
||||
# Ensure world background emits light (critical for Cycles)
|
||||
if scene.world and scene.world.use_nodes:
|
||||
world_nodes = scene.world.node_tree
|
||||
if world_nodes:
|
||||
bg_shader = None
|
||||
for node in world_nodes.nodes:
|
||||
if node.type == 'BACKGROUND':
|
||||
bg_shader = node
|
||||
break
|
||||
|
||||
if bg_shader:
|
||||
# Check and set strength - Cycles needs this to emit light!
|
||||
if hasattr(bg_shader.inputs, 'Strength'):
|
||||
strength = bg_shader.inputs['Strength'].default_value
|
||||
if strength <= 0.0:
|
||||
bg_shader.inputs['Strength'].default_value = 1.0
|
||||
print("Set world background strength to 1.0 for Cycles lighting")
|
||||
else:
|
||||
print(f"World background strength: {strength}")
|
||||
# Also ensure color is not pure black
|
||||
if hasattr(bg_shader.inputs, 'Color'):
|
||||
color = bg_shader.inputs['Color'].default_value
|
||||
if len(color) >= 3 and color[0] == 0.0 and color[1] == 0.0 and color[2] == 0.0:
|
||||
bg_shader.inputs['Color'].default_value = (1.0, 1.0, 1.0, 1.0)
|
||||
print("Set world background color to white for Cycles lighting")
|
||||
|
||||
# Check film_transparent setting - if enabled, background will be transparent/black
|
||||
if hasattr(scene.cycles, 'film_transparent') and scene.cycles.film_transparent:
|
||||
print("WARNING: film_transparent is enabled - background will be transparent")
|
||||
print(" If you see black renders, try disabling film_transparent")
|
||||
|
||||
# Force Cycles to update/compile materials and shaders
|
||||
try:
|
||||
# Update depsgraph to ensure everything is initialized
|
||||
depsgraph = bpy.context.evaluated_depsgraph_get()
|
||||
if depsgraph:
|
||||
depsgraph.update()
|
||||
print("Forced Cycles depsgraph update")
|
||||
|
||||
# Force material updates to ensure shaders are compiled
|
||||
for obj in scene.objects:
|
||||
if obj.type == 'MESH' and obj.data.materials:
|
||||
for mat in obj.data.materials:
|
||||
if mat and mat.use_nodes:
|
||||
# Force material update
|
||||
mat.use_nodes = mat.use_nodes
|
||||
print("Forced Cycles material updates")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not force Cycles updates: {e}")
|
||||
|
||||
# Verify device is actually set correctly
|
||||
if hasattr(scene.cycles, 'device'):
|
||||
actual_device = scene.cycles.device
|
||||
print(f"Cycles device setting: {actual_device}")
|
||||
if actual_device == 'GPU':
|
||||
# Try to verify GPU is actually available
|
||||
try:
|
||||
prefs = bpy.context.preferences
|
||||
cycles_prefs = prefs.addons['cycles'].preferences
|
||||
devices = cycles_prefs.devices
|
||||
enabled_devices = [d for d in devices if d.use]
|
||||
if len(enabled_devices) == 0:
|
||||
print("WARNING: GPU device set but no GPU devices are enabled!")
|
||||
print(" Falling back to CPU may cause issues")
|
||||
except Exception as e:
|
||||
print(f"Could not verify GPU devices: {e}")
|
||||
|
||||
# Ensure camera exists and is active
|
||||
if scene.camera is None:
|
||||
# Find first camera in scene
|
||||
for obj in scene.objects:
|
||||
if obj.type == 'CAMERA':
|
||||
scene.camera = obj
|
||||
print(f"Set active camera: {obj.name}")
|
||||
break
|
||||
|
||||
print("Headless rendering initialization complete")
|
||||
except Exception as e:
|
||||
print(f"Warning: Headless rendering initialization had issues: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# Final verification before rendering
|
||||
print("\n=== Pre-render verification ===")
|
||||
try:
|
||||
scene = bpy.context.scene
|
||||
print(f"Render engine: {scene.render.engine}")
|
||||
print(f"Active camera: {scene.camera.name if scene.camera else 'None'}")
|
||||
|
||||
if scene.render.engine == 'CYCLES':
|
||||
print(f"Cycles device: {scene.cycles.device}")
|
||||
print(f"Cycles samples: {scene.cycles.samples}")
|
||||
lights = [obj for obj in scene.objects if obj.type == 'LIGHT']
|
||||
print(f"Lights in scene: {len(lights)}")
|
||||
if scene.world:
|
||||
if scene.world.use_nodes:
|
||||
world_nodes = scene.world.node_tree
|
||||
if world_nodes:
|
||||
bg_shader = None
|
||||
for node in world_nodes.nodes:
|
||||
if node.type == 'BACKGROUND':
|
||||
bg_shader = node
|
||||
break
|
||||
if bg_shader:
|
||||
if hasattr(bg_shader.inputs, 'Strength'):
|
||||
strength = bg_shader.inputs['Strength'].default_value
|
||||
print(f"World background strength: {strength}")
|
||||
if hasattr(bg_shader.inputs, 'Color'):
|
||||
color = bg_shader.inputs['Color'].default_value
|
||||
print(f"World background color: ({color[0]:.2f}, {color[1]:.2f}, {color[2]:.2f})")
|
||||
else:
|
||||
print("World exists but nodes are disabled")
|
||||
else:
|
||||
print("WARNING: No world in scene!")
|
||||
|
||||
print("=== Verification complete ===\n")
|
||||
except Exception as e:
|
||||
print(f"Warning: Verification failed: {e}")
|
||||
|
||||
print("Device configuration complete - blend file settings preserved, device optimized")
|
||||
sys.stdout.flush()
|
||||
|
||||
29
pkg/scripts/scripts/unhide_objects.py
Normal file
29
pkg/scripts/scripts/unhide_objects.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Fix objects and collections hidden from render
|
||||
vl = bpy.context.view_layer
|
||||
|
||||
# 1. Objects hidden in view layer
|
||||
print("Checking for objects hidden from render that need to be enabled...")
|
||||
try:
|
||||
for obj in bpy.data.objects:
|
||||
if obj.hide_get(view_layer=vl):
|
||||
if any(k in obj.name.lower() for k in ["scrotum|","cage","genital","penis","dick","collision","body.001","couch"]):
|
||||
obj.hide_set(False, view_layer=vl)
|
||||
print("Enabled object:", obj.name)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not check/fix hidden render objects: {e}")
|
||||
|
||||
# 2. Collections disabled in renders OR set to Holdout (the final killer)
|
||||
print("Checking for collections hidden from render that need to be enabled...")
|
||||
try:
|
||||
for col in bpy.data.collections:
|
||||
if col.hide_render or (vl.layer_collection.children.get(col.name) and not vl.layer_collection.children[col.name].exclude == False):
|
||||
if any(k in col.name.lower() for k in ["genital","nsfw","dick","private","hidden","cage","scrotum","collision","dick"]):
|
||||
col.hide_render = False
|
||||
if col.name in vl.layer_collection.children:
|
||||
vl.layer_collection.children[col.name].exclude = False
|
||||
vl.layer_collection.children[col.name].holdout = False
|
||||
vl.layer_collection.children[col.name].indirect_only = False
|
||||
print("Enabled collection:", col.name)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not check/fix hidden render collections: {e}")
|
||||
|
||||
@@ -27,28 +27,25 @@ const (
|
||||
type JobType string
|
||||
|
||||
const (
|
||||
JobTypeMetadata JobType = "metadata" // Metadata extraction job - only needs blend file
|
||||
JobTypeRender JobType = "render" // Render job - needs frame range, format, etc.
|
||||
JobTypeRender JobType = "render" // Render job - needs frame range, format, etc.
|
||||
)
|
||||
|
||||
// Job represents a job (metadata extraction or render)
|
||||
// Job represents a render job
|
||||
type Job struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
JobType JobType `json:"job_type"` // "metadata" or "render"
|
||||
Name string `json:"name"`
|
||||
Status JobStatus `json:"status"`
|
||||
Progress float64 `json:"progress"` // 0.0 to 100.0
|
||||
FrameStart *int `json:"frame_start,omitempty"` // Only for render jobs
|
||||
FrameEnd *int `json:"frame_end,omitempty"` // Only for render jobs
|
||||
OutputFormat *string `json:"output_format,omitempty"` // Only for render jobs - PNG, JPEG, EXR, etc.
|
||||
AllowParallelRunners *bool `json:"allow_parallel_runners,omitempty"` // Only for render jobs
|
||||
TimeoutSeconds int `json:"timeout_seconds"` // Job-level timeout (24 hours default)
|
||||
BlendMetadata *BlendMetadata `json:"blend_metadata,omitempty"` // Extracted metadata from blend file
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
CompletedAt *time.Time `json:"completed_at,omitempty"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
JobType JobType `json:"job_type"` // "render"
|
||||
Name string `json:"name"`
|
||||
Status JobStatus `json:"status"`
|
||||
Progress float64 `json:"progress"` // 0.0 to 100.0
|
||||
FrameStart *int `json:"frame_start,omitempty"` // Only for render jobs
|
||||
FrameEnd *int `json:"frame_end,omitempty"` // Only for render jobs
|
||||
OutputFormat *string `json:"output_format,omitempty"` // Only for render jobs - PNG, JPEG, EXR, etc.
|
||||
BlendMetadata *BlendMetadata `json:"blend_metadata,omitempty"` // Extracted metadata from blend file
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
CompletedAt *time.Time `json:"completed_at,omitempty"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
}
|
||||
|
||||
// RunnerStatus represents the status of a runner
|
||||
@@ -87,9 +84,8 @@ const (
|
||||
type TaskType string
|
||||
|
||||
const (
|
||||
TaskTypeRender TaskType = "render"
|
||||
TaskTypeMetadata TaskType = "metadata"
|
||||
TaskTypeVideoGeneration TaskType = "video_generation"
|
||||
TaskTypeRender TaskType = "render"
|
||||
TaskTypeEncode TaskType = "encode"
|
||||
)
|
||||
|
||||
// Task represents a render task assigned to a runner
|
||||
@@ -97,8 +93,7 @@ type Task struct {
|
||||
ID int64 `json:"id"`
|
||||
JobID int64 `json:"job_id"`
|
||||
RunnerID *int64 `json:"runner_id,omitempty"`
|
||||
FrameStart int `json:"frame_start"`
|
||||
FrameEnd int `json:"frame_end"`
|
||||
Frame int `json:"frame"`
|
||||
TaskType TaskType `json:"task_type"`
|
||||
Status TaskStatus `json:"status"`
|
||||
CurrentStep string `json:"current_step,omitempty"`
|
||||
@@ -133,13 +128,18 @@ type JobFile struct {
|
||||
|
||||
// CreateJobRequest represents a request to create a new job
|
||||
type CreateJobRequest struct {
|
||||
JobType JobType `json:"job_type"` // "metadata" or "render"
|
||||
Name string `json:"name"`
|
||||
FrameStart *int `json:"frame_start,omitempty"` // Required for render jobs
|
||||
FrameEnd *int `json:"frame_end,omitempty"` // Required for render jobs
|
||||
OutputFormat *string `json:"output_format,omitempty"` // Required for render jobs
|
||||
AllowParallelRunners *bool `json:"allow_parallel_runners,omitempty"` // Optional for render jobs, defaults to true
|
||||
MetadataJobID *int64 `json:"metadata_job_id,omitempty"` // Optional: ID of metadata job to copy input files from
|
||||
JobType JobType `json:"job_type"` // "render"
|
||||
Name string `json:"name"`
|
||||
FrameStart *int `json:"frame_start,omitempty"` // Required for render jobs
|
||||
FrameEnd *int `json:"frame_end,omitempty"` // Required for render jobs
|
||||
OutputFormat *string `json:"output_format,omitempty"` // Required for render jobs
|
||||
RenderSettings *RenderSettings `json:"render_settings,omitempty"` // Optional: Override blend file render settings
|
||||
UploadSessionID *string `json:"upload_session_id,omitempty"` // Optional: Session ID from file upload
|
||||
UnhideObjects *bool `json:"unhide_objects,omitempty"` // Optional: Enable unhide tweaks for objects/collections
|
||||
EnableExecution *bool `json:"enable_execution,omitempty"` // Optional: Enable auto-execution in Blender (adds --enable-autoexec flag, defaults to false)
|
||||
BlenderVersion *string `json:"blender_version,omitempty"` // Optional: Override Blender version (e.g., "4.2" or "4.2.3")
|
||||
PreserveHDR *bool `json:"preserve_hdr,omitempty"` // Optional: Preserve HDR range for EXR encoding (uses HLG with bt709 primaries)
|
||||
PreserveAlpha *bool `json:"preserve_alpha,omitempty"` // Optional: Preserve alpha channel for EXR encoding (requires AV1 or VP9 codec)
|
||||
}
|
||||
|
||||
// UpdateJobProgressRequest represents a request to update job progress
|
||||
@@ -151,7 +151,7 @@ type UpdateJobProgressRequest struct {
|
||||
type RegisterRunnerRequest struct {
|
||||
Name string `json:"name"`
|
||||
Hostname string `json:"hostname"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
IPAddress string `json:"ip_address,omitempty"` // Optional, extracted from request by manager
|
||||
Capabilities string `json:"capabilities"`
|
||||
Priority *int `json:"priority,omitempty"` // Optional, defaults to 100 if not provided
|
||||
}
|
||||
@@ -225,19 +225,37 @@ type TaskLogEntry struct {
|
||||
|
||||
// BlendMetadata represents extracted metadata from a blend file
|
||||
type BlendMetadata struct {
|
||||
FrameStart int `json:"frame_start"`
|
||||
FrameEnd int `json:"frame_end"`
|
||||
RenderSettings RenderSettings `json:"render_settings"`
|
||||
SceneInfo SceneInfo `json:"scene_info"`
|
||||
FrameStart int `json:"frame_start"`
|
||||
FrameEnd int `json:"frame_end"`
|
||||
HasNegativeFrames bool `json:"has_negative_frames"` // True if blend file has negative frame numbers (not supported)
|
||||
RenderSettings RenderSettings `json:"render_settings"`
|
||||
SceneInfo SceneInfo `json:"scene_info"`
|
||||
MissingFilesInfo *MissingFilesInfo `json:"missing_files_info,omitempty"`
|
||||
UnhideObjects *bool `json:"unhide_objects,omitempty"` // Enable unhide tweaks for objects/collections
|
||||
EnableExecution *bool `json:"enable_execution,omitempty"` // Enable auto-execution in Blender (adds --enable-autoexec flag, defaults to false)
|
||||
BlenderVersion string `json:"blender_version,omitempty"` // Detected or overridden Blender version (e.g., "4.2" or "4.2.3")
|
||||
PreserveHDR *bool `json:"preserve_hdr,omitempty"` // Preserve HDR range for EXR encoding (uses HLG with bt709 primaries)
|
||||
PreserveAlpha *bool `json:"preserve_alpha,omitempty"` // Preserve alpha channel for EXR encoding (requires AV1 or VP9 codec)
|
||||
}
|
||||
|
||||
// MissingFilesInfo represents information about missing files/addons
|
||||
type MissingFilesInfo struct {
|
||||
Checked bool `json:"checked"`
|
||||
HasMissing bool `json:"has_missing"`
|
||||
MissingFiles []string `json:"missing_files,omitempty"`
|
||||
MissingAddons []string `json:"missing_addons,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// RenderSettings represents render settings from a blend file
|
||||
type RenderSettings struct {
|
||||
ResolutionX int `json:"resolution_x"`
|
||||
ResolutionY int `json:"resolution_y"`
|
||||
Samples int `json:"samples"`
|
||||
OutputFormat string `json:"output_format"`
|
||||
Engine string `json:"engine"`
|
||||
ResolutionX int `json:"resolution_x"`
|
||||
ResolutionY int `json:"resolution_y"`
|
||||
FrameRate float64 `json:"frame_rate"`
|
||||
Samples int `json:"samples,omitempty"` // Deprecated, use EngineSettings
|
||||
OutputFormat string `json:"output_format"`
|
||||
Engine string `json:"engine"`
|
||||
EngineSettings map[string]interface{} `json:"engine_settings,omitempty"`
|
||||
}
|
||||
|
||||
// SceneInfo represents scene information from a blend file
|
||||
|
||||
16
version/version.go
Normal file
16
version/version.go
Normal file
@@ -0,0 +1,16 @@
|
||||
// version/version.go
|
||||
package version
|
||||
|
||||
import "time"
|
||||
|
||||
var Version string
|
||||
var Date string
|
||||
|
||||
func init() {
|
||||
if Version == "" {
|
||||
Version = "0.0.0-dev"
|
||||
}
|
||||
if Date == "" {
|
||||
Date = time.Now().Format("2006-01-02 15:04:05")
|
||||
}
|
||||
}
|
||||
@@ -241,7 +241,6 @@ function displayRunners(runners) {
|
||||
<h3>${escapeHtml(runner.name)}</h3>
|
||||
<div class="runner-info">
|
||||
<span>Hostname: ${escapeHtml(runner.hostname)}</span>
|
||||
<span>IP: ${escapeHtml(runner.ip_address)}</span>
|
||||
<span>Last heartbeat: ${lastHeartbeat.toLocaleString()}</span>
|
||||
</div>
|
||||
<div class="runner-status ${isOnline ? 'online' : 'offline'}">
|
||||
|
||||
45
web/embed.go
Normal file
45
web/embed.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//go:embed dist/*
|
||||
var distFS embed.FS
|
||||
|
||||
// GetFileSystem returns an http.FileSystem for the embedded web UI files
|
||||
func GetFileSystem() http.FileSystem {
|
||||
subFS, err := fs.Sub(distFS, "dist")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return http.FS(subFS)
|
||||
}
|
||||
|
||||
// SPAHandler returns an http.Handler that serves the embedded SPA
|
||||
// It serves static files if they exist, otherwise falls back to index.html
|
||||
func SPAHandler() http.Handler {
|
||||
fsys := GetFileSystem()
|
||||
fileServer := http.FileServer(fsys)
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
path := r.URL.Path
|
||||
|
||||
// Try to open the file
|
||||
f, err := fsys.Open(strings.TrimPrefix(path, "/"))
|
||||
if err != nil {
|
||||
// File doesn't exist, serve index.html for SPA routing
|
||||
r.URL.Path = "/"
|
||||
fileServer.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
f.Close()
|
||||
|
||||
// File exists, serve it
|
||||
fileServer.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import Layout from './components/Layout';
|
||||
import JobList from './components/JobList';
|
||||
import JobSubmission from './components/JobSubmission';
|
||||
import AdminPanel from './components/AdminPanel';
|
||||
import ErrorBoundary from './components/ErrorBoundary';
|
||||
import LoadingSpinner from './components/LoadingSpinner';
|
||||
import './styles/index.css';
|
||||
|
||||
function App() {
|
||||
@@ -17,7 +19,7 @@ function App() {
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="min-h-screen flex items-center justify-center bg-gray-900">
|
||||
<div className="animate-spin rounded-full h-12 w-12 border-b-2 border-orange-500"></div>
|
||||
<LoadingSpinner size="md" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -26,26 +28,20 @@ function App() {
|
||||
return loginComponent;
|
||||
}
|
||||
|
||||
// Wrapper to check auth before changing tabs
|
||||
const handleTabChange = async (newTab) => {
|
||||
// Check auth before allowing navigation
|
||||
try {
|
||||
await refresh();
|
||||
// If refresh succeeds, user is still authenticated
|
||||
setActiveTab(newTab);
|
||||
} catch (error) {
|
||||
// Auth check failed, user will be set to null and login will show
|
||||
console.error('Auth check failed on navigation:', error);
|
||||
}
|
||||
// Wrapper to change tabs - only check auth on mount, not on every navigation
|
||||
const handleTabChange = (newTab) => {
|
||||
setActiveTab(newTab);
|
||||
};
|
||||
|
||||
return (
|
||||
<Layout activeTab={activeTab} onTabChange={handleTabChange}>
|
||||
{activeTab === 'jobs' && <JobList />}
|
||||
{activeTab === 'submit' && (
|
||||
<JobSubmission onSuccess={() => handleTabChange('jobs')} />
|
||||
)}
|
||||
{activeTab === 'admin' && <AdminPanel />}
|
||||
<ErrorBoundary>
|
||||
{activeTab === 'jobs' && <JobList />}
|
||||
{activeTab === 'submit' && (
|
||||
<JobSubmission onSuccess={() => handleTabChange('jobs')} />
|
||||
)}
|
||||
{activeTab === 'admin' && <AdminPanel />}
|
||||
</ErrorBoundary>
|
||||
</Layout>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,41 +1,136 @@
|
||||
import { useState, useEffect } from 'react';
|
||||
import { admin } from '../utils/api';
|
||||
import { useState, useEffect, useRef } from 'react';
|
||||
import { admin, jobs, normalizeArrayResponse } from '../utils/api';
|
||||
import { wsManager } from '../utils/websocket';
|
||||
import UserJobs from './UserJobs';
|
||||
import PasswordChange from './PasswordChange';
|
||||
import LoadingSpinner from './LoadingSpinner';
|
||||
|
||||
export default function AdminPanel() {
|
||||
const [activeSection, setActiveSection] = useState('tokens');
|
||||
const [tokens, setTokens] = useState([]);
|
||||
const [activeSection, setActiveSection] = useState('api-keys');
|
||||
const [apiKeys, setApiKeys] = useState([]);
|
||||
const [runners, setRunners] = useState([]);
|
||||
const [users, setUsers] = useState([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [newTokenExpires, setNewTokenExpires] = useState(24);
|
||||
const [newToken, setNewToken] = useState(null);
|
||||
const [newAPIKeyName, setNewAPIKeyName] = useState('');
|
||||
const [newAPIKeyDescription, setNewAPIKeyDescription] = useState('');
|
||||
const [newAPIKeyScope, setNewAPIKeyScope] = useState('user'); // Default to user scope
|
||||
const [newAPIKey, setNewAPIKey] = useState(null);
|
||||
const [selectedUser, setSelectedUser] = useState(null);
|
||||
const [registrationEnabled, setRegistrationEnabled] = useState(true);
|
||||
const [passwordChangeUser, setPasswordChangeUser] = useState(null);
|
||||
const listenerIdRef = useRef(null); // Listener ID for shared WebSocket
|
||||
const subscribedChannelsRef = useRef(new Set()); // Track confirmed subscribed channels
|
||||
const pendingSubscriptionsRef = useRef(new Set()); // Track pending subscriptions (waiting for confirmation)
|
||||
|
||||
// Connect to shared WebSocket on mount
|
||||
useEffect(() => {
|
||||
listenerIdRef.current = wsManager.subscribe('adminpanel', {
|
||||
open: () => {
|
||||
console.log('AdminPanel: Shared WebSocket connected');
|
||||
// Subscribe to runners if already viewing runners section
|
||||
if (activeSection === 'runners') {
|
||||
subscribeToRunners();
|
||||
}
|
||||
},
|
||||
message: (data) => {
|
||||
// Handle subscription responses - update both local refs and wsManager
|
||||
if (data.type === 'subscribed' && data.channel) {
|
||||
pendingSubscriptionsRef.current.delete(data.channel);
|
||||
subscribedChannelsRef.current.add(data.channel);
|
||||
wsManager.confirmSubscription(data.channel);
|
||||
console.log('Successfully subscribed to channel:', data.channel);
|
||||
} else if (data.type === 'subscription_error' && data.channel) {
|
||||
pendingSubscriptionsRef.current.delete(data.channel);
|
||||
subscribedChannelsRef.current.delete(data.channel);
|
||||
wsManager.failSubscription(data.channel);
|
||||
console.error('Subscription failed for channel:', data.channel, data.error);
|
||||
}
|
||||
|
||||
// Handle runners channel messages
|
||||
if (data.channel === 'runners' && data.type === 'runner_status') {
|
||||
// Update runner in list
|
||||
setRunners(prev => {
|
||||
const index = prev.findIndex(r => r.id === data.runner_id);
|
||||
if (index >= 0 && data.data) {
|
||||
const updated = [...prev];
|
||||
updated[index] = { ...updated[index], ...data.data };
|
||||
return updated;
|
||||
}
|
||||
return prev;
|
||||
});
|
||||
}
|
||||
},
|
||||
error: (error) => {
|
||||
console.error('AdminPanel: Shared WebSocket error:', error);
|
||||
},
|
||||
close: (event) => {
|
||||
console.log('AdminPanel: Shared WebSocket closed:', event);
|
||||
subscribedChannelsRef.current.clear();
|
||||
pendingSubscriptionsRef.current.clear();
|
||||
}
|
||||
});
|
||||
|
||||
// Ensure connection is established
|
||||
wsManager.connect();
|
||||
|
||||
return () => {
|
||||
// Unsubscribe from all channels before unmounting
|
||||
unsubscribeFromRunners();
|
||||
if (listenerIdRef.current) {
|
||||
wsManager.unsubscribe(listenerIdRef.current);
|
||||
listenerIdRef.current = null;
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
const subscribeToRunners = () => {
|
||||
const channel = 'runners';
|
||||
// Don't subscribe if already subscribed or pending
|
||||
if (subscribedChannelsRef.current.has(channel) || pendingSubscriptionsRef.current.has(channel)) {
|
||||
return;
|
||||
}
|
||||
wsManager.subscribeToChannel(channel);
|
||||
subscribedChannelsRef.current.add(channel);
|
||||
pendingSubscriptionsRef.current.add(channel);
|
||||
console.log('Subscribing to runners channel');
|
||||
};
|
||||
|
||||
const unsubscribeFromRunners = () => {
|
||||
const channel = 'runners';
|
||||
if (!subscribedChannelsRef.current.has(channel)) {
|
||||
return; // Not subscribed
|
||||
}
|
||||
wsManager.unsubscribeFromChannel(channel);
|
||||
subscribedChannelsRef.current.delete(channel);
|
||||
pendingSubscriptionsRef.current.delete(channel);
|
||||
console.log('Unsubscribed from runners channel');
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (activeSection === 'tokens') {
|
||||
loadTokens();
|
||||
if (activeSection === 'api-keys') {
|
||||
loadAPIKeys();
|
||||
unsubscribeFromRunners();
|
||||
} else if (activeSection === 'runners') {
|
||||
loadRunners();
|
||||
subscribeToRunners();
|
||||
} else if (activeSection === 'users') {
|
||||
loadUsers();
|
||||
unsubscribeFromRunners();
|
||||
} else if (activeSection === 'settings') {
|
||||
loadSettings();
|
||||
unsubscribeFromRunners();
|
||||
}
|
||||
}, [activeSection]);
|
||||
|
||||
const loadTokens = async () => {
|
||||
const loadAPIKeys = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const data = await admin.listTokens();
|
||||
setTokens(Array.isArray(data) ? data : []);
|
||||
const data = await admin.listAPIKeys();
|
||||
setApiKeys(normalizeArrayResponse(data));
|
||||
} catch (error) {
|
||||
console.error('Failed to load tokens:', error);
|
||||
setTokens([]);
|
||||
alert('Failed to load tokens');
|
||||
console.error('Failed to load API keys:', error);
|
||||
setApiKeys([]);
|
||||
alert('Failed to load API keys');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
@@ -45,7 +140,7 @@ export default function AdminPanel() {
|
||||
setLoading(true);
|
||||
try {
|
||||
const data = await admin.listRunners();
|
||||
setRunners(Array.isArray(data) ? data : []);
|
||||
setRunners(normalizeArrayResponse(data));
|
||||
} catch (error) {
|
||||
console.error('Failed to load runners:', error);
|
||||
setRunners([]);
|
||||
@@ -59,7 +154,7 @@ export default function AdminPanel() {
|
||||
setLoading(true);
|
||||
try {
|
||||
const data = await admin.listUsers();
|
||||
setUsers(Array.isArray(data) ? data : []);
|
||||
setUsers(normalizeArrayResponse(data));
|
||||
} catch (error) {
|
||||
console.error('Failed to load users:', error);
|
||||
setUsers([]);
|
||||
@@ -97,54 +192,61 @@ export default function AdminPanel() {
|
||||
}
|
||||
};
|
||||
|
||||
const generateToken = async () => {
|
||||
const generateAPIKey = async () => {
|
||||
if (!newAPIKeyName.trim()) {
|
||||
alert('API key name is required');
|
||||
return;
|
||||
}
|
||||
|
||||
setLoading(true);
|
||||
try {
|
||||
const data = await admin.generateToken(newTokenExpires);
|
||||
setNewToken(data.token);
|
||||
await loadTokens();
|
||||
const data = await admin.generateAPIKey(newAPIKeyName.trim(), newAPIKeyDescription.trim() || undefined, newAPIKeyScope);
|
||||
setNewAPIKey(data);
|
||||
setNewAPIKeyName('');
|
||||
setNewAPIKeyDescription('');
|
||||
setNewAPIKeyScope('user');
|
||||
await loadAPIKeys();
|
||||
} catch (error) {
|
||||
console.error('Failed to generate token:', error);
|
||||
alert('Failed to generate token');
|
||||
console.error('Failed to generate API key:', error);
|
||||
alert('Failed to generate API key');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const revokeToken = async (tokenId) => {
|
||||
if (!confirm('Are you sure you want to revoke this token?')) {
|
||||
const [deletingKeyId, setDeletingKeyId] = useState(null);
|
||||
const [deletingRunnerId, setDeletingRunnerId] = useState(null);
|
||||
|
||||
const revokeAPIKey = async (keyId) => {
|
||||
if (!confirm('Are you sure you want to delete this API key? This action cannot be undone.')) {
|
||||
return;
|
||||
}
|
||||
setDeletingKeyId(keyId);
|
||||
try {
|
||||
await admin.revokeToken(tokenId);
|
||||
await loadTokens();
|
||||
await admin.deleteAPIKey(keyId);
|
||||
await loadAPIKeys();
|
||||
} catch (error) {
|
||||
console.error('Failed to revoke token:', error);
|
||||
alert('Failed to revoke token');
|
||||
console.error('Failed to delete API key:', error);
|
||||
alert('Failed to delete API key');
|
||||
} finally {
|
||||
setDeletingKeyId(null);
|
||||
}
|
||||
};
|
||||
|
||||
const verifyRunner = async (runnerId) => {
|
||||
try {
|
||||
await admin.verifyRunner(runnerId);
|
||||
await loadRunners();
|
||||
alert('Runner verified');
|
||||
} catch (error) {
|
||||
console.error('Failed to verify runner:', error);
|
||||
alert('Failed to verify runner');
|
||||
}
|
||||
};
|
||||
|
||||
const deleteRunner = async (runnerId) => {
|
||||
if (!confirm('Are you sure you want to delete this runner?')) {
|
||||
return;
|
||||
}
|
||||
setDeletingRunnerId(runnerId);
|
||||
try {
|
||||
await admin.deleteRunner(runnerId);
|
||||
await loadRunners();
|
||||
} catch (error) {
|
||||
console.error('Failed to delete runner:', error);
|
||||
alert('Failed to delete runner');
|
||||
} finally {
|
||||
setDeletingRunnerId(null);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -153,12 +255,8 @@ export default function AdminPanel() {
|
||||
alert('Copied to clipboard!');
|
||||
};
|
||||
|
||||
const isTokenExpired = (expiresAt) => {
|
||||
return new Date(expiresAt) < new Date();
|
||||
};
|
||||
|
||||
const isTokenUsed = (used) => {
|
||||
return used;
|
||||
const isAPIKeyActive = (isActive) => {
|
||||
return isActive;
|
||||
};
|
||||
|
||||
return (
|
||||
@@ -166,16 +264,16 @@ export default function AdminPanel() {
|
||||
<div className="flex space-x-4 border-b border-gray-700">
|
||||
<button
|
||||
onClick={() => {
|
||||
setActiveSection('tokens');
|
||||
setActiveSection('api-keys');
|
||||
setSelectedUser(null);
|
||||
}}
|
||||
className={`py-2 px-4 border-b-2 font-medium ${
|
||||
activeSection === 'tokens'
|
||||
activeSection === 'api-keys'
|
||||
? 'border-orange-500 text-orange-500'
|
||||
: 'border-transparent text-gray-400 hover:text-gray-300'
|
||||
}`}
|
||||
>
|
||||
Registration Tokens
|
||||
API Keys
|
||||
</button>
|
||||
<button
|
||||
onClick={() => {
|
||||
@@ -218,76 +316,112 @@ export default function AdminPanel() {
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{activeSection === 'tokens' && (
|
||||
{activeSection === 'api-keys' && (
|
||||
<div className="space-y-6">
|
||||
<div className="bg-gray-800 rounded-lg shadow-md p-6 border border-gray-700">
|
||||
<h2 className="text-xl font-semibold mb-4 text-gray-100">Generate Registration Token</h2>
|
||||
<div className="flex gap-4 items-end">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-gray-300 mb-2">
|
||||
Expires in (hours)
|
||||
</label>
|
||||
<input
|
||||
type="number"
|
||||
min="1"
|
||||
max="168"
|
||||
value={newTokenExpires}
|
||||
onChange={(e) => setNewTokenExpires(parseInt(e.target.value) || 24)}
|
||||
className="w-32 px-3 py-2 bg-gray-900 border border-gray-600 rounded-lg text-gray-100 focus:ring-2 focus:ring-orange-500 focus:border-transparent"
|
||||
/>
|
||||
<h2 className="text-xl font-semibold mb-4 text-gray-100">Generate API Key</h2>
|
||||
<div className="space-y-4">
|
||||
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-gray-300 mb-2">
|
||||
Name *
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={newAPIKeyName}
|
||||
onChange={(e) => setNewAPIKeyName(e.target.value)}
|
||||
placeholder="e.g., production-runner-01"
|
||||
className="w-full px-3 py-2 bg-gray-900 border border-gray-600 rounded-lg text-gray-100 focus:ring-2 focus:ring-orange-500 focus:border-transparent"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-gray-300 mb-2">
|
||||
Description
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={newAPIKeyDescription}
|
||||
onChange={(e) => setNewAPIKeyDescription(e.target.value)}
|
||||
placeholder="Optional description"
|
||||
className="w-full px-3 py-2 bg-gray-900 border border-gray-600 rounded-lg text-gray-100 focus:ring-2 focus:ring-orange-500 focus:border-transparent"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-gray-300 mb-2">
|
||||
Scope
|
||||
</label>
|
||||
<select
|
||||
value={newAPIKeyScope}
|
||||
onChange={(e) => setNewAPIKeyScope(e.target.value)}
|
||||
className="w-full px-3 py-2 bg-gray-900 border border-gray-600 rounded-lg text-gray-100 focus:ring-2 focus:ring-orange-500 focus:border-transparent"
|
||||
>
|
||||
<option value="user">User - Only jobs from API key owner</option>
|
||||
<option value="manager">Manager - All jobs from any user</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex justify-end">
|
||||
<button
|
||||
onClick={generateAPIKey}
|
||||
disabled={loading || !newAPIKeyName.trim()}
|
||||
className="px-6 py-2 bg-orange-600 text-white rounded-lg hover:bg-orange-500 disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
|
||||
>
|
||||
Generate API Key
|
||||
</button>
|
||||
</div>
|
||||
<button
|
||||
onClick={generateToken}
|
||||
disabled={loading}
|
||||
className="px-6 py-2 bg-orange-600 text-white rounded-lg hover:bg-orange-500 disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
|
||||
>
|
||||
Generate Token
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{newToken && (
|
||||
{newAPIKey && (
|
||||
<div className="mt-4 p-4 bg-green-400/20 border border-green-400/50 rounded-lg">
|
||||
<p className="text-sm font-medium text-green-400 mb-2">New Token Generated:</p>
|
||||
<div className="flex items-center gap-2">
|
||||
<code className="flex-1 px-3 py-2 bg-gray-900 border border-green-400/50 rounded text-sm font-mono break-all text-gray-100">
|
||||
{newToken}
|
||||
</code>
|
||||
<button
|
||||
onClick={() => copyToClipboard(newToken)}
|
||||
className="px-4 py-2 bg-green-600 text-white rounded hover:bg-green-500 transition-colors text-sm"
|
||||
>
|
||||
Copy
|
||||
</button>
|
||||
<p className="text-sm font-medium text-green-400 mb-2">New API Key Generated:</p>
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center gap-2">
|
||||
<code className="flex-1 px-3 py-2 bg-gray-900 border border-green-400/50 rounded text-sm font-mono break-all text-gray-100">
|
||||
{newAPIKey.key}
|
||||
</code>
|
||||
<button
|
||||
onClick={() => copyToClipboard(newAPIKey.key)}
|
||||
className="px-4 py-2 bg-green-600 text-white rounded hover:bg-green-500 transition-colors text-sm whitespace-nowrap"
|
||||
>
|
||||
Copy Key
|
||||
</button>
|
||||
</div>
|
||||
<div className="text-xs text-green-400/80">
|
||||
<p><strong>Name:</strong> {newAPIKey.name}</p>
|
||||
{newAPIKey.description && <p><strong>Description:</strong> {newAPIKey.description}</p>}
|
||||
</div>
|
||||
<p className="text-xs text-green-400/80 mt-2">
|
||||
⚠️ Save this API key securely. It will not be shown again.
|
||||
</p>
|
||||
</div>
|
||||
<p className="text-xs text-green-400/80 mt-2">
|
||||
Save this token securely. It will not be shown again.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="bg-gray-800 rounded-lg shadow-md p-6 border border-gray-700">
|
||||
<h2 className="text-xl font-semibold mb-4 text-gray-100">Active Tokens</h2>
|
||||
<h2 className="text-xl font-semibold mb-4 text-gray-100">API Keys</h2>
|
||||
{loading ? (
|
||||
<div className="flex justify-center py-8">
|
||||
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-orange-500"></div>
|
||||
</div>
|
||||
) : !tokens || tokens.length === 0 ? (
|
||||
<p className="text-gray-400 text-center py-8">No tokens generated yet.</p>
|
||||
<LoadingSpinner size="sm" className="py-8" />
|
||||
) : !apiKeys || apiKeys.length === 0 ? (
|
||||
<p className="text-gray-400 text-center py-8">No API keys generated yet.</p>
|
||||
) : (
|
||||
<div className="overflow-x-auto">
|
||||
<table className="min-w-full divide-y divide-gray-700">
|
||||
<thead className="bg-gray-900">
|
||||
<tr>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
|
||||
Token
|
||||
Name
|
||||
</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
|
||||
Scope
|
||||
</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
|
||||
Key Prefix
|
||||
</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
|
||||
Status
|
||||
</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
|
||||
Expires At
|
||||
</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
|
||||
Created At
|
||||
</th>
|
||||
@@ -297,46 +431,54 @@ export default function AdminPanel() {
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody className="bg-gray-800 divide-y divide-gray-700">
|
||||
{tokens.map((token) => {
|
||||
const expired = isTokenExpired(token.expires_at);
|
||||
const used = isTokenUsed(token.used);
|
||||
return (
|
||||
<tr key={token.id}>
|
||||
<td className="px-6 py-4 whitespace-nowrap">
|
||||
<code className="text-sm font-mono text-gray-100">
|
||||
{token.token.substring(0, 16)}...
|
||||
</code>
|
||||
</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap">
|
||||
{expired ? (
|
||||
<span className="px-2 py-1 text-xs font-medium rounded-full bg-red-400/20 text-red-400">
|
||||
Expired
|
||||
</span>
|
||||
) : used ? (
|
||||
<span className="px-2 py-1 text-xs font-medium rounded-full bg-yellow-400/20 text-yellow-400">
|
||||
Used
|
||||
</span>
|
||||
) : (
|
||||
<span className="px-2 py-1 text-xs font-medium rounded-full bg-green-400/20 text-green-400">
|
||||
Active
|
||||
</span>
|
||||
)}
|
||||
</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm text-gray-400">
|
||||
{new Date(token.expires_at).toLocaleString()}
|
||||
</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm text-gray-400">
|
||||
{new Date(token.created_at).toLocaleString()}
|
||||
</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm">
|
||||
{!used && !expired && (
|
||||
<button
|
||||
onClick={() => revokeToken(token.id)}
|
||||
className="text-red-400 hover:text-red-300 font-medium"
|
||||
>
|
||||
Revoke
|
||||
</button>
|
||||
{apiKeys.map((key) => {
|
||||
return (
|
||||
<tr key={key.id}>
|
||||
<td className="px-6 py-4 whitespace-nowrap">
|
||||
<div>
|
||||
<div className="text-sm font-medium text-gray-100">{key.name}</div>
|
||||
{key.description && (
|
||||
<div className="text-sm text-gray-400">{key.description}</div>
|
||||
)}
|
||||
</div>
|
||||
</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap">
|
||||
<span className={`px-2 py-1 text-xs font-medium rounded-full ${
|
||||
key.scope === 'manager'
|
||||
? 'bg-purple-400/20 text-purple-400'
|
||||
: 'bg-blue-400/20 text-blue-400'
|
||||
}`}>
|
||||
{key.scope === 'manager' ? 'Manager' : 'User'}
|
||||
</span>
|
||||
</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap">
|
||||
<code className="text-sm font-mono text-gray-300">
|
||||
{key.key_prefix}
|
||||
</code>
|
||||
</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap">
|
||||
{!key.is_active ? (
|
||||
<span className="px-2 py-1 text-xs font-medium rounded-full bg-gray-500/20 text-gray-400">
|
||||
Revoked
|
||||
</span>
|
||||
) : (
|
||||
<span className="px-2 py-1 text-xs font-medium rounded-full bg-green-400/20 text-green-400">
|
||||
Active
|
||||
</span>
|
||||
)}
|
||||
</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm text-gray-400">
|
||||
{new Date(key.created_at).toLocaleString()}
|
||||
</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm space-x-2">
|
||||
<button
|
||||
onClick={() => revokeAPIKey(key.id)}
|
||||
disabled={deletingKeyId === key.id}
|
||||
className="text-red-400 hover:text-red-300 font-medium disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
title="Delete API key"
|
||||
>
|
||||
{deletingKeyId === key.id ? 'Deleting...' : 'Delete'}
|
||||
</button>
|
||||
</td>
|
||||
</tr>
|
||||
);
|
||||
@@ -353,9 +495,7 @@ export default function AdminPanel() {
|
||||
<div className="bg-gray-800 rounded-lg shadow-md p-6 border border-gray-700">
|
||||
<h2 className="text-xl font-semibold mb-4 text-gray-100">Runner Management</h2>
|
||||
{loading ? (
|
||||
<div className="flex justify-center py-8">
|
||||
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-orange-500"></div>
|
||||
</div>
|
||||
<LoadingSpinner size="sm" className="py-8" />
|
||||
) : !runners || runners.length === 0 ? (
|
||||
<p className="text-gray-400 text-center py-8">No runners registered.</p>
|
||||
) : (
|
||||
@@ -369,14 +509,11 @@ export default function AdminPanel() {
|
||||
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
|
||||
Hostname
|
||||
</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
|
||||
IP Address
|
||||
</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
|
||||
Status
|
||||
</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
|
||||
Verified
|
||||
API Key
|
||||
</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium text-gray-400 uppercase tracking-wider">
|
||||
Priority
|
||||
@@ -403,9 +540,6 @@ export default function AdminPanel() {
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm text-gray-400">
|
||||
{runner.hostname}
|
||||
</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm text-gray-400">
|
||||
{runner.ip_address}
|
||||
</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap">
|
||||
<span
|
||||
className={`px-2 py-1 text-xs font-medium rounded-full ${
|
||||
@@ -417,16 +551,10 @@ export default function AdminPanel() {
|
||||
{isOnline ? 'Online' : 'Offline'}
|
||||
</span>
|
||||
</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap">
|
||||
<span
|
||||
className={`px-2 py-1 text-xs font-medium rounded-full ${
|
||||
runner.verified
|
||||
? 'bg-green-400/20 text-green-400'
|
||||
: 'bg-yellow-400/20 text-yellow-400'
|
||||
}`}
|
||||
>
|
||||
{runner.verified ? 'Verified' : 'Unverified'}
|
||||
</span>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm text-gray-400">
|
||||
<code className="text-xs font-mono bg-gray-900 px-2 py-1 rounded">
|
||||
jk_r{runner.id % 10}_...
|
||||
</code>
|
||||
</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm text-gray-400">
|
||||
{runner.priority}
|
||||
@@ -452,20 +580,13 @@ export default function AdminPanel() {
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm text-gray-400">
|
||||
{new Date(runner.last_heartbeat).toLocaleString()}
|
||||
</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm space-x-2">
|
||||
{!runner.verified && (
|
||||
<button
|
||||
onClick={() => verifyRunner(runner.id)}
|
||||
className="text-orange-400 hover:text-orange-300 font-medium"
|
||||
>
|
||||
Verify
|
||||
</button>
|
||||
)}
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm">
|
||||
<button
|
||||
onClick={() => deleteRunner(runner.id)}
|
||||
className="text-red-400 hover:text-red-300 font-medium"
|
||||
disabled={deletingRunnerId === runner.id}
|
||||
className="text-red-400 hover:text-red-300 font-medium disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
Delete
|
||||
{deletingRunnerId === runner.id ? 'Deleting...' : 'Delete'}
|
||||
</button>
|
||||
</td>
|
||||
</tr>
|
||||
@@ -515,9 +636,7 @@ export default function AdminPanel() {
|
||||
<div className="bg-gray-800 rounded-lg shadow-md p-6 border border-gray-700">
|
||||
<h2 className="text-xl font-semibold mb-4 text-gray-100">User Management</h2>
|
||||
{loading ? (
|
||||
<div className="flex justify-center py-8">
|
||||
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-orange-500"></div>
|
||||
</div>
|
||||
<LoadingSpinner size="sm" className="py-8" />
|
||||
) : !users || users.length === 0 ? (
|
||||
<p className="text-gray-400 text-center py-8">No users found.</p>
|
||||
) : (
|
||||
|
||||
41
web/src/components/ErrorBoundary.jsx
Normal file
41
web/src/components/ErrorBoundary.jsx
Normal file
@@ -0,0 +1,41 @@
|
||||
import React from 'react';
|
||||
|
||||
class ErrorBoundary extends React.Component {
|
||||
constructor(props) {
|
||||
super(props);
|
||||
this.state = { hasError: false, error: null };
|
||||
}
|
||||
|
||||
static getDerivedStateFromError(error) {
|
||||
return { hasError: true, error };
|
||||
}
|
||||
|
||||
componentDidCatch(error, errorInfo) {
|
||||
console.error('ErrorBoundary caught an error:', error, errorInfo);
|
||||
}
|
||||
|
||||
render() {
|
||||
if (this.state.hasError) {
|
||||
return (
|
||||
<div className="p-6 bg-red-400/20 border border-red-400/50 rounded-lg text-red-400">
|
||||
<h2 className="text-xl font-semibold mb-2">Something went wrong</h2>
|
||||
<p className="mb-4">{this.state.error?.message || 'An unexpected error occurred'}</p>
|
||||
<button
|
||||
onClick={() => {
|
||||
this.setState({ hasError: false, error: null });
|
||||
window.location.reload();
|
||||
}}
|
||||
className="px-4 py-2 bg-red-600 text-white rounded-lg hover:bg-red-500 transition-colors"
|
||||
>
|
||||
Reload Page
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return this.props.children;
|
||||
}
|
||||
}
|
||||
|
||||
export default ErrorBoundary;
|
||||
|
||||
26
web/src/components/ErrorMessage.jsx
Normal file
26
web/src/components/ErrorMessage.jsx
Normal file
@@ -0,0 +1,26 @@
|
||||
import React from 'react';
|
||||
|
||||
/**
|
||||
* Shared ErrorMessage component for consistent error display
|
||||
* Sanitizes error messages to prevent XSS
|
||||
*/
|
||||
export default function ErrorMessage({ error, className = '' }) {
|
||||
if (!error) return null;
|
||||
|
||||
// Sanitize error message - escape HTML entities
|
||||
const sanitize = (text) => {
|
||||
const div = document.createElement('div');
|
||||
div.textContent = text;
|
||||
return div.innerHTML;
|
||||
};
|
||||
|
||||
const sanitizedError = typeof error === 'string' ? sanitize(error) : sanitize(error.message || 'An error occurred');
|
||||
|
||||
return (
|
||||
<div className={`p-4 bg-red-400/20 border border-red-400/50 rounded-lg text-red-400 ${className}`}>
|
||||
<p className="font-semibold">Error:</p>
|
||||
<p dangerouslySetInnerHTML={{ __html: sanitizedError }} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
191
web/src/components/FileExplorer.jsx
Normal file
191
web/src/components/FileExplorer.jsx
Normal file
@@ -0,0 +1,191 @@
|
||||
import { useState } from 'react';
|
||||
|
||||
export default function FileExplorer({ files, onDownload, onPreview, onVideoPreview, isImageFile }) {
|
||||
const [expandedPaths, setExpandedPaths] = useState(new Set()); // Root folder collapsed by default
|
||||
|
||||
// Build directory tree from file paths
|
||||
const buildTree = (files) => {
|
||||
const tree = {};
|
||||
|
||||
files.forEach(file => {
|
||||
const path = file.file_name;
|
||||
// Handle both paths with slashes and single filenames
|
||||
const parts = path.includes('/') ? path.split('/').filter(p => p) : [path];
|
||||
|
||||
// If it's a single file at root (no slashes), treat it specially
|
||||
if (parts.length === 1 && !path.includes('/')) {
|
||||
tree[parts[0]] = {
|
||||
name: parts[0],
|
||||
isFile: true,
|
||||
file: file,
|
||||
children: {},
|
||||
path: parts[0]
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
let current = tree;
|
||||
parts.forEach((part, index) => {
|
||||
if (!current[part]) {
|
||||
current[part] = {
|
||||
name: part,
|
||||
isFile: index === parts.length - 1,
|
||||
file: index === parts.length - 1 ? file : null,
|
||||
children: {},
|
||||
path: parts.slice(0, index + 1).join('/')
|
||||
};
|
||||
}
|
||||
current = current[part].children;
|
||||
});
|
||||
});
|
||||
|
||||
return tree;
|
||||
};
|
||||
|
||||
const togglePath = (path) => {
|
||||
const newExpanded = new Set(expandedPaths);
|
||||
if (newExpanded.has(path)) {
|
||||
newExpanded.delete(path);
|
||||
} else {
|
||||
newExpanded.add(path);
|
||||
}
|
||||
setExpandedPaths(newExpanded);
|
||||
};
|
||||
|
||||
const renderTree = (node, level = 0, parentPath = '') => {
|
||||
const items = Object.values(node).sort((a, b) => {
|
||||
// Directories first, then files
|
||||
if (a.isFile !== b.isFile) {
|
||||
return a.isFile ? 1 : -1;
|
||||
}
|
||||
return a.name.localeCompare(b.name);
|
||||
});
|
||||
|
||||
return items.map((item) => {
|
||||
const fullPath = parentPath ? `${parentPath}/${item.name}` : item.name;
|
||||
const isExpanded = expandedPaths.has(fullPath);
|
||||
const indent = level * 20;
|
||||
|
||||
if (item.isFile) {
|
||||
const file = item.file;
|
||||
const isImage = isImageFile && isImageFile(file.file_name);
|
||||
const isVideo = file.file_name.toLowerCase().endsWith('.mp4');
|
||||
const sizeMB = (file.file_size / 1024 / 1024).toFixed(2);
|
||||
const isArchive = file.file_name.endsWith('.tar') || file.file_name.endsWith('.zip');
|
||||
|
||||
return (
|
||||
<div key={fullPath} className="flex items-center justify-between py-1.5 hover:bg-gray-800/50 rounded px-2" style={{ paddingLeft: `${indent + 8}px` }}>
|
||||
<div className="flex items-center gap-2 flex-1 min-w-0">
|
||||
<span className="text-gray-500 text-sm">{isArchive ? '📦' : isVideo ? '🎬' : '📄'}</span>
|
||||
<span className="text-gray-200 text-sm truncate" title={item.name}>
|
||||
{item.name}
|
||||
</span>
|
||||
<span className="text-gray-500 text-xs ml-2">{sizeMB} MB</span>
|
||||
</div>
|
||||
<div className="flex gap-2 ml-4 shrink-0">
|
||||
{isVideo && onVideoPreview && (
|
||||
<button
|
||||
onClick={() => onVideoPreview(file)}
|
||||
className="px-2 py-1 bg-purple-600 text-white rounded text-xs hover:bg-purple-500 transition-colors"
|
||||
title="Play Video"
|
||||
>
|
||||
▶
|
||||
</button>
|
||||
)}
|
||||
{isImage && onPreview && (
|
||||
<button
|
||||
onClick={() => onPreview(file)}
|
||||
className="px-2 py-1 bg-blue-600 text-white rounded text-xs hover:bg-blue-500 transition-colors"
|
||||
title="Preview"
|
||||
>
|
||||
👁
|
||||
</button>
|
||||
)}
|
||||
{onDownload && file.id && (
|
||||
<button
|
||||
onClick={() => onDownload(file.id, file.file_name)}
|
||||
className="px-2 py-1 bg-orange-600 text-white rounded text-xs hover:bg-orange-500 transition-colors"
|
||||
title="Download"
|
||||
>
|
||||
⬇
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
} else {
|
||||
const hasChildren = Object.keys(item.children).length > 0;
|
||||
return (
|
||||
<div key={fullPath}>
|
||||
<div
|
||||
className="flex items-center gap-2 py-1.5 hover:bg-gray-800/50 rounded px-2 cursor-pointer select-none"
|
||||
style={{ paddingLeft: `${indent + 8}px` }}
|
||||
onClick={() => hasChildren && togglePath(fullPath)}
|
||||
>
|
||||
<span className="text-gray-400 text-xs w-4 flex items-center justify-center">
|
||||
{hasChildren ? (isExpanded ? '▼' : '▶') : '•'}
|
||||
</span>
|
||||
<span className="text-gray-500 text-sm">
|
||||
{hasChildren ? (isExpanded ? '📂' : '📁') : '📁'}
|
||||
</span>
|
||||
<span className="text-gray-300 text-sm font-medium">{item.name}</span>
|
||||
{hasChildren && (
|
||||
<span className="text-gray-500 text-xs ml-2">
|
||||
({Object.keys(item.children).length})
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
{hasChildren && isExpanded && (
|
||||
<div className="ml-2">
|
||||
{renderTree(item.children, level + 1, fullPath)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
const tree = buildTree(files);
|
||||
|
||||
if (Object.keys(tree).length === 0) {
|
||||
return (
|
||||
<div className="text-gray-400 text-sm py-4 text-center">
|
||||
No files
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Wrap tree in a root folder
|
||||
const rootExpanded = expandedPaths.has('');
|
||||
|
||||
return (
|
||||
<div className="bg-gray-900 rounded-lg border border-gray-700 p-3">
|
||||
<div className="space-y-1">
|
||||
<div>
|
||||
<div
|
||||
className="flex items-center gap-2 py-1.5 hover:bg-gray-800/50 rounded px-2 cursor-pointer select-none"
|
||||
onClick={() => togglePath('')}
|
||||
>
|
||||
<span className="text-gray-400 text-xs w-4 flex items-center justify-center">
|
||||
{rootExpanded ? '▼' : '▶'}
|
||||
</span>
|
||||
<span className="text-gray-500 text-sm">
|
||||
{rootExpanded ? '📂' : '📁'}
|
||||
</span>
|
||||
<span className="text-gray-300 text-sm font-medium">Files</span>
|
||||
<span className="text-gray-500 text-xs ml-2">
|
||||
({Object.keys(tree).length})
|
||||
</span>
|
||||
</div>
|
||||
{rootExpanded && (
|
||||
<div className="ml-2">
|
||||
{renderTree(tree)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,29 +1,142 @@
|
||||
import { useState, useEffect } from 'react';
|
||||
import { jobs } from '../utils/api';
|
||||
import { useState, useEffect, useRef } from 'react';
|
||||
import { jobs, normalizeArrayResponse } from '../utils/api';
|
||||
import { wsManager } from '../utils/websocket';
|
||||
import JobDetails from './JobDetails';
|
||||
import LoadingSpinner from './LoadingSpinner';
|
||||
|
||||
export default function JobList() {
|
||||
const [jobList, setJobList] = useState([]);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [selectedJob, setSelectedJob] = useState(null);
|
||||
const [pagination, setPagination] = useState({ total: 0, limit: 50, offset: 0 });
|
||||
const [hasMore, setHasMore] = useState(true);
|
||||
const listenerIdRef = useRef(null);
|
||||
|
||||
useEffect(() => {
|
||||
loadJobs();
|
||||
const interval = setInterval(loadJobs, 5000);
|
||||
return () => clearInterval(interval);
|
||||
// Use shared WebSocket manager for real-time updates
|
||||
listenerIdRef.current = wsManager.subscribe('joblist', {
|
||||
open: () => {
|
||||
console.log('JobList: Shared WebSocket connected');
|
||||
// Load initial job list via HTTP to get current state
|
||||
loadJobs();
|
||||
},
|
||||
message: (data) => {
|
||||
console.log('JobList: Client WebSocket message received:', data.type, data.channel, data);
|
||||
// Handle jobs channel messages (always broadcasted)
|
||||
if (data.channel === 'jobs') {
|
||||
if (data.type === 'job_update' && data.data) {
|
||||
console.log('JobList: Updating job:', data.job_id, data.data);
|
||||
// Update job in list
|
||||
setJobList(prev => {
|
||||
const prevArray = Array.isArray(prev) ? prev : [];
|
||||
const index = prevArray.findIndex(j => j.id === data.job_id);
|
||||
if (index >= 0) {
|
||||
const updated = [...prevArray];
|
||||
updated[index] = { ...updated[index], ...data.data };
|
||||
console.log('JobList: Updated job at index', index, updated[index]);
|
||||
return updated;
|
||||
}
|
||||
// If job not in current page, reload to get updated list
|
||||
if (data.data.status === 'completed' || data.data.status === 'failed') {
|
||||
loadJobs();
|
||||
}
|
||||
return prevArray;
|
||||
});
|
||||
} else if (data.type === 'job_created' && data.data) {
|
||||
console.log('JobList: New job created:', data.job_id, data.data);
|
||||
// New job created - add to list
|
||||
setJobList(prev => {
|
||||
const prevArray = Array.isArray(prev) ? prev : [];
|
||||
// Check if job already exists (avoid duplicates)
|
||||
if (prevArray.findIndex(j => j.id === data.job_id) >= 0) {
|
||||
return prevArray;
|
||||
}
|
||||
// Add new job at the beginning
|
||||
return [data.data, ...prevArray];
|
||||
});
|
||||
}
|
||||
} else if (data.type === 'connected') {
|
||||
// Connection established
|
||||
console.log('JobList: WebSocket connected');
|
||||
}
|
||||
},
|
||||
error: (error) => {
|
||||
console.error('JobList: Shared WebSocket error:', error);
|
||||
},
|
||||
close: (event) => {
|
||||
console.log('JobList: Shared WebSocket closed:', event);
|
||||
}
|
||||
});
|
||||
|
||||
// Ensure connection is established
|
||||
wsManager.connect();
|
||||
|
||||
return () => {
|
||||
if (listenerIdRef.current) {
|
||||
wsManager.unsubscribe(listenerIdRef.current);
|
||||
listenerIdRef.current = null;
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
const loadJobs = async () => {
|
||||
const loadJobs = async (append = false) => {
|
||||
try {
|
||||
const data = await jobs.list();
|
||||
setJobList(data);
|
||||
const offset = append ? pagination.offset + pagination.limit : 0;
|
||||
const result = await jobs.listSummary({
|
||||
limit: pagination.limit,
|
||||
offset,
|
||||
sort: 'created_at:desc'
|
||||
});
|
||||
|
||||
// Handle both old format (array) and new format (object with data, total, etc.)
|
||||
const jobsArray = normalizeArrayResponse(result);
|
||||
const total = result.total !== undefined ? result.total : jobsArray.length;
|
||||
|
||||
if (append) {
|
||||
setJobList(prev => {
|
||||
const prevArray = Array.isArray(prev) ? prev : [];
|
||||
return [...prevArray, ...jobsArray];
|
||||
});
|
||||
setPagination(prev => ({ ...prev, offset, total }));
|
||||
} else {
|
||||
setJobList(jobsArray);
|
||||
setPagination({ total, limit: result.limit || pagination.limit, offset: result.offset || 0 });
|
||||
}
|
||||
|
||||
setHasMore(offset + jobsArray.length < total);
|
||||
} catch (error) {
|
||||
console.error('Failed to load jobs:', error);
|
||||
// Ensure jobList is always an array even on error
|
||||
if (!append) {
|
||||
setJobList([]);
|
||||
}
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const loadMore = () => {
|
||||
if (!loading && hasMore) {
|
||||
loadJobs(true);
|
||||
}
|
||||
};
|
||||
|
||||
// Keep selectedJob in sync with the job list when it refreshes
|
||||
useEffect(() => {
|
||||
if (selectedJob && jobList.length > 0) {
|
||||
const freshJob = jobList.find(j => j.id === selectedJob.id);
|
||||
if (freshJob) {
|
||||
// Update to the fresh object from the list to keep it in sync
|
||||
setSelectedJob(freshJob);
|
||||
} else {
|
||||
// Job was deleted or no longer exists, clear selection
|
||||
setSelectedJob(null);
|
||||
}
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [jobList]); // Only depend on jobList, not selectedJob to avoid infinite loops
|
||||
|
||||
const handleCancel = async (jobId) => {
|
||||
if (!confirm('Are you sure you want to cancel this job?')) return;
|
||||
try {
|
||||
@@ -37,12 +150,21 @@ export default function JobList() {
|
||||
const handleDelete = async (jobId) => {
|
||||
if (!confirm('Are you sure you want to permanently delete this job? This action cannot be undone.')) return;
|
||||
try {
|
||||
await jobs.delete(jobId);
|
||||
loadJobs();
|
||||
// Optimistically update the list
|
||||
setJobList(prev => {
|
||||
const prevArray = Array.isArray(prev) ? prev : [];
|
||||
return prevArray.filter(j => j.id !== jobId);
|
||||
});
|
||||
if (selectedJob && selectedJob.id === jobId) {
|
||||
setSelectedJob(null);
|
||||
}
|
||||
// Then actually delete
|
||||
await jobs.delete(jobId);
|
||||
// Reload to ensure consistency
|
||||
loadJobs();
|
||||
} catch (error) {
|
||||
// On error, reload to restore correct state
|
||||
loadJobs();
|
||||
alert('Failed to delete job: ' + error.message);
|
||||
}
|
||||
};
|
||||
@@ -58,12 +180,8 @@ export default function JobList() {
|
||||
return colors[status] || colors.pending;
|
||||
};
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="flex justify-center items-center h-64">
|
||||
<div className="animate-spin rounded-full h-12 w-12 border-b-2 border-orange-500"></div>
|
||||
</div>
|
||||
);
|
||||
if (loading && jobList.length === 0) {
|
||||
return <LoadingSpinner size="md" className="h-64" />;
|
||||
}
|
||||
|
||||
if (jobList.length === 0) {
|
||||
@@ -90,8 +208,10 @@ export default function JobList() {
|
||||
</div>
|
||||
|
||||
<div className="space-y-2 text-sm text-gray-400 mb-4">
|
||||
<p>Frames: {job.frame_start} - {job.frame_end}</p>
|
||||
<p>Format: {job.output_format}</p>
|
||||
{job.frame_start !== undefined && job.frame_end !== undefined && (
|
||||
<p>Frames: {job.frame_start} - {job.frame_end}</p>
|
||||
)}
|
||||
{job.output_format && <p>Format: {job.output_format}</p>}
|
||||
<p>Created: {new Date(job.created_at).toLocaleString()}</p>
|
||||
</div>
|
||||
|
||||
@@ -110,7 +230,15 @@ export default function JobList() {
|
||||
|
||||
<div className="flex gap-2">
|
||||
<button
|
||||
onClick={() => setSelectedJob(job)}
|
||||
onClick={() => {
|
||||
// Fetch full job details when viewing
|
||||
jobs.get(job.id).then(fullJob => {
|
||||
setSelectedJob(fullJob);
|
||||
}).catch(err => {
|
||||
console.error('Failed to load job details:', err);
|
||||
setSelectedJob(job); // Fallback to summary
|
||||
});
|
||||
}}
|
||||
className="flex-1 px-4 py-2 bg-orange-600 text-white rounded-lg hover:bg-orange-500 transition-colors font-medium"
|
||||
>
|
||||
View Details
|
||||
@@ -137,6 +265,18 @@ export default function JobList() {
|
||||
))}
|
||||
</div>
|
||||
|
||||
{hasMore && (
|
||||
<div className="flex justify-center mt-6">
|
||||
<button
|
||||
onClick={loadMore}
|
||||
disabled={loading}
|
||||
className="px-6 py-2 bg-gray-700 text-gray-200 rounded-lg hover:bg-gray-600 transition-colors font-medium disabled:opacity-50"
|
||||
>
|
||||
{loading ? 'Loading...' : 'Load More'}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{selectedJob && (
|
||||
<JobDetails
|
||||
job={selectedJob}
|
||||
@@ -147,4 +287,3 @@ export default function JobList() {
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
19
web/src/components/LoadingSpinner.jsx
Normal file
19
web/src/components/LoadingSpinner.jsx
Normal file
@@ -0,0 +1,19 @@
|
||||
import React from 'react';
|
||||
|
||||
/**
|
||||
* Shared LoadingSpinner component with size variants
|
||||
*/
|
||||
export default function LoadingSpinner({ size = 'md', className = '', borderColor = 'border-orange-500' }) {
|
||||
const sizeClasses = {
|
||||
sm: 'h-8 w-8',
|
||||
md: 'h-12 w-12',
|
||||
lg: 'h-16 w-16',
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={`flex justify-center items-center ${className}`}>
|
||||
<div className={`animate-spin rounded-full border-b-2 ${borderColor} ${sizeClasses[size]}`}></div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { useState, useEffect } from 'react';
|
||||
import { auth } from '../utils/api';
|
||||
import ErrorMessage from './ErrorMessage';
|
||||
|
||||
export default function Login() {
|
||||
const [providers, setProviders] = useState({
|
||||
@@ -92,11 +93,7 @@ export default function Login() {
|
||||
</div>
|
||||
|
||||
<div className="space-y-4">
|
||||
{error && (
|
||||
<div className="p-4 bg-red-400/20 border border-red-400/50 rounded-lg text-red-400 text-sm">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
<ErrorMessage error={error} className="text-sm" />
|
||||
{providers.local && (
|
||||
<div className="pb-4 border-b border-gray-700">
|
||||
<div className="flex gap-2 mb-4">
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { useState } from 'react';
|
||||
import { auth } from '../utils/api';
|
||||
import ErrorMessage from './ErrorMessage';
|
||||
import { useAuth } from '../hooks/useAuth';
|
||||
|
||||
export default function PasswordChange({ targetUserId = null, targetUserName = null, onSuccess }) {
|
||||
@@ -64,11 +65,7 @@ export default function PasswordChange({ targetUserId = null, targetUserName = n
|
||||
{isChangingOtherUser ? `Change Password for ${targetUserName || 'User'}` : 'Change Password'}
|
||||
</h2>
|
||||
|
||||
{error && (
|
||||
<div className="mb-4 p-3 bg-red-400/20 border border-red-400/50 rounded-lg text-red-400 text-sm">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
<ErrorMessage error={error} className="mb-4 text-sm" />
|
||||
|
||||
{success && (
|
||||
<div className="mb-4 p-3 bg-green-400/20 border border-green-400/50 rounded-lg text-green-400 text-sm">
|
||||
|
||||
@@ -1,22 +1,71 @@
|
||||
import { useState, useEffect } from 'react';
|
||||
import { admin } from '../utils/api';
|
||||
import { useState, useEffect, useRef } from 'react';
|
||||
import { admin, normalizeArrayResponse } from '../utils/api';
|
||||
import { wsManager } from '../utils/websocket';
|
||||
import JobDetails from './JobDetails';
|
||||
import LoadingSpinner from './LoadingSpinner';
|
||||
|
||||
export default function UserJobs({ userId, userName, onBack }) {
|
||||
const [jobList, setJobList] = useState([]);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [selectedJob, setSelectedJob] = useState(null);
|
||||
const listenerIdRef = useRef(null);
|
||||
|
||||
useEffect(() => {
|
||||
loadJobs();
|
||||
const interval = setInterval(loadJobs, 5000);
|
||||
return () => clearInterval(interval);
|
||||
// Use shared WebSocket manager for real-time updates instead of polling
|
||||
listenerIdRef.current = wsManager.subscribe(`userjobs_${userId}`, {
|
||||
open: () => {
|
||||
console.log('UserJobs: Shared WebSocket connected');
|
||||
loadJobs();
|
||||
},
|
||||
message: (data) => {
|
||||
// Handle jobs channel messages (always broadcasted)
|
||||
if (data.channel === 'jobs') {
|
||||
if (data.type === 'job_update' && data.data) {
|
||||
// Update job in list if it belongs to this user
|
||||
setJobList(prev => {
|
||||
const prevArray = Array.isArray(prev) ? prev : [];
|
||||
const index = prevArray.findIndex(j => j.id === data.job_id);
|
||||
if (index >= 0) {
|
||||
const updated = [...prevArray];
|
||||
updated[index] = { ...updated[index], ...data.data };
|
||||
return updated;
|
||||
}
|
||||
// If job not in current list, reload to get updated list
|
||||
if (data.data.status === 'completed' || data.data.status === 'failed') {
|
||||
loadJobs();
|
||||
}
|
||||
return prevArray;
|
||||
});
|
||||
} else if (data.type === 'job_created' && data.data) {
|
||||
// New job created - reload to check if it belongs to this user
|
||||
loadJobs();
|
||||
}
|
||||
}
|
||||
},
|
||||
error: (error) => {
|
||||
console.error('UserJobs: Shared WebSocket error:', error);
|
||||
},
|
||||
close: (event) => {
|
||||
console.log('UserJobs: Shared WebSocket closed:', event);
|
||||
}
|
||||
});
|
||||
|
||||
// Ensure connection is established
|
||||
wsManager.connect();
|
||||
|
||||
return () => {
|
||||
if (listenerIdRef.current) {
|
||||
wsManager.unsubscribe(listenerIdRef.current);
|
||||
listenerIdRef.current = null;
|
||||
}
|
||||
};
|
||||
}, [userId]);
|
||||
|
||||
const loadJobs = async () => {
|
||||
try {
|
||||
const data = await admin.getUserJobs(userId);
|
||||
setJobList(Array.isArray(data) ? data : []);
|
||||
setJobList(normalizeArrayResponse(data));
|
||||
} catch (error) {
|
||||
console.error('Failed to load jobs:', error);
|
||||
setJobList([]);
|
||||
@@ -47,11 +96,7 @@ export default function UserJobs({ userId, userName, onBack }) {
|
||||
}
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="flex justify-center items-center h-64">
|
||||
<div className="animate-spin rounded-full h-12 w-12 border-b-2 border-orange-500"></div>
|
||||
</div>
|
||||
);
|
||||
return <LoadingSpinner size="md" className="h-64" />;
|
||||
}
|
||||
|
||||
return (
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import { useState, useRef, useEffect } from 'react';
|
||||
import ErrorMessage from './ErrorMessage';
|
||||
import LoadingSpinner from './LoadingSpinner';
|
||||
|
||||
export default function VideoPlayer({ videoUrl, onClose }) {
|
||||
const videoRef = useRef(null);
|
||||
@@ -55,10 +57,10 @@ export default function VideoPlayer({ videoUrl, onClose }) {
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="bg-red-50 border border-red-200 rounded-lg p-4 text-red-700">
|
||||
{error}
|
||||
<div className="mt-2 text-sm text-red-600">
|
||||
<a href={videoUrl} download className="underline">Download video instead</a>
|
||||
<div>
|
||||
<ErrorMessage error={error} />
|
||||
<div className="mt-2 text-sm text-gray-400">
|
||||
<a href={videoUrl} download className="text-orange-400 hover:text-orange-300 underline">Download video instead</a>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
@@ -68,7 +70,7 @@ export default function VideoPlayer({ videoUrl, onClose }) {
|
||||
<div className="relative bg-black rounded-lg overflow-hidden">
|
||||
{loading && (
|
||||
<div className="absolute inset-0 flex items-center justify-center bg-black bg-opacity-50 z-10">
|
||||
<div className="animate-spin rounded-full h-12 w-12 border-b-2 border-white"></div>
|
||||
<LoadingSpinner size="lg" className="border-white" />
|
||||
</div>
|
||||
)}
|
||||
<video
|
||||
|
||||
@@ -3,12 +3,96 @@ const API_BASE = '/api';
|
||||
// Global auth error handler - will be set by useAuth hook
|
||||
let onAuthError = null;
|
||||
|
||||
// Request debouncing and deduplication
|
||||
const pendingRequests = new Map(); // key: endpoint+params, value: Promise
|
||||
const requestQueue = new Map(); // key: endpoint+params, value: { resolve, reject, timestamp }
|
||||
const DEBOUNCE_DELAY = 100; // 100ms debounce delay
|
||||
const DEDUPE_WINDOW = 5000; // 5 seconds - same request within this window uses cached promise
|
||||
|
||||
// Generate cache key from endpoint and params
|
||||
function getCacheKey(endpoint, options = {}) {
|
||||
const params = new URLSearchParams();
|
||||
Object.keys(options).sort().forEach(key => {
|
||||
if (options[key] !== undefined && options[key] !== null) {
|
||||
params.append(key, String(options[key]));
|
||||
}
|
||||
});
|
||||
const query = params.toString();
|
||||
return `${endpoint}${query ? '?' + query : ''}`;
|
||||
}
|
||||
|
||||
// Utility function to normalize array responses (handles both old and new formats)
|
||||
export function normalizeArrayResponse(response) {
|
||||
const data = response?.data || response;
|
||||
return Array.isArray(data) ? data : [];
|
||||
}
|
||||
|
||||
// Sentinel value to indicate a request was superseded (instead of rejecting)
|
||||
// Export it so components can check for it
|
||||
export const REQUEST_SUPERSEDED = Symbol('REQUEST_SUPERSEDED');
|
||||
|
||||
// Debounced request wrapper
|
||||
function debounceRequest(key, requestFn, delay = DEBOUNCE_DELAY) {
|
||||
return new Promise((resolve, reject) => {
|
||||
// Check if there's a pending request for this key
|
||||
if (pendingRequests.has(key)) {
|
||||
const pending = pendingRequests.get(key);
|
||||
// If request is very recent (within dedupe window), reuse it
|
||||
const now = Date.now();
|
||||
if (pending.timestamp && (now - pending.timestamp) < DEDUPE_WINDOW) {
|
||||
pending.promise.then(resolve).catch(reject);
|
||||
return;
|
||||
} else {
|
||||
// Request is older than dedupe window - remove it and create new one
|
||||
pendingRequests.delete(key);
|
||||
}
|
||||
}
|
||||
|
||||
// Clear any existing timeout for this key
|
||||
if (requestQueue.has(key)) {
|
||||
const queued = requestQueue.get(key);
|
||||
clearTimeout(queued.timeout);
|
||||
// Resolve with sentinel value instead of rejecting - this prevents errors from propagating
|
||||
// The new request will handle the actual response
|
||||
queued.resolve(REQUEST_SUPERSEDED);
|
||||
}
|
||||
|
||||
// Queue new request
|
||||
const timeout = setTimeout(() => {
|
||||
requestQueue.delete(key);
|
||||
const promise = requestFn();
|
||||
const timestamp = Date.now();
|
||||
pendingRequests.set(key, { promise, timestamp });
|
||||
|
||||
promise
|
||||
.then(result => {
|
||||
pendingRequests.delete(key);
|
||||
resolve(result);
|
||||
})
|
||||
.catch(error => {
|
||||
pendingRequests.delete(key);
|
||||
reject(error);
|
||||
});
|
||||
}, delay);
|
||||
|
||||
requestQueue.set(key, { resolve, reject, timeout });
|
||||
});
|
||||
}
|
||||
|
||||
export const setAuthErrorHandler = (handler) => {
|
||||
onAuthError = handler;
|
||||
};
|
||||
|
||||
const handleAuthError = (response) => {
|
||||
// Whitelist of endpoints that should NOT trigger auth error handling
|
||||
// These are endpoints that can legitimately return 401/403 without meaning the user is logged out
|
||||
const AUTH_CHECK_ENDPOINTS = ['/auth/me', '/auth/logout'];
|
||||
|
||||
const handleAuthError = (response, endpoint) => {
|
||||
if (response.status === 401 || response.status === 403) {
|
||||
// Don't trigger auth error handler for endpoints that check auth status
|
||||
if (AUTH_CHECK_ENDPOINTS.includes(endpoint)) {
|
||||
return;
|
||||
}
|
||||
// Trigger auth error handler if set (this will clear user state)
|
||||
if (onAuthError) {
|
||||
onAuthError();
|
||||
@@ -22,60 +106,79 @@ const handleAuthError = (response) => {
|
||||
}
|
||||
};
|
||||
|
||||
// Extract error message from response - centralized to avoid duplication
|
||||
async function extractErrorMessage(response) {
|
||||
try {
|
||||
const errorData = await response.json();
|
||||
return errorData?.error || response.statusText;
|
||||
} catch {
|
||||
return response.statusText;
|
||||
}
|
||||
}
|
||||
|
||||
export const api = {
|
||||
async get(endpoint) {
|
||||
async get(endpoint, options = {}) {
|
||||
const abortController = options.signal || new AbortController();
|
||||
const response = await fetch(`${API_BASE}${endpoint}`, {
|
||||
credentials: 'include', // Include cookies for session
|
||||
signal: abortController.signal,
|
||||
});
|
||||
if (!response.ok) {
|
||||
// Handle auth errors before parsing response
|
||||
// Don't redirect on /auth/me - that's the auth check itself
|
||||
if ((response.status === 401 || response.status === 403) && !endpoint.startsWith('/auth/')) {
|
||||
handleAuthError(response);
|
||||
// Don't redirect - let React handle UI change through state
|
||||
}
|
||||
const errorData = await response.json().catch(() => null);
|
||||
const errorMessage = errorData?.error || response.statusText;
|
||||
handleAuthError(response, endpoint);
|
||||
const errorMessage = await extractErrorMessage(response);
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
return response.json();
|
||||
},
|
||||
|
||||
async post(endpoint, data) {
|
||||
async post(endpoint, data, options = {}) {
|
||||
const abortController = options.signal || new AbortController();
|
||||
const response = await fetch(`${API_BASE}${endpoint}`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(data),
|
||||
credentials: 'include', // Include cookies for session
|
||||
signal: abortController.signal,
|
||||
});
|
||||
if (!response.ok) {
|
||||
// Handle auth errors before parsing response
|
||||
// Don't redirect on /auth/* endpoints - those are login/logout
|
||||
if ((response.status === 401 || response.status === 403) && !endpoint.startsWith('/auth/')) {
|
||||
handleAuthError(response);
|
||||
// Don't redirect - let React handle UI change through state
|
||||
}
|
||||
const errorData = await response.json().catch(() => null);
|
||||
const errorMessage = errorData?.error || response.statusText;
|
||||
handleAuthError(response, endpoint);
|
||||
const errorMessage = await extractErrorMessage(response);
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
return response.json();
|
||||
},
|
||||
|
||||
async delete(endpoint) {
|
||||
async patch(endpoint, data, options = {}) {
|
||||
const abortController = options.signal || new AbortController();
|
||||
const response = await fetch(`${API_BASE}${endpoint}`, {
|
||||
method: 'DELETE',
|
||||
method: 'PATCH',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: data ? JSON.stringify(data) : undefined,
|
||||
credentials: 'include', // Include cookies for session
|
||||
signal: abortController.signal,
|
||||
});
|
||||
if (!response.ok) {
|
||||
// Handle auth errors before parsing response
|
||||
// Don't redirect on /auth/* endpoints
|
||||
if ((response.status === 401 || response.status === 403) && !endpoint.startsWith('/auth/')) {
|
||||
handleAuthError(response);
|
||||
// Don't redirect - let React handle UI change through state
|
||||
}
|
||||
const errorData = await response.json().catch(() => null);
|
||||
const errorMessage = errorData?.error || response.statusText;
|
||||
handleAuthError(response, endpoint);
|
||||
const errorMessage = await extractErrorMessage(response);
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
return response.json();
|
||||
},
|
||||
|
||||
async delete(endpoint, options = {}) {
|
||||
const abortController = options.signal || new AbortController();
|
||||
const response = await fetch(`${API_BASE}${endpoint}`, {
|
||||
method: 'DELETE',
|
||||
credentials: 'include', // Include cookies for session
|
||||
signal: abortController.signal,
|
||||
});
|
||||
if (!response.ok) {
|
||||
// Handle auth errors before parsing response
|
||||
handleAuthError(response, endpoint);
|
||||
const errorMessage = await extractErrorMessage(response);
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
return response.json();
|
||||
@@ -112,8 +215,7 @@ export const api = {
|
||||
} else {
|
||||
// Handle auth errors
|
||||
if (xhr.status === 401 || xhr.status === 403) {
|
||||
handleAuthError({ status: xhr.status });
|
||||
// Don't redirect - let React handle UI change through state
|
||||
handleAuthError({ status: xhr.status }, endpoint);
|
||||
}
|
||||
try {
|
||||
const errorData = JSON.parse(xhr.responseText);
|
||||
@@ -174,12 +276,53 @@ export const auth = {
|
||||
};
|
||||
|
||||
export const jobs = {
|
||||
async list() {
|
||||
return api.get('/jobs');
|
||||
async list(options = {}) {
|
||||
const key = getCacheKey('/jobs', options);
|
||||
return debounceRequest(key, () => {
|
||||
const params = new URLSearchParams();
|
||||
if (options.limit) params.append('limit', options.limit.toString());
|
||||
if (options.offset) params.append('offset', options.offset.toString());
|
||||
if (options.status) params.append('status', options.status);
|
||||
if (options.sort) params.append('sort', options.sort);
|
||||
const query = params.toString();
|
||||
return api.get(`/jobs${query ? '?' + query : ''}`);
|
||||
});
|
||||
},
|
||||
|
||||
async get(id) {
|
||||
return api.get(`/jobs/${id}`);
|
||||
async listSummary(options = {}) {
|
||||
const key = getCacheKey('/jobs/summary', options);
|
||||
return debounceRequest(key, () => {
|
||||
const params = new URLSearchParams();
|
||||
if (options.limit) params.append('limit', options.limit.toString());
|
||||
if (options.offset) params.append('offset', options.offset.toString());
|
||||
if (options.status) params.append('status', options.status);
|
||||
if (options.sort) params.append('sort', options.sort);
|
||||
const query = params.toString();
|
||||
return api.get(`/jobs/summary${query ? '?' + query : ''}`, options);
|
||||
});
|
||||
},
|
||||
|
||||
async get(id, options = {}) {
|
||||
const key = getCacheKey(`/jobs/${id}`, options);
|
||||
return debounceRequest(key, async () => {
|
||||
if (options.etag) {
|
||||
// Include ETag in request headers for conditional requests
|
||||
const headers = { 'If-None-Match': options.etag };
|
||||
const response = await fetch(`${API_BASE}/jobs/${id}`, {
|
||||
credentials: 'include',
|
||||
headers,
|
||||
});
|
||||
if (response.status === 304) {
|
||||
return null; // Not modified
|
||||
}
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json().catch(() => null);
|
||||
throw new Error(errorData?.error || response.statusText);
|
||||
}
|
||||
return response.json();
|
||||
}
|
||||
return api.get(`/jobs/${id}`, options);
|
||||
});
|
||||
},
|
||||
|
||||
async create(jobData) {
|
||||
@@ -198,31 +341,82 @@ export const jobs = {
|
||||
return api.uploadFile(`/jobs/${jobId}/upload`, file, onProgress, mainBlendFile);
|
||||
},
|
||||
|
||||
async getFiles(jobId) {
|
||||
return api.get(`/jobs/${jobId}/files`);
|
||||
async uploadFileForJobCreation(file, onProgress, mainBlendFile) {
|
||||
return api.uploadFile(`/jobs/upload`, file, onProgress, mainBlendFile);
|
||||
},
|
||||
|
||||
async getFiles(jobId, options = {}) {
|
||||
const key = getCacheKey(`/jobs/${jobId}/files`, options);
|
||||
return debounceRequest(key, () => {
|
||||
const params = new URLSearchParams();
|
||||
if (options.limit) params.append('limit', options.limit.toString());
|
||||
if (options.offset) params.append('offset', options.offset.toString());
|
||||
if (options.file_type) params.append('file_type', options.file_type);
|
||||
if (options.extension) params.append('extension', options.extension);
|
||||
const query = params.toString();
|
||||
return api.get(`/jobs/${jobId}/files${query ? '?' + query : ''}`, options);
|
||||
});
|
||||
},
|
||||
|
||||
async getFilesCount(jobId, options = {}) {
|
||||
const key = getCacheKey(`/jobs/${jobId}/files/count`, options);
|
||||
return debounceRequest(key, () => {
|
||||
const params = new URLSearchParams();
|
||||
if (options.file_type) params.append('file_type', options.file_type);
|
||||
const query = params.toString();
|
||||
return api.get(`/jobs/${jobId}/files/count${query ? '?' + query : ''}`);
|
||||
});
|
||||
},
|
||||
|
||||
async getContextArchive(jobId, options = {}) {
|
||||
return api.get(`/jobs/${jobId}/context`, options);
|
||||
},
|
||||
|
||||
downloadFile(jobId, fileId) {
|
||||
return `${API_BASE}/jobs/${jobId}/files/${fileId}/download`;
|
||||
},
|
||||
|
||||
previewEXR(jobId, fileId) {
|
||||
return `${API_BASE}/jobs/${jobId}/files/${fileId}/preview-exr`;
|
||||
},
|
||||
|
||||
getVideoUrl(jobId) {
|
||||
return `${API_BASE}/jobs/${jobId}/video`;
|
||||
},
|
||||
|
||||
async getTaskLogs(jobId, taskId, options = {}) {
|
||||
const params = new URLSearchParams();
|
||||
if (options.stepName) params.append('step_name', options.stepName);
|
||||
if (options.logLevel) params.append('log_level', options.logLevel);
|
||||
if (options.limit) params.append('limit', options.limit.toString());
|
||||
const query = params.toString();
|
||||
return api.get(`/jobs/${jobId}/tasks/${taskId}/logs${query ? '?' + query : ''}`);
|
||||
const key = getCacheKey(`/jobs/${jobId}/tasks/${taskId}/logs`, options);
|
||||
return debounceRequest(key, async () => {
|
||||
const params = new URLSearchParams();
|
||||
if (options.stepName) params.append('step_name', options.stepName);
|
||||
if (options.logLevel) params.append('log_level', options.logLevel);
|
||||
if (options.limit) params.append('limit', options.limit.toString());
|
||||
if (options.sinceId) params.append('since_id', options.sinceId.toString());
|
||||
const query = params.toString();
|
||||
const result = await api.get(`/jobs/${jobId}/tasks/${taskId}/logs${query ? '?' + query : ''}`, options);
|
||||
// Handle both old format (array) and new format (object with logs, last_id, limit)
|
||||
if (Array.isArray(result)) {
|
||||
return { logs: result, last_id: result.length > 0 ? result[result.length - 1].id : 0, limit: options.limit || 100 };
|
||||
}
|
||||
return result;
|
||||
});
|
||||
},
|
||||
|
||||
async getTaskSteps(jobId, taskId) {
|
||||
return api.get(`/jobs/${jobId}/tasks/${taskId}/steps`);
|
||||
async getTaskSteps(jobId, taskId, options = {}) {
|
||||
return api.get(`/jobs/${jobId}/tasks/${taskId}/steps`, options);
|
||||
},
|
||||
|
||||
// New unified client WebSocket - DEPRECATED: Use wsManager from websocket.js instead
|
||||
// This is kept for backwards compatibility but should not be used
|
||||
streamClientWebSocket() {
|
||||
console.warn('streamClientWebSocket() is deprecated - use wsManager from websocket.js instead');
|
||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
const wsHost = window.location.host;
|
||||
const url = `${wsProtocol}//${wsHost}${API_BASE}/ws`;
|
||||
return new WebSocket(url);
|
||||
},
|
||||
|
||||
// Old WebSocket methods (to be removed after migration)
|
||||
streamTaskLogsWebSocket(jobId, taskId, lastId = 0) {
|
||||
// Convert HTTP to WebSocket URL
|
||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
@@ -231,6 +425,20 @@ export const jobs = {
|
||||
return new WebSocket(url);
|
||||
},
|
||||
|
||||
streamJobsWebSocket() {
|
||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
const wsHost = window.location.host;
|
||||
const url = `${wsProtocol}//${wsHost}${API_BASE}/jobs/ws-old`;
|
||||
return new WebSocket(url);
|
||||
},
|
||||
|
||||
streamJobWebSocket(jobId) {
|
||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
const wsHost = window.location.host;
|
||||
const url = `${wsProtocol}//${wsHost}${API_BASE}/jobs/${jobId}/ws`;
|
||||
return new WebSocket(url);
|
||||
},
|
||||
|
||||
async retryTask(jobId, taskId) {
|
||||
return api.post(`/jobs/${jobId}/tasks/${taskId}/retry`);
|
||||
},
|
||||
@@ -239,8 +447,50 @@ export const jobs = {
|
||||
return api.get(`/jobs/${jobId}/metadata`);
|
||||
},
|
||||
|
||||
async getTasks(jobId) {
|
||||
return api.get(`/jobs/${jobId}/tasks`);
|
||||
async getTasks(jobId, options = {}) {
|
||||
const key = getCacheKey(`/jobs/${jobId}/tasks`, options);
|
||||
return debounceRequest(key, () => {
|
||||
const params = new URLSearchParams();
|
||||
if (options.limit) params.append('limit', options.limit.toString());
|
||||
if (options.offset) params.append('offset', options.offset.toString());
|
||||
if (options.status) params.append('status', options.status);
|
||||
if (options.frameStart) params.append('frame_start', options.frameStart.toString());
|
||||
if (options.frameEnd) params.append('frame_end', options.frameEnd.toString());
|
||||
if (options.sort) params.append('sort', options.sort);
|
||||
const query = params.toString();
|
||||
return api.get(`/jobs/${jobId}/tasks${query ? '?' + query : ''}`, options);
|
||||
});
|
||||
},
|
||||
|
||||
async getTasksSummary(jobId, options = {}) {
|
||||
const key = getCacheKey(`/jobs/${jobId}/tasks/summary`, options);
|
||||
return debounceRequest(key, () => {
|
||||
const params = new URLSearchParams();
|
||||
if (options.limit) params.append('limit', options.limit.toString());
|
||||
if (options.offset) params.append('offset', options.offset.toString());
|
||||
if (options.status) params.append('status', options.status);
|
||||
if (options.sort) params.append('sort', options.sort);
|
||||
const query = params.toString();
|
||||
return api.get(`/jobs/${jobId}/tasks/summary${query ? '?' + query : ''}`, options);
|
||||
});
|
||||
},
|
||||
|
||||
async batchGetJobs(jobIds) {
|
||||
// Sort jobIds for consistent cache key
|
||||
const sortedIds = [...jobIds].sort((a, b) => a - b);
|
||||
const key = getCacheKey('/jobs/batch', { job_ids: sortedIds.join(',') });
|
||||
return debounceRequest(key, () => {
|
||||
return api.post('/jobs/batch', { job_ids: jobIds });
|
||||
});
|
||||
},
|
||||
|
||||
async batchGetTasks(jobId, taskIds) {
|
||||
// Sort taskIds for consistent cache key
|
||||
const sortedIds = [...taskIds].sort((a, b) => a - b);
|
||||
const key = getCacheKey(`/jobs/${jobId}/tasks/batch`, { task_ids: sortedIds.join(',') });
|
||||
return debounceRequest(key, () => {
|
||||
return api.post(`/jobs/${jobId}/tasks/batch`, { task_ids: taskIds });
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
@@ -249,16 +499,22 @@ export const runners = {
|
||||
};
|
||||
|
||||
export const admin = {
|
||||
async generateToken(expiresInHours) {
|
||||
return api.post('/admin/runners/tokens', { expires_in_hours: expiresInHours });
|
||||
async generateAPIKey(name, description, scope) {
|
||||
const data = { name, scope };
|
||||
if (description) data.description = description;
|
||||
return api.post('/admin/runners/api-keys', data);
|
||||
},
|
||||
|
||||
async listTokens() {
|
||||
return api.get('/admin/runners/tokens');
|
||||
async listAPIKeys() {
|
||||
return api.get('/admin/runners/api-keys');
|
||||
},
|
||||
|
||||
async revokeToken(tokenId) {
|
||||
return api.delete(`/admin/runners/tokens/${tokenId}`);
|
||||
async revokeAPIKey(keyId) {
|
||||
return api.patch(`/admin/runners/api-keys/${keyId}/revoke`);
|
||||
},
|
||||
|
||||
async deleteAPIKey(keyId) {
|
||||
return api.delete(`/admin/runners/api-keys/${keyId}`);
|
||||
},
|
||||
|
||||
async listRunners() {
|
||||
|
||||
271
web/src/utils/websocket.js
Normal file
271
web/src/utils/websocket.js
Normal file
@@ -0,0 +1,271 @@
|
||||
// Shared WebSocket connection manager
|
||||
// All components should use this instead of creating their own connections
|
||||
|
||||
class WebSocketManager {
|
||||
constructor() {
|
||||
this.ws = null;
|
||||
this.listeners = new Map(); // Map of listener IDs to callback functions
|
||||
this.reconnectTimeout = null;
|
||||
this.reconnectDelay = 2000;
|
||||
this.isConnecting = false;
|
||||
this.listenerIdCounter = 0;
|
||||
this.verboseLogging = false; // Set to true to enable verbose WebSocket logging
|
||||
|
||||
// Track server-side channel subscriptions for re-subscription on reconnect
|
||||
this.serverSubscriptions = new Set(); // Channels we want to be subscribed to
|
||||
this.confirmedSubscriptions = new Set(); // Channels confirmed by server
|
||||
this.pendingSubscriptions = new Set(); // Channels waiting for confirmation
|
||||
}
|
||||
|
||||
connect() {
|
||||
// If already connected or connecting, don't create a new connection
|
||||
if (this.ws && (this.ws.readyState === WebSocket.CONNECTING || this.ws.readyState === WebSocket.OPEN)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.isConnecting) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.isConnecting = true;
|
||||
|
||||
try {
|
||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
const wsHost = window.location.host;
|
||||
const API_BASE = '/api';
|
||||
const url = `${wsProtocol}//${wsHost}${API_BASE}/jobs/ws`;
|
||||
|
||||
this.ws = new WebSocket(url);
|
||||
|
||||
this.ws.onopen = () => {
|
||||
if (this.verboseLogging) {
|
||||
console.log('Shared WebSocket connected');
|
||||
}
|
||||
this.isConnecting = false;
|
||||
|
||||
// Re-subscribe to all channels that were previously subscribed
|
||||
this.resubscribeToChannels();
|
||||
|
||||
this.notifyListeners('open', {});
|
||||
};
|
||||
|
||||
this.ws.onmessage = (event) => {
|
||||
try {
|
||||
const data = JSON.parse(event.data);
|
||||
if (this.verboseLogging) {
|
||||
console.log('WebSocketManager: Message received:', data.type, data.channel || 'no channel', data);
|
||||
}
|
||||
this.notifyListeners('message', data);
|
||||
} catch (error) {
|
||||
console.error('WebSocketManager: Failed to parse message:', error, 'Raw data:', event.data);
|
||||
}
|
||||
};
|
||||
|
||||
this.ws.onerror = (error) => {
|
||||
console.error('Shared WebSocket error:', error);
|
||||
this.isConnecting = false;
|
||||
this.notifyListeners('error', error);
|
||||
};
|
||||
|
||||
this.ws.onclose = (event) => {
|
||||
if (this.verboseLogging) {
|
||||
console.log('Shared WebSocket closed:', {
|
||||
code: event.code,
|
||||
reason: event.reason,
|
||||
wasClean: event.wasClean
|
||||
});
|
||||
}
|
||||
this.ws = null;
|
||||
this.isConnecting = false;
|
||||
|
||||
// Clear confirmed/pending but keep serverSubscriptions for re-subscription
|
||||
this.confirmedSubscriptions.clear();
|
||||
this.pendingSubscriptions.clear();
|
||||
|
||||
this.notifyListeners('close', event);
|
||||
|
||||
// Always retry connection if we have listeners
|
||||
if (this.listeners.size > 0) {
|
||||
if (this.reconnectTimeout) {
|
||||
clearTimeout(this.reconnectTimeout);
|
||||
}
|
||||
this.reconnectTimeout = setTimeout(() => {
|
||||
if (!this.ws || this.ws.readyState === WebSocket.CLOSED) {
|
||||
this.connect();
|
||||
}
|
||||
}, this.reconnectDelay);
|
||||
}
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('Failed to create WebSocket:', error);
|
||||
this.isConnecting = false;
|
||||
// Retry after delay
|
||||
this.reconnectTimeout = setTimeout(() => {
|
||||
this.connect();
|
||||
}, this.reconnectDelay);
|
||||
}
|
||||
}
|
||||
|
||||
subscribe(listenerId, callbacks) {
|
||||
// Generate ID if not provided
|
||||
if (!listenerId) {
|
||||
listenerId = `listener_${this.listenerIdCounter++}`;
|
||||
}
|
||||
|
||||
if (this.verboseLogging) {
|
||||
console.log('WebSocketManager: Subscribing listener:', listenerId, 'WebSocket state:', this.ws ? this.ws.readyState : 'no connection');
|
||||
}
|
||||
this.listeners.set(listenerId, callbacks);
|
||||
|
||||
// Connect if not already connected
|
||||
if (!this.ws || this.ws.readyState === WebSocket.CLOSED) {
|
||||
if (this.verboseLogging) {
|
||||
console.log('WebSocketManager: WebSocket not connected, connecting...');
|
||||
}
|
||||
this.connect();
|
||||
}
|
||||
|
||||
// If already open, notify immediately
|
||||
if (this.ws && this.ws.readyState === WebSocket.OPEN && callbacks.open) {
|
||||
if (this.verboseLogging) {
|
||||
console.log('WebSocketManager: WebSocket already open, calling open callback for listener:', listenerId);
|
||||
}
|
||||
// Use setTimeout to ensure this happens after the listener is registered
|
||||
setTimeout(() => {
|
||||
if (callbacks.open) {
|
||||
callbacks.open();
|
||||
}
|
||||
}, 0);
|
||||
}
|
||||
|
||||
return listenerId;
|
||||
}
|
||||
|
||||
unsubscribe(listenerId) {
|
||||
this.listeners.delete(listenerId);
|
||||
|
||||
// If no more listeners, we could close the connection, but let's keep it open
|
||||
// in case other components need it
|
||||
}
|
||||
|
||||
send(data) {
|
||||
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
|
||||
if (this.verboseLogging) {
|
||||
console.log('WebSocketManager: Sending message:', data);
|
||||
}
|
||||
this.ws.send(JSON.stringify(data));
|
||||
} else {
|
||||
console.warn('WebSocketManager: Cannot send message - connection not open. State:', this.ws ? this.ws.readyState : 'no connection', 'Message:', data);
|
||||
}
|
||||
}
|
||||
|
||||
notifyListeners(eventType, data) {
|
||||
this.listeners.forEach((callbacks) => {
|
||||
if (callbacks[eventType]) {
|
||||
try {
|
||||
callbacks[eventType](data);
|
||||
} catch (error) {
|
||||
console.error('Error in WebSocket listener:', error);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
getReadyState() {
|
||||
return this.ws ? this.ws.readyState : WebSocket.CLOSED;
|
||||
}
|
||||
|
||||
// Subscribe to a server-side channel (will be re-subscribed on reconnect)
|
||||
subscribeToChannel(channel) {
|
||||
if (this.serverSubscriptions.has(channel)) {
|
||||
// Already subscribed or pending
|
||||
return;
|
||||
}
|
||||
|
||||
this.serverSubscriptions.add(channel);
|
||||
|
||||
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
|
||||
if (!this.confirmedSubscriptions.has(channel) && !this.pendingSubscriptions.has(channel)) {
|
||||
this.pendingSubscriptions.add(channel);
|
||||
this.send({ type: 'subscribe', channel });
|
||||
if (this.verboseLogging) {
|
||||
console.log('WebSocketManager: Subscribing to channel:', channel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unsubscribe from a server-side channel (won't be re-subscribed on reconnect)
|
||||
unsubscribeFromChannel(channel) {
|
||||
this.serverSubscriptions.delete(channel);
|
||||
this.confirmedSubscriptions.delete(channel);
|
||||
this.pendingSubscriptions.delete(channel);
|
||||
|
||||
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
|
||||
this.send({ type: 'unsubscribe', channel });
|
||||
if (this.verboseLogging) {
|
||||
console.log('WebSocketManager: Unsubscribing from channel:', channel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mark a channel subscription as confirmed (call this when server confirms)
|
||||
confirmSubscription(channel) {
|
||||
this.pendingSubscriptions.delete(channel);
|
||||
this.confirmedSubscriptions.add(channel);
|
||||
if (this.verboseLogging) {
|
||||
console.log('WebSocketManager: Subscription confirmed for channel:', channel);
|
||||
}
|
||||
}
|
||||
|
||||
// Mark a channel subscription as failed (call this when server rejects)
|
||||
failSubscription(channel) {
|
||||
this.pendingSubscriptions.delete(channel);
|
||||
this.serverSubscriptions.delete(channel);
|
||||
if (this.verboseLogging) {
|
||||
console.log('WebSocketManager: Subscription failed for channel:', channel);
|
||||
}
|
||||
}
|
||||
|
||||
// Check if subscribed to a channel
|
||||
isSubscribedToChannel(channel) {
|
||||
return this.confirmedSubscriptions.has(channel);
|
||||
}
|
||||
|
||||
// Re-subscribe to all channels after reconnect
|
||||
resubscribeToChannels() {
|
||||
if (this.serverSubscriptions.size === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.verboseLogging) {
|
||||
console.log('WebSocketManager: Re-subscribing to channels:', Array.from(this.serverSubscriptions));
|
||||
}
|
||||
|
||||
for (const channel of this.serverSubscriptions) {
|
||||
if (!this.pendingSubscriptions.has(channel)) {
|
||||
this.pendingSubscriptions.add(channel);
|
||||
this.send({ type: 'subscribe', channel });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
disconnect() {
|
||||
if (this.reconnectTimeout) {
|
||||
clearTimeout(this.reconnectTimeout);
|
||||
this.reconnectTimeout = null;
|
||||
}
|
||||
if (this.ws) {
|
||||
this.ws.close();
|
||||
this.ws = null;
|
||||
}
|
||||
this.listeners.clear();
|
||||
this.serverSubscriptions.clear();
|
||||
this.confirmedSubscriptions.clear();
|
||||
this.pendingSubscriptions.clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Export singleton instance
|
||||
export const wsManager = new WebSocketManager();
|
||||
|
||||
Reference in New Issue
Block a user