Refactor runner and installation scripts for improved functionality
- Removed the `--disable-hiprt` flag from the runner command, simplifying the rendering options for users. - Updated the `jiggablend-runner` script and README to reflect the removal of the HIPRT control flag, enhancing clarity in usage instructions. - Enhanced the installation script to provide clearer examples for running the jiggablend manager and runner, improving user experience during setup. - Implemented a more robust GPU backend detection mechanism, allowing for better compatibility with various hardware configurations.
This commit is contained in:
14
Makefile
14
Makefile
@@ -27,7 +27,19 @@ cleanup: cleanup-manager cleanup-runner
|
||||
run: cleanup build init-test
|
||||
@echo "Starting manager and runner in parallel..."
|
||||
@echo "Press Ctrl+C to stop both..."
|
||||
@trap 'kill $$MANAGER_PID $$RUNNER_PID 2>/dev/null; exit' INT TERM; \
|
||||
@MANAGER_PID=""; RUNNER_PID=""; INTERRUPTED=0; \
|
||||
cleanup() { \
|
||||
exit_code=$$?; \
|
||||
trap - INT TERM EXIT; \
|
||||
if [ -n "$$RUNNER_PID" ]; then kill -TERM "$$RUNNER_PID" 2>/dev/null || true; fi; \
|
||||
if [ -n "$$MANAGER_PID" ]; then kill -TERM "$$MANAGER_PID" 2>/dev/null || true; fi; \
|
||||
if [ -n "$$MANAGER_PID$$RUNNER_PID" ]; then wait $$MANAGER_PID $$RUNNER_PID 2>/dev/null || true; fi; \
|
||||
if [ "$$INTERRUPTED" -eq 1 ]; then exit 0; fi; \
|
||||
exit $$exit_code; \
|
||||
}; \
|
||||
on_interrupt() { INTERRUPTED=1; cleanup; }; \
|
||||
trap on_interrupt INT TERM; \
|
||||
trap cleanup EXIT; \
|
||||
bin/jiggablend manager -l manager.log & \
|
||||
MANAGER_PID=$$!; \
|
||||
sleep 2; \
|
||||
|
||||
@@ -154,8 +154,8 @@ bin/jiggablend runner --api-key <your-api-key>
|
||||
# With custom options
|
||||
bin/jiggablend runner --manager http://localhost:8080 --name my-runner --api-key <key> --log-file runner.log
|
||||
|
||||
# Hardware compatibility flags (force CPU + disable HIPRT)
|
||||
bin/jiggablend runner --api-key <key> --force-cpu-rendering --disable-hiprt
|
||||
# Hardware compatibility flag (force CPU)
|
||||
bin/jiggablend runner --api-key <key> --force-cpu-rendering
|
||||
|
||||
# Using environment variables
|
||||
JIGGABLEND_MANAGER=http://localhost:8080 JIGGABLEND_API_KEY=<key> bin/jiggablend runner
|
||||
|
||||
@@ -38,7 +38,6 @@ func init() {
|
||||
runnerCmd.Flags().BoolP("verbose", "v", false, "Enable verbose logging (same as --log-level=debug)")
|
||||
runnerCmd.Flags().Duration("poll-interval", 5*time.Second, "Job polling interval")
|
||||
runnerCmd.Flags().Bool("force-cpu-rendering", false, "Force CPU rendering for all jobs (disables GPU rendering)")
|
||||
runnerCmd.Flags().Bool("disable-hiprt", false, "Disable HIPRT acceleration in Blender Cycles")
|
||||
|
||||
// Bind flags to viper with JIGGABLEND_ prefix
|
||||
runnerViper.SetEnvPrefix("JIGGABLEND")
|
||||
@@ -54,7 +53,6 @@ func init() {
|
||||
runnerViper.BindPFlag("verbose", runnerCmd.Flags().Lookup("verbose"))
|
||||
runnerViper.BindPFlag("poll_interval", runnerCmd.Flags().Lookup("poll-interval"))
|
||||
runnerViper.BindPFlag("force_cpu_rendering", runnerCmd.Flags().Lookup("force-cpu-rendering"))
|
||||
runnerViper.BindPFlag("disable_hiprt", runnerCmd.Flags().Lookup("disable-hiprt"))
|
||||
}
|
||||
|
||||
func runRunner(cmd *cobra.Command, args []string) {
|
||||
@@ -68,7 +66,6 @@ func runRunner(cmd *cobra.Command, args []string) {
|
||||
verbose := runnerViper.GetBool("verbose")
|
||||
pollInterval := runnerViper.GetDuration("poll_interval")
|
||||
forceCPURendering := runnerViper.GetBool("force_cpu_rendering")
|
||||
disableHIPRT := runnerViper.GetBool("disable_hiprt")
|
||||
|
||||
var r *runner.Runner
|
||||
|
||||
@@ -124,7 +121,7 @@ func runRunner(cmd *cobra.Command, args []string) {
|
||||
}
|
||||
|
||||
// Create runner
|
||||
r = runner.New(managerURL, name, hostname, forceCPURendering, disableHIPRT)
|
||||
r = runner.New(managerURL, name, hostname, forceCPURendering)
|
||||
|
||||
// Check for required tools early to fail fast
|
||||
if err := r.CheckRequiredTools(); err != nil {
|
||||
@@ -167,8 +164,8 @@ func runRunner(cmd *cobra.Command, args []string) {
|
||||
runnerID, err = r.Register(apiKey)
|
||||
if err == nil {
|
||||
logger.Infof("Registered runner with ID: %d", runnerID)
|
||||
// Download latest Blender and detect HIP vs NVIDIA so we only force CPU for Blender < 4.x when using HIP
|
||||
logger.Info("Detecting GPU backends (HIP/NVIDIA) for Blender < 4.x policy...")
|
||||
// Detect GPU vendors/backends from host hardware so we only force CPU for Blender < 4.x when using AMD.
|
||||
logger.Info("Detecting GPU backends (AMD/NVIDIA/Intel) from host hardware for Blender < 4.x policy...")
|
||||
r.DetectAndStoreGPUBackends()
|
||||
break
|
||||
}
|
||||
|
||||
@@ -109,5 +109,5 @@ echo "Binary: jiggablend"
|
||||
echo "Wrappers: jiggablend-manager, jiggablend-runner"
|
||||
echo "Run 'jiggablend-manager' to start the manager with test config."
|
||||
echo "Run 'jiggablend-runner [url] [runner flags...]' to start the runner."
|
||||
echo "Example: jiggablend-runner http://your-manager:8080 --force-cpu-rendering --disable-hiprt"
|
||||
echo "Example: jiggablend-runner http://your-manager:8080 --force-cpu-rendering"
|
||||
echo "Note: Depending on whether you're running the manager or runner, additional dependencies like Blender, ImageMagick, or FFmpeg may be required. See the project README for details."
|
||||
@@ -668,24 +668,42 @@ func (a *Auth) IsProductionModeFromConfig() bool {
|
||||
return a.cfg.IsProductionMode()
|
||||
}
|
||||
|
||||
func (a *Auth) writeUnauthorized(w http.ResponseWriter, r *http.Request) {
|
||||
// Keep API behavior unchanged for programmatic clients.
|
||||
if strings.HasPrefix(r.URL.Path, "/api/") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
// For HTMX UI fragment requests, trigger a full-page redirect to login.
|
||||
if strings.EqualFold(r.Header.Get("HX-Request"), "true") {
|
||||
w.Header().Set("HX-Redirect", "/login")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
// For normal browser page requests, redirect to login page.
|
||||
http.Redirect(w, r, "/login", http.StatusFound)
|
||||
}
|
||||
|
||||
// Middleware creates an authentication middleware
|
||||
func (a *Auth) Middleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("session_id")
|
||||
if err != nil {
|
||||
log.Printf("Authentication failed: missing session cookie for %s %s", r.Method, r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
||||
a.writeUnauthorized(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
session, ok := a.GetSession(cookie.Value)
|
||||
if !ok {
|
||||
log.Printf("Authentication failed: invalid session cookie for %s %s", r.Method, r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
||||
a.writeUnauthorized(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -717,18 +735,14 @@ func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
cookie, err := r.Cookie("session_id")
|
||||
if err != nil {
|
||||
log.Printf("Admin authentication failed: missing session cookie for %s %s", r.Method, r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
||||
a.writeUnauthorized(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
session, ok := a.GetSession(cookie.Value)
|
||||
if !ok {
|
||||
log.Printf("Admin authentication failed: invalid session cookie for %s %s", r.Method, r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
||||
a.writeUnauthorized(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -23,6 +23,14 @@ const (
|
||||
KeyProductionMode = "production_mode"
|
||||
KeyAllowedOrigins = "allowed_origins"
|
||||
KeyFramesPerRenderTask = "frames_per_render_task"
|
||||
|
||||
// Operational limits (seconds / bytes / counts)
|
||||
KeyRenderTimeoutSecs = "render_timeout_seconds"
|
||||
KeyEncodeTimeoutSecs = "encode_timeout_seconds"
|
||||
KeyMaxUploadBytes = "max_upload_bytes"
|
||||
KeySessionCookieMaxAge = "session_cookie_max_age"
|
||||
KeyAPIRateLimit = "api_rate_limit"
|
||||
KeyAuthRateLimit = "auth_rate_limit"
|
||||
)
|
||||
|
||||
// Config manages application configuration stored in the database
|
||||
@@ -311,3 +319,34 @@ func (c *Config) GetFramesPerRenderTask() int {
|
||||
return n
|
||||
}
|
||||
|
||||
// RenderTimeoutSeconds returns the per-frame render timeout in seconds (default 3600 = 1 hour).
|
||||
func (c *Config) RenderTimeoutSeconds() int {
|
||||
return c.GetIntWithDefault(KeyRenderTimeoutSecs, 3600)
|
||||
}
|
||||
|
||||
// EncodeTimeoutSeconds returns the video encode timeout in seconds (default 86400 = 24 hours).
|
||||
func (c *Config) EncodeTimeoutSeconds() int {
|
||||
return c.GetIntWithDefault(KeyEncodeTimeoutSecs, 86400)
|
||||
}
|
||||
|
||||
// MaxUploadBytes returns the maximum upload size in bytes (default 50 GB).
|
||||
func (c *Config) MaxUploadBytes() int64 {
|
||||
v := c.GetIntWithDefault(KeyMaxUploadBytes, 50<<30)
|
||||
return int64(v)
|
||||
}
|
||||
|
||||
// SessionCookieMaxAgeSec returns the session cookie max-age in seconds (default 86400 = 24 hours).
|
||||
func (c *Config) SessionCookieMaxAgeSec() int {
|
||||
return c.GetIntWithDefault(KeySessionCookieMaxAge, 86400)
|
||||
}
|
||||
|
||||
// APIRateLimitPerMinute returns the API rate limit (requests per minute per IP, default 100).
|
||||
func (c *Config) APIRateLimitPerMinute() int {
|
||||
return c.GetIntWithDefault(KeyAPIRateLimit, 100)
|
||||
}
|
||||
|
||||
// AuthRateLimitPerMinute returns the auth rate limit (requests per minute per IP, default 10).
|
||||
func (c *Config) AuthRateLimitPerMinute() int {
|
||||
return c.GetIntWithDefault(KeyAuthRateLimit, 10)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
-- SQLite does not support DROP COLUMN directly; recreate the table without last_used_at.
|
||||
CREATE TABLE runner_api_keys_backup (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
key_prefix TEXT NOT NULL,
|
||||
key_hash TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
scope TEXT NOT NULL DEFAULT 'user',
|
||||
is_active INTEGER NOT NULL DEFAULT 1,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
created_by INTEGER,
|
||||
FOREIGN KEY (created_by) REFERENCES users(id),
|
||||
UNIQUE(key_prefix)
|
||||
);
|
||||
|
||||
INSERT INTO runner_api_keys_backup SELECT id, key_prefix, key_hash, name, description, scope, is_active, created_at, created_by FROM runner_api_keys;
|
||||
|
||||
DROP TABLE runner_api_keys;
|
||||
|
||||
ALTER TABLE runner_api_keys_backup RENAME TO runner_api_keys;
|
||||
|
||||
CREATE INDEX idx_runner_api_keys_prefix ON runner_api_keys(key_prefix);
|
||||
CREATE INDEX idx_runner_api_keys_active ON runner_api_keys(is_active);
|
||||
CREATE INDEX idx_runner_api_keys_created_by ON runner_api_keys(created_by);
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE runner_api_keys ADD COLUMN last_used_at TIMESTAMP;
|
||||
@@ -3,7 +3,6 @@ package api
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/bzip2"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -16,6 +15,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"jiggablend/pkg/blendfile"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -439,144 +440,16 @@ func (s *Manager) cleanupExtractedBlenderFolders(blenderDir string, version *Ble
|
||||
}
|
||||
}
|
||||
|
||||
// ParseBlenderVersionFromFile parses the Blender version that a .blend file was saved with
|
||||
// This reads the file header to determine the version
|
||||
// ParseBlenderVersionFromFile parses the Blender version that a .blend file was saved with.
|
||||
// Delegates to the shared pkg/blendfile implementation.
|
||||
func ParseBlenderVersionFromFile(blendPath string) (major, minor int, err error) {
|
||||
file, err := os.Open(blendPath)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to open blend file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
return ParseBlenderVersionFromReader(file)
|
||||
return blendfile.ParseVersionFromFile(blendPath)
|
||||
}
|
||||
|
||||
// ParseBlenderVersionFromReader parses the Blender version from a reader
|
||||
// Useful for reading from uploaded files without saving to disk first
|
||||
// ParseBlenderVersionFromReader parses the Blender version from a reader.
|
||||
// Delegates to the shared pkg/blendfile implementation.
|
||||
func ParseBlenderVersionFromReader(r io.ReadSeeker) (major, minor int, err error) {
|
||||
// Read the first 12 bytes of the blend file header
|
||||
// Format: BLENDER-v<major><minor><patch> or BLENDER_v<major><minor><patch>
|
||||
// The header is: "BLENDER" (7 bytes) + pointer size (1 byte: '-' for 64-bit, '_' for 32-bit)
|
||||
// + endianness (1 byte: 'v' for little-endian, 'V' for big-endian)
|
||||
// + version (3 bytes: e.g., "402" for 4.02)
|
||||
header := make([]byte, 12)
|
||||
n, err := r.Read(header)
|
||||
if err != nil || n < 12 {
|
||||
return 0, 0, fmt.Errorf("failed to read blend file header: %w", err)
|
||||
}
|
||||
|
||||
// Check for BLENDER magic
|
||||
if string(header[:7]) != "BLENDER" {
|
||||
// Might be compressed - try to decompress
|
||||
r.Seek(0, 0)
|
||||
return parseCompressedBlendVersion(r)
|
||||
}
|
||||
|
||||
// Parse version from bytes 9-11 (3 digits)
|
||||
versionStr := string(header[9:12])
|
||||
var vMajor, vMinor int
|
||||
|
||||
// Version format changed in Blender 3.0
|
||||
// Pre-3.0: "279" = 2.79, "280" = 2.80
|
||||
// 3.0+: "300" = 3.0, "402" = 4.02, "410" = 4.10
|
||||
if len(versionStr) == 3 {
|
||||
// First digit is major version
|
||||
fmt.Sscanf(string(versionStr[0]), "%d", &vMajor)
|
||||
// Next two digits are minor version
|
||||
fmt.Sscanf(versionStr[1:3], "%d", &vMinor)
|
||||
}
|
||||
|
||||
return vMajor, vMinor, nil
|
||||
}
|
||||
|
||||
// parseCompressedBlendVersion handles gzip and zstd compressed blend files
|
||||
func parseCompressedBlendVersion(r io.ReadSeeker) (major, minor int, err error) {
|
||||
// Check for compression magic bytes
|
||||
magic := make([]byte, 4)
|
||||
if _, err := r.Read(magic); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
r.Seek(0, 0)
|
||||
|
||||
if magic[0] == 0x1f && magic[1] == 0x8b {
|
||||
// gzip compressed
|
||||
gzReader, err := gzip.NewReader(r)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
}
|
||||
defer gzReader.Close()
|
||||
|
||||
header := make([]byte, 12)
|
||||
n, err := gzReader.Read(header)
|
||||
if err != nil || n < 12 {
|
||||
return 0, 0, fmt.Errorf("failed to read compressed blend header: %w", err)
|
||||
}
|
||||
|
||||
if string(header[:7]) != "BLENDER" {
|
||||
return 0, 0, fmt.Errorf("invalid blend file format")
|
||||
}
|
||||
|
||||
versionStr := string(header[9:12])
|
||||
var vMajor, vMinor int
|
||||
if len(versionStr) == 3 {
|
||||
fmt.Sscanf(string(versionStr[0]), "%d", &vMajor)
|
||||
fmt.Sscanf(versionStr[1:3], "%d", &vMinor)
|
||||
}
|
||||
|
||||
return vMajor, vMinor, nil
|
||||
}
|
||||
|
||||
// Check for zstd magic (Blender 3.0+): 0x28 0xB5 0x2F 0xFD
|
||||
if magic[0] == 0x28 && magic[1] == 0xb5 && magic[2] == 0x2f && magic[3] == 0xfd {
|
||||
return parseZstdBlendVersion(r)
|
||||
}
|
||||
|
||||
return 0, 0, fmt.Errorf("unknown blend file format")
|
||||
}
|
||||
|
||||
// parseZstdBlendVersion handles zstd-compressed blend files (Blender 3.0+)
|
||||
// Uses zstd command line tool since Go doesn't have native zstd support
|
||||
func parseZstdBlendVersion(r io.ReadSeeker) (major, minor int, err error) {
|
||||
r.Seek(0, 0)
|
||||
|
||||
// We need to decompress just enough to read the header
|
||||
// Use zstd command to decompress from stdin
|
||||
cmd := exec.Command("zstd", "-d", "-c")
|
||||
cmd.Stdin = r
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to create zstd stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to start zstd decompression: %w", err)
|
||||
}
|
||||
|
||||
// Read just the header (12 bytes)
|
||||
header := make([]byte, 12)
|
||||
n, readErr := io.ReadFull(stdout, header)
|
||||
|
||||
// Kill the process early - we only need the header
|
||||
cmd.Process.Kill()
|
||||
cmd.Wait()
|
||||
|
||||
if readErr != nil || n < 12 {
|
||||
return 0, 0, fmt.Errorf("failed to read zstd compressed blend header: %v", readErr)
|
||||
}
|
||||
|
||||
if string(header[:7]) != "BLENDER" {
|
||||
return 0, 0, fmt.Errorf("invalid blend file format in zstd archive")
|
||||
}
|
||||
|
||||
versionStr := string(header[9:12])
|
||||
var vMajor, vMinor int
|
||||
if len(versionStr) == 3 {
|
||||
fmt.Sscanf(string(versionStr[0]), "%d", &vMajor)
|
||||
fmt.Sscanf(versionStr[1:3], "%d", &vMinor)
|
||||
}
|
||||
|
||||
return vMajor, vMinor, nil
|
||||
return blendfile.ParseVersionFromReader(r)
|
||||
}
|
||||
|
||||
// handleGetBlenderVersions returns available Blender versions
|
||||
@@ -713,7 +586,7 @@ func (s *Manager) handleDownloadBlender(w http.ResponseWriter, r *http.Request)
|
||||
tarFilename = strings.TrimSuffix(tarFilename, ".bz2")
|
||||
|
||||
w.Header().Set("Content-Type", "application/x-tar")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", tarFilename))
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", tarFilename))
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
|
||||
w.Header().Set("X-Blender-Version", blenderVersion.Full)
|
||||
|
||||
|
||||
@@ -345,9 +345,9 @@ func (s *Manager) handleCreateJob(w http.ResponseWriter, r *http.Request) {
|
||||
// Only create render tasks for render jobs
|
||||
if req.JobType == types.JobTypeRender {
|
||||
// Determine task timeout based on output format
|
||||
taskTimeout := RenderTimeout // 1 hour for render jobs
|
||||
taskTimeout := s.renderTimeout
|
||||
if *req.OutputFormat == "EXR_264_MP4" || *req.OutputFormat == "EXR_AV1_MP4" || *req.OutputFormat == "EXR_VP9_WEBM" {
|
||||
taskTimeout = VideoEncodeTimeout // 24 hours for encoding
|
||||
taskTimeout = s.videoEncodeTimeout
|
||||
}
|
||||
|
||||
// Create tasks for the job (batch INSERT in a single transaction)
|
||||
@@ -390,7 +390,7 @@ func (s *Manager) handleCreateJob(w http.ResponseWriter, r *http.Request) {
|
||||
// Create encode task immediately if output format requires it
|
||||
// The task will have a condition that prevents it from being assigned until all render tasks are completed
|
||||
if *req.OutputFormat == "EXR_264_MP4" || *req.OutputFormat == "EXR_AV1_MP4" || *req.OutputFormat == "EXR_VP9_WEBM" {
|
||||
encodeTaskTimeout := VideoEncodeTimeout // 24 hours for encoding
|
||||
encodeTaskTimeout := s.videoEncodeTimeout
|
||||
conditionJSON := `{"type": "all_render_tasks_completed"}`
|
||||
var encodeTaskID int64
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
@@ -2592,7 +2592,7 @@ func (s *Manager) handleDownloadJobFile(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// Set headers
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("%s; filename=%s", disposition, fileName))
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("%s; filename=%q", disposition, fileName))
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
|
||||
// Stream file
|
||||
@@ -2710,7 +2710,7 @@ func (s *Manager) handleDownloadEXRZip(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
fileName := fmt.Sprintf("%s-exr.zip", safeJobName)
|
||||
w.Header().Set("Content-Type", "application/zip")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", fileName))
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", fileName))
|
||||
|
||||
zipWriter := zip.NewWriter(w)
|
||||
defer zipWriter.Close()
|
||||
@@ -2881,7 +2881,7 @@ func (s *Manager) handlePreviewEXR(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Set headers
|
||||
pngFileName := strings.TrimSuffix(fileName, filepath.Ext(fileName)) + ".png"
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("inline; filename=%s", pngFileName))
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("inline; filename=%q", pngFileName))
|
||||
w.Header().Set("Content-Type", "image/png")
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(pngData)))
|
||||
|
||||
|
||||
@@ -30,27 +30,22 @@ import (
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// Configuration constants
|
||||
// Configuration constants (non-configurable infrastructure values)
|
||||
const (
|
||||
// WebSocket timeouts
|
||||
WSReadDeadline = 90 * time.Second
|
||||
WSPingInterval = 30 * time.Second
|
||||
WSWriteDeadline = 10 * time.Second
|
||||
|
||||
// Task timeouts
|
||||
RenderTimeout = 60 * 60 // 1 hour for frame rendering
|
||||
VideoEncodeTimeout = 60 * 60 * 24 // 24 hours for encoding
|
||||
|
||||
// Limits
|
||||
MaxUploadSize = 50 << 30 // 50 GB
|
||||
// Infrastructure timers
|
||||
RunnerHeartbeatTimeout = 90 * time.Second
|
||||
TaskDistributionInterval = 10 * time.Second
|
||||
ProgressUpdateThrottle = 2 * time.Second
|
||||
|
||||
// Cookie settings
|
||||
SessionCookieMaxAge = 86400 // 24 hours
|
||||
)
|
||||
|
||||
// Operational limits are loaded from database config at Manager initialization.
|
||||
// Defaults are defined in internal/config/config.go convenience methods.
|
||||
|
||||
// Manager represents the manager server
|
||||
type Manager struct {
|
||||
db *database.DB
|
||||
@@ -109,6 +104,12 @@ type Manager struct {
|
||||
|
||||
// Server start time for health checks
|
||||
startTime time.Time
|
||||
|
||||
// Configurable operational values loaded from config
|
||||
renderTimeout int // seconds
|
||||
videoEncodeTimeout int // seconds
|
||||
maxUploadSize int64 // bytes
|
||||
sessionCookieMaxAge int // seconds
|
||||
}
|
||||
|
||||
// ClientConnection represents a client WebSocket connection with subscriptions
|
||||
@@ -166,6 +167,11 @@ func NewManager(db *database.DB, cfg *config.Config, auth *authpkg.Auth, storage
|
||||
router: chi.NewRouter(),
|
||||
ui: ui,
|
||||
startTime: time.Now(),
|
||||
|
||||
renderTimeout: cfg.RenderTimeoutSeconds(),
|
||||
videoEncodeTimeout: cfg.EncodeTimeoutSeconds(),
|
||||
maxUploadSize: cfg.MaxUploadBytes(),
|
||||
sessionCookieMaxAge: cfg.SessionCookieMaxAgeSec(),
|
||||
wsUpgrader: websocket.Upgrader{
|
||||
CheckOrigin: checkWebSocketOrigin,
|
||||
ReadBufferSize: 1024,
|
||||
@@ -189,6 +195,10 @@ func NewManager(db *database.DB, cfg *config.Config, auth *authpkg.Auth, storage
|
||||
jobStatusUpdateMu: make(map[int64]*sync.Mutex),
|
||||
}
|
||||
|
||||
// Initialize rate limiters from config
|
||||
apiRateLimiter = NewRateLimiter(cfg.APIRateLimitPerMinute(), time.Minute)
|
||||
authRateLimiter = NewRateLimiter(cfg.AuthRateLimitPerMinute(), time.Minute)
|
||||
|
||||
// Check for required external tools
|
||||
if err := s.checkRequiredTools(); err != nil {
|
||||
return nil, err
|
||||
@@ -267,6 +277,7 @@ type RateLimiter struct {
|
||||
mu sync.RWMutex
|
||||
limit int // max requests
|
||||
window time.Duration // time window
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter
|
||||
@@ -275,12 +286,17 @@ func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
|
||||
requests: make(map[string][]time.Time),
|
||||
limit: limit,
|
||||
window: window,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
// Start cleanup goroutine
|
||||
go rl.cleanup()
|
||||
return rl
|
||||
}
|
||||
|
||||
// Stop shuts down the cleanup goroutine.
|
||||
func (rl *RateLimiter) Stop() {
|
||||
close(rl.stopChan)
|
||||
}
|
||||
|
||||
// Allow checks if a request from the given IP is allowed
|
||||
func (rl *RateLimiter) Allow(ip string) bool {
|
||||
rl.mu.Lock()
|
||||
@@ -313,7 +329,11 @@ func (rl *RateLimiter) Allow(ip string) bool {
|
||||
// cleanup periodically removes old entries
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
for range ticker.C {
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
rl.mu.Lock()
|
||||
cutoff := time.Now().Add(-rl.window)
|
||||
for ip, reqs := range rl.requests {
|
||||
@@ -330,15 +350,16 @@ func (rl *RateLimiter) cleanup() {
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
case <-rl.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Global rate limiters for different endpoint types
|
||||
// Rate limiters — initialized per Manager instance in NewManager.
|
||||
var (
|
||||
// General API rate limiter: 100 requests per minute per IP
|
||||
apiRateLimiter = NewRateLimiter(100, time.Minute)
|
||||
// Auth rate limiter: 10 requests per minute per IP (stricter for login attempts)
|
||||
authRateLimiter = NewRateLimiter(10, time.Minute)
|
||||
apiRateLimiter *RateLimiter
|
||||
authRateLimiter *RateLimiter
|
||||
)
|
||||
|
||||
// rateLimitMiddleware applies rate limiting based on client IP
|
||||
@@ -610,17 +631,16 @@ func (s *Manager) respondError(w http.ResponseWriter, status int, message string
|
||||
}
|
||||
|
||||
// createSessionCookie creates a secure session cookie with appropriate flags for the environment
|
||||
func createSessionCookie(sessionID string) *http.Cookie {
|
||||
func (s *Manager) createSessionCookie(sessionID string) *http.Cookie {
|
||||
cookie := &http.Cookie{
|
||||
Name: "session_id",
|
||||
Value: sessionID,
|
||||
Path: "/",
|
||||
MaxAge: SessionCookieMaxAge,
|
||||
MaxAge: s.sessionCookieMaxAge,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
|
||||
// In production mode, set Secure flag to require HTTPS
|
||||
if authpkg.IsProductionMode() {
|
||||
cookie.Secure = true
|
||||
}
|
||||
@@ -712,7 +732,7 @@ func (s *Manager) handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
sessionID := s.auth.CreateSession(session)
|
||||
http.SetCookie(w, createSessionCookie(sessionID))
|
||||
http.SetCookie(w, s.createSessionCookie(sessionID))
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
}
|
||||
@@ -745,7 +765,7 @@ func (s *Manager) handleDiscordCallback(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
sessionID := s.auth.CreateSession(session)
|
||||
http.SetCookie(w, createSessionCookie(sessionID))
|
||||
http.SetCookie(w, s.createSessionCookie(sessionID))
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
}
|
||||
@@ -838,7 +858,7 @@ func (s *Manager) handleLocalRegister(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
sessionID := s.auth.CreateSession(session)
|
||||
http.SetCookie(w, createSessionCookie(sessionID))
|
||||
http.SetCookie(w, s.createSessionCookie(sessionID))
|
||||
|
||||
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"message": "Registration successful",
|
||||
@@ -875,7 +895,7 @@ func (s *Manager) handleLocalLogin(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
sessionID := s.auth.CreateSession(session)
|
||||
http.SetCookie(w, createSessionCookie(sessionID))
|
||||
http.SetCookie(w, s.createSessionCookie(sessionID))
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"message": "Login successful",
|
||||
|
||||
@@ -3,6 +3,7 @@ package api
|
||||
import (
|
||||
"fmt"
|
||||
"html/template"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -92,13 +93,17 @@ func newUIRenderer() (*uiRenderer, error) {
|
||||
func (r *uiRenderer) render(w http.ResponseWriter, data pageData) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if err := r.templates.ExecuteTemplate(w, "base", data); err != nil {
|
||||
log.Printf("Template render error: %v", err)
|
||||
http.Error(w, "template render error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (r *uiRenderer) renderTemplate(w http.ResponseWriter, templateName string, data interface{}) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if err := r.templates.ExecuteTemplate(w, templateName, data); err != nil {
|
||||
log.Printf("Template render error for %s: %v", templateName, err)
|
||||
http.Error(w, "template render error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -765,7 +765,7 @@ func (s *Manager) handleDownloadJobContext(w http.ResponseWriter, r *http.Reques
|
||||
|
||||
// Set appropriate headers for tar file
|
||||
w.Header().Set("Content-Type", "application/x-tar")
|
||||
w.Header().Set("Content-Disposition", "attachment; filename=context.tar")
|
||||
w.Header().Set("Content-Disposition", "attachment; filename=\"context.tar\"")
|
||||
|
||||
// Stream the file to the response
|
||||
io.Copy(w, file)
|
||||
@@ -821,7 +821,7 @@ func (s *Manager) handleDownloadJobContextWithToken(w http.ResponseWriter, r *ht
|
||||
|
||||
// Set appropriate headers for tar file
|
||||
w.Header().Set("Content-Type", "application/x-tar")
|
||||
w.Header().Set("Content-Disposition", "attachment; filename=context.tar")
|
||||
w.Header().Set("Content-Disposition", "attachment; filename=\"context.tar\"")
|
||||
|
||||
// Stream the file to the response
|
||||
io.Copy(w, file)
|
||||
@@ -836,7 +836,7 @@ func (s *Manager) handleUploadFileFromRunner(w http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
|
||||
err = r.ParseMultipartForm(MaxUploadSize) // 50 GB (for large output files)
|
||||
err = r.ParseMultipartForm(s.maxUploadSize)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Failed to parse multipart form: %v", err))
|
||||
return
|
||||
@@ -944,7 +944,7 @@ func (s *Manager) handleUploadFileWithToken(w http.ResponseWriter, r *http.Reque
|
||||
return
|
||||
}
|
||||
|
||||
err = r.ParseMultipartForm(MaxUploadSize) // 50 GB (for large output files)
|
||||
err = r.ParseMultipartForm(s.maxUploadSize)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Failed to parse multipart form: %v", err))
|
||||
return
|
||||
@@ -1228,7 +1228,7 @@ func (s *Manager) handleDownloadFileForRunner(w http.ResponseWriter, r *http.Req
|
||||
|
||||
// Set headers
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", decodedFileName))
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", decodedFileName))
|
||||
|
||||
// Stream file
|
||||
io.Copy(w, file)
|
||||
@@ -1476,40 +1476,33 @@ func (s *Manager) handleRunnerJobWebSocket(w http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
}
|
||||
case "runner_heartbeat":
|
||||
// Lookup runner ID from job's assigned_runner_id
|
||||
s.handleWSRunnerHeartbeat(conn, jobID)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleWSRunnerHeartbeat processes a runner heartbeat received over a job WebSocket.
|
||||
func (s *Manager) handleWSRunnerHeartbeat(conn *websocket.Conn, jobID int64) {
|
||||
var assignedRunnerID sql.NullInt64
|
||||
err := s.db.With(func(db *sql.DB) error {
|
||||
return db.QueryRow(
|
||||
"SELECT assigned_runner_id FROM jobs WHERE id = ?",
|
||||
jobID,
|
||||
"SELECT assigned_runner_id FROM jobs WHERE id = ?", jobID,
|
||||
).Scan(&assignedRunnerID)
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Failed to lookup runner for job %d heartbeat: %v", jobID, err)
|
||||
// Send error response
|
||||
response := map[string]interface{}{
|
||||
"type": "error",
|
||||
"message": "Failed to process heartbeat",
|
||||
}
|
||||
s.sendWebSocketMessage(conn, response)
|
||||
continue
|
||||
s.sendWebSocketMessage(conn, map[string]interface{}{"type": "error", "message": "Failed to process heartbeat"})
|
||||
return
|
||||
}
|
||||
|
||||
if !assignedRunnerID.Valid {
|
||||
log.Printf("Job %d has no assigned runner, skipping heartbeat update", jobID)
|
||||
// Send acknowledgment but no database update
|
||||
response := map[string]interface{}{
|
||||
"type": "heartbeat_ack",
|
||||
"timestamp": time.Now().Unix(),
|
||||
"message": "No assigned runner for this job",
|
||||
}
|
||||
s.sendWebSocketMessage(conn, response)
|
||||
continue
|
||||
s.sendWebSocketMessage(conn, map[string]interface{}{"type": "heartbeat_ack", "timestamp": time.Now().Unix(), "message": "No assigned runner for this job"})
|
||||
return
|
||||
}
|
||||
|
||||
runnerID := assignedRunnerID.Int64
|
||||
|
||||
// Update runner heartbeat
|
||||
err = s.db.With(func(db *sql.DB) error {
|
||||
_, err := db.Exec(
|
||||
"UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?",
|
||||
@@ -1519,25 +1512,11 @@ func (s *Manager) handleRunnerJobWebSocket(w http.ResponseWriter, r *http.Reques
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Failed to update runner %d heartbeat for job %d: %v", runnerID, jobID, err)
|
||||
// Send error response
|
||||
response := map[string]interface{}{
|
||||
"type": "error",
|
||||
"message": "Failed to update heartbeat",
|
||||
}
|
||||
s.sendWebSocketMessage(conn, response)
|
||||
continue
|
||||
s.sendWebSocketMessage(conn, map[string]interface{}{"type": "error", "message": "Failed to update heartbeat"})
|
||||
return
|
||||
}
|
||||
|
||||
// Send acknowledgment
|
||||
response := map[string]interface{}{
|
||||
"type": "heartbeat_ack",
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
s.sendWebSocketMessage(conn, response)
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
s.sendWebSocketMessage(conn, map[string]interface{}{"type": "heartbeat_ack", "timestamp": time.Now().Unix()})
|
||||
}
|
||||
|
||||
// handleWebSocketLog handles log entries from WebSocket
|
||||
@@ -1948,162 +1927,164 @@ func (s *Manager) cleanupJobStatusUpdateMutex(jobID int64) {
|
||||
// This function is serialized per jobID to prevent race conditions when multiple tasks
|
||||
// complete concurrently and trigger status updates simultaneously.
|
||||
func (s *Manager) updateJobStatusFromTasks(jobID int64) {
|
||||
// Serialize updates per job to prevent race conditions
|
||||
mu := s.getJobStatusUpdateMutex(jobID)
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// All jobs now use parallel runners (one task per frame), so we always use task-based progress
|
||||
|
||||
// Get current job status to detect changes
|
||||
var currentStatus string
|
||||
err := s.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow(`SELECT status FROM jobs WHERE id = ?`, jobID).Scan(¤tStatus)
|
||||
})
|
||||
currentStatus, err := s.getJobStatus(jobID)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get current job status for job %d: %v", jobID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Cancellation is terminal from the user's perspective.
|
||||
// Do not allow asynchronous task updates to revive cancelled jobs.
|
||||
if currentStatus == string(types.JobStatusCancelled) {
|
||||
return
|
||||
}
|
||||
|
||||
// Count total tasks and completed tasks
|
||||
var totalTasks, completedTasks int
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
err := conn.QueryRow(
|
||||
counts, err := s.getJobTaskCounts(jobID)
|
||||
if err != nil {
|
||||
log.Printf("Failed to count tasks for job %d: %v", jobID, err)
|
||||
return
|
||||
}
|
||||
|
||||
progress := counts.progress()
|
||||
|
||||
if counts.pendingOrRunning == 0 && counts.total > 0 {
|
||||
s.handleAllTasksFinished(jobID, currentStatus, counts, progress)
|
||||
} else {
|
||||
s.handleTasksInProgress(jobID, currentStatus, counts, progress)
|
||||
}
|
||||
}
|
||||
|
||||
// jobTaskCounts holds task state counts for a job.
|
||||
type jobTaskCounts struct {
|
||||
total int
|
||||
completed int
|
||||
pendingOrRunning int
|
||||
failed int
|
||||
running int
|
||||
}
|
||||
|
||||
func (c *jobTaskCounts) progress() float64 {
|
||||
if c.total == 0 {
|
||||
return 0.0
|
||||
}
|
||||
return float64(c.completed) / float64(c.total) * 100.0
|
||||
}
|
||||
|
||||
func (s *Manager) getJobStatus(jobID int64) (string, error) {
|
||||
var status string
|
||||
err := s.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow(`SELECT status FROM jobs WHERE id = ?`, jobID).Scan(&status)
|
||||
})
|
||||
return status, err
|
||||
}
|
||||
|
||||
func (s *Manager) getJobTaskCounts(jobID int64) (*jobTaskCounts, error) {
|
||||
c := &jobTaskCounts{}
|
||||
err := s.db.With(func(conn *sql.DB) error {
|
||||
if err := conn.QueryRow(
|
||||
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?, ?, ?)`,
|
||||
jobID, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed,
|
||||
).Scan(&totalTasks)
|
||||
if err != nil {
|
||||
).Scan(&c.total); err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.QueryRow(
|
||||
if err := conn.QueryRow(
|
||||
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
|
||||
jobID, types.TaskStatusCompleted,
|
||||
).Scan(&completedTasks)
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Failed to count completed tasks for job %d: %v", jobID, err)
|
||||
return
|
||||
).Scan(&c.completed); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Calculate progress
|
||||
var progress float64
|
||||
if totalTasks == 0 {
|
||||
// All tasks cancelled or no tasks, set progress to 0
|
||||
progress = 0.0
|
||||
} else {
|
||||
// Standard task-based progress
|
||||
progress = float64(completedTasks) / float64(totalTasks) * 100.0
|
||||
}
|
||||
|
||||
var jobStatus string
|
||||
|
||||
// Check if all non-cancelled tasks are completed
|
||||
var pendingOrRunningTasks int
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow(
|
||||
`SELECT COUNT(*) FROM tasks
|
||||
WHERE job_id = ? AND status IN (?, ?)`,
|
||||
if err := conn.QueryRow(
|
||||
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?)`,
|
||||
jobID, types.TaskStatusPending, types.TaskStatusRunning,
|
||||
).Scan(&pendingOrRunningTasks)
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Failed to count pending/running tasks for job %d: %v", jobID, err)
|
||||
return
|
||||
).Scan(&c.pendingOrRunning); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pendingOrRunningTasks == 0 && totalTasks > 0 {
|
||||
// All tasks are either completed or failed/cancelled
|
||||
// Check if any tasks failed
|
||||
var failedTasks int
|
||||
s.db.With(func(conn *sql.DB) error {
|
||||
conn.QueryRow(
|
||||
if err := conn.QueryRow(
|
||||
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
|
||||
jobID, types.TaskStatusFailed,
|
||||
).Scan(&failedTasks)
|
||||
).Scan(&c.failed); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := conn.QueryRow(
|
||||
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
|
||||
jobID, types.TaskStatusRunning,
|
||||
).Scan(&c.running); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return c, err
|
||||
}
|
||||
|
||||
if failedTasks > 0 {
|
||||
// Some tasks failed - check if job has retries left
|
||||
// handleAllTasksFinished handles the case where no pending/running tasks remain.
|
||||
func (s *Manager) handleAllTasksFinished(jobID int64, currentStatus string, counts *jobTaskCounts, progress float64) {
|
||||
now := time.Now()
|
||||
var jobStatus string
|
||||
|
||||
if counts.failed > 0 {
|
||||
jobStatus = s.handleFailedTasks(jobID, currentStatus, &progress)
|
||||
if jobStatus == "" {
|
||||
return // retry handled; early exit
|
||||
}
|
||||
} else {
|
||||
jobStatus = string(types.JobStatusCompleted)
|
||||
progress = 100.0
|
||||
}
|
||||
|
||||
s.setJobFinalStatus(jobID, currentStatus, jobStatus, progress, now, counts)
|
||||
}
|
||||
|
||||
// handleFailedTasks decides whether to retry or mark the job failed.
|
||||
// Returns "" if a retry was triggered (caller should return early),
|
||||
// or the final status string.
|
||||
func (s *Manager) handleFailedTasks(jobID int64, currentStatus string, progress *float64) string {
|
||||
var retryCount, maxRetries int
|
||||
err := s.db.With(func(conn *sql.DB) error {
|
||||
return conn.QueryRow(
|
||||
`SELECT retry_count, max_retries FROM jobs WHERE id = ?`,
|
||||
jobID,
|
||||
`SELECT retry_count, max_retries FROM jobs WHERE id = ?`, jobID,
|
||||
).Scan(&retryCount, &maxRetries)
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Failed to get retry info for job %d: %v", jobID, err)
|
||||
// Fall back to marking job as failed
|
||||
jobStatus = string(types.JobStatusFailed)
|
||||
} else if retryCount < maxRetries {
|
||||
// Job has retries left - reset failed tasks and redistribute
|
||||
return string(types.JobStatusFailed)
|
||||
}
|
||||
|
||||
if retryCount < maxRetries {
|
||||
if err := s.resetFailedTasksAndRedistribute(jobID); err != nil {
|
||||
log.Printf("Failed to reset failed tasks for job %d: %v", jobID, err)
|
||||
// If reset fails, mark job as failed
|
||||
jobStatus = string(types.JobStatusFailed)
|
||||
} else {
|
||||
// Tasks reset successfully - job remains in running/pending state
|
||||
// Don't update job status, just update progress
|
||||
jobStatus = currentStatus // Keep current status
|
||||
// Recalculate progress after reset (failed tasks are now pending again)
|
||||
var newTotalTasks, newCompletedTasks int
|
||||
s.db.With(func(conn *sql.DB) error {
|
||||
conn.QueryRow(
|
||||
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?, ?, ?)`,
|
||||
jobID, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed,
|
||||
).Scan(&newTotalTasks)
|
||||
conn.QueryRow(
|
||||
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
|
||||
jobID, types.TaskStatusCompleted,
|
||||
).Scan(&newCompletedTasks)
|
||||
return nil
|
||||
})
|
||||
if newTotalTasks > 0 {
|
||||
progress = float64(newCompletedTasks) / float64(newTotalTasks) * 100.0
|
||||
return string(types.JobStatusFailed)
|
||||
}
|
||||
// Update progress only
|
||||
err := s.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec(
|
||||
`UPDATE jobs SET progress = ? WHERE id = ?`,
|
||||
progress, jobID,
|
||||
)
|
||||
// Recalculate progress after reset
|
||||
counts, err := s.getJobTaskCounts(jobID)
|
||||
if err == nil && counts.total > 0 {
|
||||
*progress = counts.progress()
|
||||
}
|
||||
err = s.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec(`UPDATE jobs SET progress = ? WHERE id = ?`, *progress, jobID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Failed to update job %d progress: %v", jobID, err)
|
||||
} else {
|
||||
// Broadcast job update via WebSocket
|
||||
s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{
|
||||
"status": jobStatus,
|
||||
"progress": progress,
|
||||
"status": currentStatus,
|
||||
"progress": *progress,
|
||||
})
|
||||
}
|
||||
return // Exit early since we've handled the retry
|
||||
return "" // retry handled
|
||||
}
|
||||
} else {
|
||||
// No retries left - mark job as failed and cancel active tasks
|
||||
jobStatus = string(types.JobStatusFailed)
|
||||
|
||||
// No retries left
|
||||
if err := s.cancelActiveTasksForJob(jobID); err != nil {
|
||||
log.Printf("Failed to cancel active tasks for job %d: %v", jobID, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// All tasks completed successfully
|
||||
jobStatus = string(types.JobStatusCompleted)
|
||||
progress = 100.0 // Ensure progress is 100% when all tasks complete
|
||||
}
|
||||
return string(types.JobStatusFailed)
|
||||
}
|
||||
|
||||
// Update job status (if we didn't return early from retry logic)
|
||||
if jobStatus != "" {
|
||||
// setJobFinalStatus persists the terminal job status and broadcasts the update.
|
||||
func (s *Manager) setJobFinalStatus(jobID int64, currentStatus, jobStatus string, progress float64, now time.Time, counts *jobTaskCounts) {
|
||||
err := s.db.With(func(conn *sql.DB) error {
|
||||
_, err := conn.Exec(
|
||||
`UPDATE jobs SET status = ?, progress = ?, completed_at = ? WHERE id = ?`,
|
||||
@@ -2113,44 +2094,30 @@ func (s *Manager) updateJobStatusFromTasks(jobID int64) {
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Failed to update job %d status to %s: %v", jobID, jobStatus, err)
|
||||
} else {
|
||||
// Only log if status actually changed
|
||||
if currentStatus != jobStatus {
|
||||
log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed tasks: %d/%d)", jobID, currentStatus, jobStatus, progress, completedTasks, totalTasks)
|
||||
return
|
||||
}
|
||||
if currentStatus != jobStatus {
|
||||
log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed tasks: %d/%d)", jobID, currentStatus, jobStatus, progress, counts.completed, counts.total)
|
||||
}
|
||||
// Broadcast job update via WebSocket
|
||||
s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{
|
||||
"status": jobStatus,
|
||||
"progress": progress,
|
||||
"completed_at": now,
|
||||
})
|
||||
// Clean up mutex for jobs in final states (completed or failed)
|
||||
// No more status updates will occur for these jobs
|
||||
if jobStatus == string(types.JobStatusCompleted) || jobStatus == string(types.JobStatusFailed) {
|
||||
s.cleanupJobStatusUpdateMutex(jobID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Encode tasks are now created immediately when the job is created
|
||||
// with a condition that prevents assignment until all render tasks are completed.
|
||||
// No need to create them here anymore.
|
||||
} else {
|
||||
// Job has pending or running tasks - determine if it's running or still pending
|
||||
var runningTasks int
|
||||
s.db.With(func(conn *sql.DB) error {
|
||||
conn.QueryRow(
|
||||
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
|
||||
jobID, types.TaskStatusRunning,
|
||||
).Scan(&runningTasks)
|
||||
return nil
|
||||
})
|
||||
// handleTasksInProgress handles the case where tasks are still pending or running.
|
||||
func (s *Manager) handleTasksInProgress(jobID int64, currentStatus string, counts *jobTaskCounts, progress float64) {
|
||||
now := time.Now()
|
||||
var jobStatus string
|
||||
|
||||
if runningTasks > 0 {
|
||||
// Has running tasks - job is running
|
||||
if counts.running > 0 {
|
||||
jobStatus = string(types.JobStatusRunning)
|
||||
var startedAt sql.NullTime
|
||||
s.db.With(func(conn *sql.DB) error {
|
||||
var startedAt sql.NullTime
|
||||
conn.QueryRow(`SELECT started_at FROM jobs WHERE id = ?`, jobID).Scan(&startedAt)
|
||||
if !startedAt.Valid {
|
||||
conn.Exec(`UPDATE jobs SET started_at = ? WHERE id = ?`, now, jobID)
|
||||
@@ -2158,7 +2125,6 @@ func (s *Manager) updateJobStatusFromTasks(jobID int64) {
|
||||
return nil
|
||||
})
|
||||
} else {
|
||||
// All tasks are pending - job is pending
|
||||
jobStatus = string(types.JobStatusPending)
|
||||
}
|
||||
|
||||
@@ -2171,18 +2137,16 @@ func (s *Manager) updateJobStatusFromTasks(jobID int64) {
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Failed to update job %d status to %s: %v", jobID, jobStatus, err)
|
||||
} else {
|
||||
// Only log if status actually changed
|
||||
if currentStatus != jobStatus {
|
||||
log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed: %d/%d, pending: %d, running: %d)", jobID, currentStatus, jobStatus, progress, completedTasks, totalTasks, pendingOrRunningTasks-runningTasks, runningTasks)
|
||||
return
|
||||
}
|
||||
if currentStatus != jobStatus {
|
||||
pending := counts.pendingOrRunning - counts.running
|
||||
log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed: %d/%d, pending: %d, running: %d)", jobID, currentStatus, jobStatus, progress, counts.completed, counts.total, pending, counts.running)
|
||||
}
|
||||
// Broadcast job update during execution (not just on completion)
|
||||
s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{
|
||||
"status": jobStatus,
|
||||
"progress": progress,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// broadcastLogToFrontend broadcasts log to connected frontend clients
|
||||
|
||||
@@ -241,8 +241,8 @@ func (m *ManagerClient) DownloadContext(contextPath, jobToken string) (io.ReadCl
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("context download failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
@@ -435,8 +435,8 @@ func (m *ManagerClient) DownloadBlender(version string) (io.ReadCloser, error) {
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("failed to download blender: status %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
|
||||
@@ -1,45 +1,116 @@
|
||||
// Package blender: GPU backend detection for HIP vs NVIDIA.
|
||||
// Package blender: host GPU backend detection for AMD/NVIDIA/Intel.
|
||||
package blender
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"jiggablend/pkg/scripts"
|
||||
)
|
||||
|
||||
// DetectGPUBackends runs a minimal Blender script to detect whether HIP (AMD) and/or
|
||||
// NVIDIA (CUDA/OptiX) devices are available. Use this to decide whether to force CPU
|
||||
// for Blender < 4.x (only force when HIP is present, since HIP has no official support pre-4).
|
||||
func DetectGPUBackends(blenderBinary, scriptDir string) (hasHIP, hasNVIDIA bool, err error) {
|
||||
scriptPath := filepath.Join(scriptDir, "detect_gpu_backends.py")
|
||||
if err := os.WriteFile(scriptPath, []byte(scripts.DetectGPUBackends), 0644); err != nil {
|
||||
return false, false, fmt.Errorf("write detection script: %w", err)
|
||||
}
|
||||
defer os.Remove(scriptPath)
|
||||
// DetectGPUBackends detects whether AMD, NVIDIA, and/or Intel GPUs are available
|
||||
// using host-level hardware probing only.
|
||||
func DetectGPUBackends() (hasAMD, hasNVIDIA, hasIntel bool, ok bool) {
|
||||
return detectGPUBackendsFromHost()
|
||||
}
|
||||
|
||||
env := TarballEnv(blenderBinary, os.Environ())
|
||||
cmd := exec.Command(blenderBinary, "-b", "--python", scriptPath)
|
||||
cmd.Env = env
|
||||
cmd.Dir = scriptDir
|
||||
out, err := cmd.CombinedOutput()
|
||||
func detectGPUBackendsFromHost() (hasAMD, hasNVIDIA, hasIntel bool, ok bool) {
|
||||
if amd, nvidia, intel, found := detectGPUBackendsFromDRM(); found {
|
||||
return amd, nvidia, intel, true
|
||||
}
|
||||
if amd, nvidia, intel, found := detectGPUBackendsFromLSPCI(); found {
|
||||
return amd, nvidia, intel, true
|
||||
}
|
||||
return false, false, false, false
|
||||
}
|
||||
|
||||
func detectGPUBackendsFromDRM() (hasAMD, hasNVIDIA, hasIntel bool, ok bool) {
|
||||
entries, err := os.ReadDir("/sys/class/drm")
|
||||
if err != nil {
|
||||
return false, false, fmt.Errorf("run blender detection: %w (output: %s)", err, string(out))
|
||||
return false, false, false, false
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if !isDRMCardNode(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
vendorPath := filepath.Join("/sys/class/drm", name, "device", "vendor")
|
||||
vendorRaw, err := os.ReadFile(vendorPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
vendor := strings.TrimSpace(strings.ToLower(string(vendorRaw)))
|
||||
switch vendor {
|
||||
case "0x1002":
|
||||
hasAMD = true
|
||||
ok = true
|
||||
case "0x10de":
|
||||
hasNVIDIA = true
|
||||
ok = true
|
||||
case "0x8086":
|
||||
hasIntel = true
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
|
||||
return hasAMD, hasNVIDIA, hasIntel, ok
|
||||
}
|
||||
|
||||
func isDRMCardNode(name string) bool {
|
||||
if !strings.HasPrefix(name, "card") {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(name, "-") {
|
||||
// Connector entries like card0-DP-1 are not GPU device nodes.
|
||||
return false
|
||||
}
|
||||
if len(name) <= len("card") {
|
||||
return false
|
||||
}
|
||||
_, err := strconv.Atoi(strings.TrimPrefix(name, "card"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func detectGPUBackendsFromLSPCI() (hasAMD, hasNVIDIA, hasIntel bool, ok bool) {
|
||||
if _, err := exec.LookPath("lspci"); err != nil {
|
||||
return false, false, false, false
|
||||
}
|
||||
|
||||
out, err := exec.Command("lspci").CombinedOutput()
|
||||
if err != nil {
|
||||
return false, false, false, false
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(strings.NewReader(string(out)))
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
switch line {
|
||||
case "HAS_HIP":
|
||||
hasHIP = true
|
||||
case "HAS_NVIDIA":
|
||||
line := strings.ToLower(strings.TrimSpace(scanner.Text()))
|
||||
if !isGPUControllerLine(line) {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.Contains(line, "nvidia") {
|
||||
hasNVIDIA = true
|
||||
ok = true
|
||||
}
|
||||
if strings.Contains(line, "amd") || strings.Contains(line, "ati") || strings.Contains(line, "radeon") {
|
||||
hasAMD = true
|
||||
ok = true
|
||||
}
|
||||
if strings.Contains(line, "intel") {
|
||||
hasIntel = true
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
return hasHIP, hasNVIDIA, scanner.Err()
|
||||
|
||||
return hasAMD, hasNVIDIA, hasIntel, ok
|
||||
}
|
||||
|
||||
func isGPUControllerLine(line string) bool {
|
||||
return strings.Contains(line, "vga compatible controller") ||
|
||||
strings.Contains(line, "3d controller") ||
|
||||
strings.Contains(line, "display controller")
|
||||
}
|
||||
|
||||
@@ -1,143 +1,19 @@
|
||||
package blender
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
||||
"jiggablend/pkg/blendfile"
|
||||
)
|
||||
|
||||
// ParseVersionFromFile parses the Blender version that a .blend file was saved with.
|
||||
// Returns major and minor version numbers.
|
||||
// Delegates to the shared pkg/blendfile implementation.
|
||||
func ParseVersionFromFile(blendPath string) (major, minor int, err error) {
|
||||
file, err := os.Open(blendPath)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to open blend file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Read the first 12 bytes of the blend file header
|
||||
// Format: BLENDER-v<major><minor><patch> or BLENDER_v<major><minor><patch>
|
||||
// The header is: "BLENDER" (7 bytes) + pointer size (1 byte: '-' for 64-bit, '_' for 32-bit)
|
||||
// + endianness (1 byte: 'v' for little-endian, 'V' for big-endian)
|
||||
// + version (3 bytes: e.g., "402" for 4.02)
|
||||
header := make([]byte, 12)
|
||||
n, err := file.Read(header)
|
||||
if err != nil || n < 12 {
|
||||
return 0, 0, fmt.Errorf("failed to read blend file header: %w", err)
|
||||
}
|
||||
|
||||
// Check for BLENDER magic
|
||||
if string(header[:7]) != "BLENDER" {
|
||||
// Might be compressed - try to decompress
|
||||
file.Seek(0, 0)
|
||||
return parseCompressedVersion(file)
|
||||
}
|
||||
|
||||
// Parse version from bytes 9-11 (3 digits)
|
||||
versionStr := string(header[9:12])
|
||||
|
||||
// Version format changed in Blender 3.0
|
||||
// Pre-3.0: "279" = 2.79, "280" = 2.80
|
||||
// 3.0+: "300" = 3.0, "402" = 4.02, "410" = 4.10
|
||||
if len(versionStr) == 3 {
|
||||
// First digit is major version
|
||||
fmt.Sscanf(string(versionStr[0]), "%d", &major)
|
||||
// Next two digits are minor version
|
||||
fmt.Sscanf(versionStr[1:3], "%d", &minor)
|
||||
}
|
||||
|
||||
return major, minor, nil
|
||||
}
|
||||
|
||||
// parseCompressedVersion handles gzip and zstd compressed blend files.
|
||||
func parseCompressedVersion(file *os.File) (major, minor int, err error) {
|
||||
magic := make([]byte, 4)
|
||||
if _, err := file.Read(magic); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
file.Seek(0, 0)
|
||||
|
||||
if magic[0] == 0x1f && magic[1] == 0x8b {
|
||||
// gzip compressed
|
||||
gzReader, err := gzip.NewReader(file)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
}
|
||||
defer gzReader.Close()
|
||||
|
||||
header := make([]byte, 12)
|
||||
n, err := gzReader.Read(header)
|
||||
if err != nil || n < 12 {
|
||||
return 0, 0, fmt.Errorf("failed to read compressed blend header: %w", err)
|
||||
}
|
||||
|
||||
if string(header[:7]) != "BLENDER" {
|
||||
return 0, 0, fmt.Errorf("invalid blend file format")
|
||||
}
|
||||
|
||||
versionStr := string(header[9:12])
|
||||
if len(versionStr) == 3 {
|
||||
fmt.Sscanf(string(versionStr[0]), "%d", &major)
|
||||
fmt.Sscanf(versionStr[1:3], "%d", &minor)
|
||||
}
|
||||
|
||||
return major, minor, nil
|
||||
}
|
||||
|
||||
// Check for zstd magic (Blender 3.0+): 0x28 0xB5 0x2F 0xFD
|
||||
if magic[0] == 0x28 && magic[1] == 0xb5 && magic[2] == 0x2f && magic[3] == 0xfd {
|
||||
return parseZstdVersion(file)
|
||||
}
|
||||
|
||||
return 0, 0, fmt.Errorf("unknown blend file format")
|
||||
}
|
||||
|
||||
// parseZstdVersion handles zstd-compressed blend files (Blender 3.0+).
|
||||
// Uses zstd command line tool since Go doesn't have native zstd support.
|
||||
func parseZstdVersion(file *os.File) (major, minor int, err error) {
|
||||
file.Seek(0, 0)
|
||||
|
||||
cmd := exec.Command("zstd", "-d", "-c")
|
||||
cmd.Stdin = file
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to create zstd stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to start zstd decompression: %w", err)
|
||||
}
|
||||
|
||||
// Read just the header (12 bytes)
|
||||
header := make([]byte, 12)
|
||||
n, readErr := io.ReadFull(stdout, header)
|
||||
|
||||
// Kill the process early - we only need the header
|
||||
cmd.Process.Kill()
|
||||
cmd.Wait()
|
||||
|
||||
if readErr != nil || n < 12 {
|
||||
return 0, 0, fmt.Errorf("failed to read zstd compressed blend header: %v", readErr)
|
||||
}
|
||||
|
||||
if string(header[:7]) != "BLENDER" {
|
||||
return 0, 0, fmt.Errorf("invalid blend file format in zstd archive")
|
||||
}
|
||||
|
||||
versionStr := string(header[9:12])
|
||||
if len(versionStr) == 3 {
|
||||
fmt.Sscanf(string(versionStr[0]), "%d", &major)
|
||||
fmt.Sscanf(versionStr[1:3], "%d", &minor)
|
||||
}
|
||||
|
||||
return major, minor, nil
|
||||
return blendfile.ParseVersionFromFile(blendPath)
|
||||
}
|
||||
|
||||
// VersionString returns a formatted version string like "4.2".
|
||||
func VersionString(major, minor int) string {
|
||||
return fmt.Sprintf("%d.%d", major, minor)
|
||||
}
|
||||
|
||||
|
||||
@@ -46,23 +46,22 @@ type Runner struct {
|
||||
gpuLockedOut bool
|
||||
gpuLockedOutMu sync.RWMutex
|
||||
|
||||
// hasHIP/hasNVIDIA are set at startup by running latest Blender to detect GPU backends.
|
||||
// Used to force CPU only for Blender < 4.x when HIP is present (no official HIP support pre-4).
|
||||
// gpuDetectionFailed is true when detection could not run; we then force CPU for all versions (we could not determine HIP vs NVIDIA).
|
||||
// hasAMD/hasNVIDIA/hasIntel are set at startup by hardware/Blender GPU backend detection.
|
||||
// Used to force CPU only for Blender < 4.x when AMD is present (no official HIP support pre-4).
|
||||
// gpuDetectionFailed is true when detection could not run; we then force CPU for all versions.
|
||||
gpuBackendMu sync.RWMutex
|
||||
hasHIP bool
|
||||
hasAMD bool
|
||||
hasNVIDIA bool
|
||||
hasIntel bool
|
||||
gpuBackendProbed bool
|
||||
gpuDetectionFailed bool
|
||||
|
||||
// forceCPURendering forces CPU rendering for all jobs regardless of metadata/backend detection.
|
||||
forceCPURendering bool
|
||||
// disableHIPRT disables HIPRT acceleration when configuring Cycles HIP devices.
|
||||
disableHIPRT bool
|
||||
}
|
||||
|
||||
// New creates a new runner.
|
||||
func New(managerURL, name, hostname string, forceCPURendering, disableHIPRT bool) *Runner {
|
||||
func New(managerURL, name, hostname string, forceCPURendering bool) *Runner {
|
||||
manager := api.NewManagerClient(managerURL)
|
||||
|
||||
r := &Runner{
|
||||
@@ -74,7 +73,6 @@ func New(managerURL, name, hostname string, forceCPURendering, disableHIPRT bool
|
||||
processors: make(map[string]tasks.Processor),
|
||||
|
||||
forceCPURendering: forceCPURendering,
|
||||
disableHIPRT: disableHIPRT,
|
||||
}
|
||||
|
||||
// Generate fingerprint
|
||||
@@ -93,17 +91,16 @@ func (r *Runner) CheckRequiredTools() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var cachedCapabilities map[string]interface{} = nil
|
||||
var (
|
||||
cachedCapabilities map[string]interface{}
|
||||
capabilitiesOnce sync.Once
|
||||
)
|
||||
|
||||
// ProbeCapabilities detects hardware capabilities.
|
||||
func (r *Runner) ProbeCapabilities() map[string]interface{} {
|
||||
if cachedCapabilities != nil {
|
||||
return cachedCapabilities
|
||||
}
|
||||
|
||||
capabilitiesOnce.Do(func() {
|
||||
caps := make(map[string]interface{})
|
||||
|
||||
// Check for ffmpeg and probe encoding capabilities
|
||||
if err := exec.Command("ffmpeg", "-version").Run(); err == nil {
|
||||
caps["ffmpeg"] = true
|
||||
} else {
|
||||
@@ -111,7 +108,8 @@ func (r *Runner) ProbeCapabilities() map[string]interface{} {
|
||||
}
|
||||
|
||||
cachedCapabilities = caps
|
||||
return caps
|
||||
})
|
||||
return cachedCapabilities
|
||||
}
|
||||
|
||||
// Register registers the runner with the manager.
|
||||
@@ -141,52 +139,66 @@ func (r *Runner) Register(apiKey string) (int64, error) {
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// DetectAndStoreGPUBackends downloads the latest Blender from the manager (if needed),
|
||||
// runs a detection script to see if HIP (AMD) and/or NVIDIA devices are available,
|
||||
// and stores the result. Call after Register. Used so we only force CPU for Blender < 4.x
|
||||
// when the runner has HIP (no official HIP support pre-4); NVIDIA is allowed.
|
||||
// DetectAndStoreGPUBackends runs host-level backend detection and stores AMD/NVIDIA/Intel results.
|
||||
// Call after Register. Used so we only force CPU for Blender < 4.x when AMD is present.
|
||||
func (r *Runner) DetectAndStoreGPUBackends() {
|
||||
r.gpuBackendMu.Lock()
|
||||
defer r.gpuBackendMu.Unlock()
|
||||
if r.gpuBackendProbed {
|
||||
return
|
||||
}
|
||||
latestVer, err := r.manager.GetLatestBlenderVersion()
|
||||
if err != nil {
|
||||
log.Printf("GPU backend detection failed (could not get latest Blender version: %v). All jobs will use CPU because we could not determine HIP vs NVIDIA.", err)
|
||||
hasAMD, hasNVIDIA, hasIntel, ok := blender.DetectGPUBackends()
|
||||
if !ok {
|
||||
log.Printf("GPU backend detection failed (host probe unavailable). All jobs will use CPU because backend availability is unknown.")
|
||||
r.gpuBackendProbed = true
|
||||
r.gpuDetectionFailed = true
|
||||
return
|
||||
}
|
||||
binaryPath, err := r.blender.GetBinaryPath(latestVer)
|
||||
if err != nil {
|
||||
log.Printf("GPU backend detection failed (could not get Blender binary: %v). All jobs will use CPU because we could not determine HIP vs NVIDIA.", err)
|
||||
r.gpuBackendProbed = true
|
||||
r.gpuDetectionFailed = true
|
||||
return
|
||||
|
||||
detectedTypes := 0
|
||||
if hasAMD {
|
||||
detectedTypes++
|
||||
}
|
||||
hasHIP, hasNVIDIA, err := blender.DetectGPUBackends(binaryPath, r.workspace.BaseDir())
|
||||
if err != nil {
|
||||
log.Printf("GPU backend detection failed (script error: %v). All jobs will use CPU because we could not determine HIP vs NVIDIA.", err)
|
||||
r.gpuBackendProbed = true
|
||||
r.gpuDetectionFailed = true
|
||||
return
|
||||
if hasNVIDIA {
|
||||
detectedTypes++
|
||||
}
|
||||
r.hasHIP = hasHIP
|
||||
if hasIntel {
|
||||
detectedTypes++
|
||||
}
|
||||
if detectedTypes > 1 {
|
||||
log.Printf("mixed GPU vendors detected (AMD=%v NVIDIA=%v INTEL=%v): multi-vendor setups may not work reliably, but runner will continue with GPU enabled", hasAMD, hasNVIDIA, hasIntel)
|
||||
}
|
||||
|
||||
r.hasAMD = hasAMD
|
||||
r.hasNVIDIA = hasNVIDIA
|
||||
r.hasIntel = hasIntel
|
||||
r.gpuBackendProbed = true
|
||||
r.gpuDetectionFailed = false
|
||||
log.Printf("GPU backend detection: HIP=%v NVIDIA=%v (Blender < 4.x will force CPU only when HIP is present)", hasHIP, hasNVIDIA)
|
||||
log.Printf("GPU backend detection: AMD=%v NVIDIA=%v INTEL=%v (Blender < 4.x will force CPU only when AMD is present)", hasAMD, hasNVIDIA, hasIntel)
|
||||
}
|
||||
|
||||
// HasHIP returns whether the runner detected HIP (AMD) devices. Used to force CPU for Blender < 4.x only when HIP is present.
|
||||
func (r *Runner) HasHIP() bool {
|
||||
// HasAMD returns whether the runner detected AMD devices. Used to force CPU for Blender < 4.x only when AMD is present.
|
||||
func (r *Runner) HasAMD() bool {
|
||||
r.gpuBackendMu.RLock()
|
||||
defer r.gpuBackendMu.RUnlock()
|
||||
return r.hasHIP
|
||||
return r.hasAMD
|
||||
}
|
||||
|
||||
// GPUDetectionFailed returns true when startup GPU backend detection could not run or failed. When true, all jobs use CPU because we could not determine HIP vs NVIDIA.
|
||||
// HasNVIDIA returns whether the runner detected NVIDIA GPUs.
|
||||
func (r *Runner) HasNVIDIA() bool {
|
||||
r.gpuBackendMu.RLock()
|
||||
defer r.gpuBackendMu.RUnlock()
|
||||
return r.hasNVIDIA
|
||||
}
|
||||
|
||||
// HasIntel returns whether the runner detected Intel GPUs (e.g. Arc).
|
||||
func (r *Runner) HasIntel() bool {
|
||||
r.gpuBackendMu.RLock()
|
||||
defer r.gpuBackendMu.RUnlock()
|
||||
return r.hasIntel
|
||||
}
|
||||
|
||||
// GPUDetectionFailed returns true when startup GPU backend detection could not run or failed. When true, all jobs use CPU because backend availability is unknown.
|
||||
func (r *Runner) GPUDetectionFailed() bool {
|
||||
r.gpuBackendMu.RLock()
|
||||
defer r.gpuBackendMu.RUnlock()
|
||||
@@ -313,10 +325,11 @@ func (r *Runner) executeJob(job *api.NextJobResponse) (err error) {
|
||||
r.encoder,
|
||||
r.processes,
|
||||
r.IsGPULockedOut(),
|
||||
r.HasHIP(),
|
||||
r.HasAMD(),
|
||||
r.HasNVIDIA(),
|
||||
r.HasIntel(),
|
||||
r.GPUDetectionFailed(),
|
||||
r.forceCPURendering,
|
||||
r.disableHIPRT,
|
||||
func() { r.SetGPULockedOut(true) },
|
||||
)
|
||||
|
||||
|
||||
@@ -298,6 +298,9 @@ func (p *EncodeProcessor) Process(ctx *Context) error {
|
||||
ctx.Info(line)
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Printf("Error reading encode stdout: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Stream stderr
|
||||
@@ -311,6 +314,9 @@ func (p *EncodeProcessor) Process(ctx *Context) error {
|
||||
ctx.Warn(line)
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Printf("Error reading encode stderr: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
err = cmd.Wait()
|
||||
|
||||
@@ -11,8 +11,6 @@ import (
|
||||
"jiggablend/pkg/executils"
|
||||
"jiggablend/pkg/types"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -43,23 +41,25 @@ type Context struct {
|
||||
|
||||
// GPULockedOut is set when the runner has detected a GPU error (e.g. HIP) and disables GPU for all jobs.
|
||||
GPULockedOut bool
|
||||
// HasHIP is true when the runner detected HIP (AMD) devices at startup. Used to force CPU for Blender < 4.x only when HIP is present.
|
||||
HasHIP bool
|
||||
// GPUDetectionFailed is true when startup GPU backend detection could not run; we force CPU for all versions (could not determine HIP vs NVIDIA).
|
||||
// HasAMD is true when the runner detected AMD devices at startup.
|
||||
HasAMD bool
|
||||
// HasNVIDIA is true when the runner detected NVIDIA GPUs at startup.
|
||||
HasNVIDIA bool
|
||||
// HasIntel is true when the runner detected Intel GPUs (e.g. Arc) at startup.
|
||||
HasIntel bool
|
||||
// GPUDetectionFailed is true when startup GPU backend detection could not run; we force CPU for all versions (backend availability unknown).
|
||||
GPUDetectionFailed bool
|
||||
// OnGPUError is called when a GPU error line is seen in render logs; typically sets runner GPU lockout.
|
||||
OnGPUError func()
|
||||
// ForceCPURendering is a runner-level override that forces CPU rendering for all jobs.
|
||||
ForceCPURendering bool
|
||||
// DisableHIPRT is a runner-level override that disables HIPRT acceleration in Blender.
|
||||
DisableHIPRT bool
|
||||
}
|
||||
|
||||
// ErrJobCancelled indicates the manager-side job was cancelled during execution.
|
||||
var ErrJobCancelled = errors.New("job cancelled")
|
||||
|
||||
// NewContext creates a new task context. frameEnd should be >= frame; if 0 or less than frame, it is treated as single-frame (frameEnd = frame).
|
||||
// gpuLockedOut is the runner's current GPU lockout state; hasHIP means the runner has HIP (AMD) devices (force CPU for Blender < 4.x only when true); gpuDetectionFailed means detection failed at startup (force CPU for all versions—could not determine HIP vs NVIDIA); onGPUError is called when a GPU error is detected in logs (may be nil).
|
||||
// gpuLockedOut is the runner's current GPU lockout state; gpuDetectionFailed means detection failed at startup (force CPU for all versions); onGPUError is called when a GPU error is detected in logs (may be nil).
|
||||
func NewContext(
|
||||
taskID, jobID int64,
|
||||
jobName string,
|
||||
@@ -75,10 +75,11 @@ func NewContext(
|
||||
encoder *encoding.Selector,
|
||||
processes *executils.ProcessTracker,
|
||||
gpuLockedOut bool,
|
||||
hasHIP bool,
|
||||
hasAMD bool,
|
||||
hasNVIDIA bool,
|
||||
hasIntel bool,
|
||||
gpuDetectionFailed bool,
|
||||
forceCPURendering bool,
|
||||
disableHIPRT bool,
|
||||
onGPUError func(),
|
||||
) *Context {
|
||||
if frameEnd < frameStart {
|
||||
@@ -101,10 +102,11 @@ func NewContext(
|
||||
Encoder: encoder,
|
||||
Processes: processes,
|
||||
GPULockedOut: gpuLockedOut,
|
||||
HasHIP: hasHIP,
|
||||
HasAMD: hasAMD,
|
||||
HasNVIDIA: hasNVIDIA,
|
||||
HasIntel: hasIntel,
|
||||
GPUDetectionFailed: gpuDetectionFailed,
|
||||
ForceCPURendering: forceCPURendering,
|
||||
DisableHIPRT: disableHIPRT,
|
||||
OnGPUError: onGPUError,
|
||||
}
|
||||
}
|
||||
@@ -187,8 +189,7 @@ func (c *Context) ShouldEnableExecution() bool {
|
||||
}
|
||||
|
||||
// ShouldForceCPU returns true if GPU should be disabled and CPU rendering forced
|
||||
// (runner GPU lockout, GPU detection failed at startup for any version, metadata force_cpu,
|
||||
// or Blender < 4.x when the runner has HIP).
|
||||
// (runner GPU lockout, GPU detection failed at startup, or metadata force_cpu).
|
||||
func (c *Context) ShouldForceCPU() bool {
|
||||
if c.ForceCPURendering {
|
||||
return true
|
||||
@@ -196,17 +197,10 @@ func (c *Context) ShouldForceCPU() bool {
|
||||
if c.GPULockedOut {
|
||||
return true
|
||||
}
|
||||
// Detection failed at startup: we could not determine HIP vs NVIDIA, so force CPU for all versions.
|
||||
// Detection failed at startup: backend availability unknown, so force CPU for all versions.
|
||||
if c.GPUDetectionFailed {
|
||||
return true
|
||||
}
|
||||
v := c.GetBlenderVersion()
|
||||
major := parseBlenderMajor(v)
|
||||
isPre4 := v != "" && major >= 0 && major < 4
|
||||
// Blender < 4.x: force CPU when runner has HIP (no official HIP support).
|
||||
if isPre4 && c.HasHIP {
|
||||
return true
|
||||
}
|
||||
if c.Metadata != nil && c.Metadata.RenderSettings.EngineSettings != nil {
|
||||
if v, ok := c.Metadata.RenderSettings.EngineSettings["force_cpu"]; ok {
|
||||
if b, ok := v.(bool); ok && b {
|
||||
@@ -217,21 +211,6 @@ func (c *Context) ShouldForceCPU() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// parseBlenderMajor returns the major version number from a string like "4.2.3" or "3.6".
|
||||
// Returns -1 if the version cannot be parsed.
|
||||
func parseBlenderMajor(version string) int {
|
||||
version = strings.TrimSpace(version)
|
||||
if version == "" {
|
||||
return -1
|
||||
}
|
||||
parts := strings.SplitN(version, ".", 2)
|
||||
major, err := strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
return major
|
||||
}
|
||||
|
||||
// IsJobCancelled checks whether the manager marked this job as cancelled.
|
||||
func (c *Context) IsJobCancelled() (bool, error) {
|
||||
if c.Manager == nil {
|
||||
|
||||
@@ -104,15 +104,10 @@ func (p *RenderProcessor) Process(ctx *Context) error {
|
||||
renderFormat := "EXR"
|
||||
|
||||
if ctx.ShouldForceCPU() {
|
||||
v := ctx.GetBlenderVersion()
|
||||
major := parseBlenderMajor(v)
|
||||
isPre4 := v != "" && major >= 0 && major < 4
|
||||
if ctx.ForceCPURendering {
|
||||
ctx.Info("Runner compatibility flag is enabled: forcing CPU rendering for this job")
|
||||
} else if ctx.GPUDetectionFailed {
|
||||
ctx.Info("GPU backend detection failed at startup—we could not determine whether this machine has HIP (AMD) or NVIDIA GPUs, so rendering will use CPU to avoid compatibility issues")
|
||||
} else if isPre4 && ctx.HasHIP {
|
||||
ctx.Info("Blender < 4.x has no official HIP support: using CPU rendering only")
|
||||
ctx.Info("GPU backend detection failed at startup—we could not determine available GPU backends, so rendering will use CPU to avoid compatibility issues")
|
||||
} else {
|
||||
ctx.Info("GPU lockout active: using CPU rendering only")
|
||||
}
|
||||
@@ -195,7 +190,6 @@ func (p *RenderProcessor) createRenderScript(ctx *Context, renderFormat string)
|
||||
settingsMap = make(map[string]interface{})
|
||||
}
|
||||
settingsMap["force_cpu"] = ctx.ShouldForceCPU()
|
||||
settingsMap["disable_hiprt"] = ctx.DisableHIPRT
|
||||
settingsJSON, err := json.Marshal(settingsMap)
|
||||
if err == nil {
|
||||
if err := os.WriteFile(renderSettingsFilePath, settingsJSON, 0644); err != nil {
|
||||
@@ -277,6 +271,9 @@ func (p *RenderProcessor) runBlender(ctx *Context, blenderBinary, blendFile, out
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Printf("Error reading stdout: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Stream stderr and watch for GPU error lines
|
||||
@@ -297,6 +294,9 @@ func (p *RenderProcessor) runBlender(ctx *Context, blenderBinary, blendFile, out
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Printf("Error reading stderr: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for completion
|
||||
|
||||
@@ -99,6 +99,11 @@ func ExtractTarStripPrefix(reader io.Reader, destDir string) error {
|
||||
|
||||
targetPath := filepath.Join(destDir, name)
|
||||
|
||||
// Sanitize path to prevent directory traversal
|
||||
if !strings.HasPrefix(filepath.Clean(targetPath), filepath.Clean(destDir)+string(os.PathSeparator)) {
|
||||
return fmt.Errorf("invalid file path in tar: %s", header.Name)
|
||||
}
|
||||
|
||||
switch header.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil {
|
||||
|
||||
101
internal/runner/workspace/archive_test.go
Normal file
101
internal/runner/workspace/archive_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package workspace
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func createTarBuffer(files map[string]string) *bytes.Buffer {
|
||||
var buf bytes.Buffer
|
||||
tw := tar.NewWriter(&buf)
|
||||
for name, content := range files {
|
||||
hdr := &tar.Header{
|
||||
Name: name,
|
||||
Mode: 0644,
|
||||
Size: int64(len(content)),
|
||||
}
|
||||
tw.WriteHeader(hdr)
|
||||
tw.Write([]byte(content))
|
||||
}
|
||||
tw.Close()
|
||||
return &buf
|
||||
}
|
||||
|
||||
func TestExtractTar(t *testing.T) {
|
||||
destDir := t.TempDir()
|
||||
|
||||
buf := createTarBuffer(map[string]string{
|
||||
"hello.txt": "world",
|
||||
"sub/a.txt": "nested",
|
||||
})
|
||||
|
||||
if err := ExtractTar(buf, destDir); err != nil {
|
||||
t.Fatalf("ExtractTar: %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(destDir, "hello.txt"))
|
||||
if err != nil {
|
||||
t.Fatalf("read hello.txt: %v", err)
|
||||
}
|
||||
if string(data) != "world" {
|
||||
t.Errorf("hello.txt = %q, want %q", data, "world")
|
||||
}
|
||||
|
||||
data, err = os.ReadFile(filepath.Join(destDir, "sub", "a.txt"))
|
||||
if err != nil {
|
||||
t.Fatalf("read sub/a.txt: %v", err)
|
||||
}
|
||||
if string(data) != "nested" {
|
||||
t.Errorf("sub/a.txt = %q, want %q", data, "nested")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTarStripPrefix(t *testing.T) {
|
||||
destDir := t.TempDir()
|
||||
|
||||
buf := createTarBuffer(map[string]string{
|
||||
"toplevel/": "",
|
||||
"toplevel/foo.txt": "bar",
|
||||
})
|
||||
|
||||
if err := ExtractTarStripPrefix(buf, destDir); err != nil {
|
||||
t.Fatalf("ExtractTarStripPrefix: %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(destDir, "foo.txt"))
|
||||
if err != nil {
|
||||
t.Fatalf("read foo.txt: %v", err)
|
||||
}
|
||||
if string(data) != "bar" {
|
||||
t.Errorf("foo.txt = %q, want %q", data, "bar")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTarStripPrefix_PathTraversal(t *testing.T) {
|
||||
destDir := t.TempDir()
|
||||
|
||||
buf := createTarBuffer(map[string]string{
|
||||
"prefix/../../../etc/passwd": "pwned",
|
||||
})
|
||||
|
||||
err := ExtractTarStripPrefix(buf, destDir)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for path traversal, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTar_PathTraversal(t *testing.T) {
|
||||
destDir := t.TempDir()
|
||||
|
||||
buf := createTarBuffer(map[string]string{
|
||||
"../../../etc/passwd": "pwned",
|
||||
})
|
||||
|
||||
err := ExtractTar(buf, destDir)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for path traversal, got nil")
|
||||
}
|
||||
}
|
||||
@@ -82,6 +82,9 @@ func (s *Storage) JobPath(jobID int64) string {
|
||||
|
||||
// SaveUpload saves an uploaded file
|
||||
func (s *Storage) SaveUpload(jobID int64, filename string, reader io.Reader) (string, error) {
|
||||
// Sanitize filename to prevent path traversal
|
||||
filename = filepath.Base(filename)
|
||||
|
||||
jobPath := s.JobPath(jobID)
|
||||
if err := os.MkdirAll(jobPath, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create job directory: %w", err)
|
||||
|
||||
95
internal/storage/storage_test.go
Normal file
95
internal/storage/storage_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func setupStorage(t *testing.T) *Storage {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
s, err := NewStorage(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewStorage: %v", err)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func TestSaveUpload(t *testing.T) {
|
||||
s := setupStorage(t)
|
||||
path, err := s.SaveUpload(1, "test.blend", strings.NewReader("data"))
|
||||
if err != nil {
|
||||
t.Fatalf("SaveUpload: %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("read saved file: %v", err)
|
||||
}
|
||||
if string(data) != "data" {
|
||||
t.Errorf("got %q, want %q", data, "data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveUpload_PathTraversal(t *testing.T) {
|
||||
s := setupStorage(t)
|
||||
path, err := s.SaveUpload(1, "../../etc/passwd", strings.NewReader("evil"))
|
||||
if err != nil {
|
||||
t.Fatalf("SaveUpload: %v", err)
|
||||
}
|
||||
|
||||
// filepath.Base strips traversal, so the file should be inside the job dir
|
||||
if !strings.HasPrefix(path, s.JobPath(1)) {
|
||||
t.Errorf("saved file %q escaped job directory %q", path, s.JobPath(1))
|
||||
}
|
||||
|
||||
if filepath.Base(path) != "passwd" {
|
||||
t.Errorf("expected basename 'passwd', got %q", filepath.Base(path))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveOutput(t *testing.T) {
|
||||
s := setupStorage(t)
|
||||
path, err := s.SaveOutput(42, "output.png", strings.NewReader("img"))
|
||||
if err != nil {
|
||||
t.Fatalf("SaveOutput: %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("read saved output: %v", err)
|
||||
}
|
||||
if string(data) != "img" {
|
||||
t.Errorf("got %q, want %q", data, "img")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFile(t *testing.T) {
|
||||
s := setupStorage(t)
|
||||
savedPath, err := s.SaveUpload(1, "readme.txt", strings.NewReader("hello"))
|
||||
if err != nil {
|
||||
t.Fatalf("SaveUpload: %v", err)
|
||||
}
|
||||
|
||||
f, err := s.GetFile(savedPath)
|
||||
if err != nil {
|
||||
t.Fatalf("GetFile: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
buf := make([]byte, 64)
|
||||
n, _ := f.Read(buf)
|
||||
if string(buf[:n]) != "hello" {
|
||||
t.Errorf("got %q, want %q", string(buf[:n]), "hello")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJobPath(t *testing.T) {
|
||||
s := setupStorage(t)
|
||||
path := s.JobPath(99)
|
||||
if !strings.Contains(path, "99") {
|
||||
t.Errorf("JobPath(99) = %q, expected to contain '99'", path)
|
||||
}
|
||||
}
|
||||
123
pkg/blendfile/version.go
Normal file
123
pkg/blendfile/version.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package blendfile
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// ParseVersionFromReader parses the Blender version from a reader.
|
||||
// Returns major and minor version numbers.
|
||||
//
|
||||
// Blend file header layout (12 bytes):
|
||||
//
|
||||
// "BLENDER" (7) + pointer-size (1: '-'=64, '_'=32) + endian (1: 'v'=LE, 'V'=BE)
|
||||
// + version (3 digits, e.g. "402" = 4.02)
|
||||
//
|
||||
// Supports uncompressed, gzip-compressed, and zstd-compressed blend files.
|
||||
func ParseVersionFromReader(r io.ReadSeeker) (major, minor int, err error) {
|
||||
header := make([]byte, 12)
|
||||
n, err := r.Read(header)
|
||||
if err != nil || n < 12 {
|
||||
return 0, 0, fmt.Errorf("failed to read blend file header: %w", err)
|
||||
}
|
||||
|
||||
if string(header[:7]) != "BLENDER" {
|
||||
r.Seek(0, 0)
|
||||
return parseCompressedVersion(r)
|
||||
}
|
||||
|
||||
return parseVersionDigits(header[9:12])
|
||||
}
|
||||
|
||||
// ParseVersionFromFile opens a blend file and parses the Blender version.
|
||||
func ParseVersionFromFile(blendPath string) (major, minor int, err error) {
|
||||
file, err := os.Open(blendPath)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to open blend file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
return ParseVersionFromReader(file)
|
||||
}
|
||||
|
||||
// VersionString returns a formatted version string like "4.2".
|
||||
func VersionString(major, minor int) string {
|
||||
return fmt.Sprintf("%d.%d", major, minor)
|
||||
}
|
||||
|
||||
func parseVersionDigits(versionBytes []byte) (major, minor int, err error) {
|
||||
if len(versionBytes) != 3 {
|
||||
return 0, 0, fmt.Errorf("expected 3 version digits, got %d", len(versionBytes))
|
||||
}
|
||||
fmt.Sscanf(string(versionBytes[0]), "%d", &major)
|
||||
fmt.Sscanf(string(versionBytes[1:3]), "%d", &minor)
|
||||
return major, minor, nil
|
||||
}
|
||||
|
||||
func parseCompressedVersion(r io.ReadSeeker) (major, minor int, err error) {
|
||||
magic := make([]byte, 4)
|
||||
if _, err := r.Read(magic); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
r.Seek(0, 0)
|
||||
|
||||
// gzip: 0x1f 0x8b
|
||||
if magic[0] == 0x1f && magic[1] == 0x8b {
|
||||
gzReader, err := gzip.NewReader(r)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
}
|
||||
defer gzReader.Close()
|
||||
|
||||
header := make([]byte, 12)
|
||||
n, err := gzReader.Read(header)
|
||||
if err != nil || n < 12 {
|
||||
return 0, 0, fmt.Errorf("failed to read compressed blend header: %w", err)
|
||||
}
|
||||
if string(header[:7]) != "BLENDER" {
|
||||
return 0, 0, fmt.Errorf("invalid blend file format")
|
||||
}
|
||||
return parseVersionDigits(header[9:12])
|
||||
}
|
||||
|
||||
// zstd: 0x28 0xB5 0x2F 0xFD
|
||||
if magic[0] == 0x28 && magic[1] == 0xb5 && magic[2] == 0x2f && magic[3] == 0xfd {
|
||||
return parseZstdVersion(r)
|
||||
}
|
||||
|
||||
return 0, 0, fmt.Errorf("unknown blend file format")
|
||||
}
|
||||
|
||||
func parseZstdVersion(r io.ReadSeeker) (major, minor int, err error) {
|
||||
r.Seek(0, 0)
|
||||
|
||||
cmd := exec.Command("zstd", "-d", "-c")
|
||||
cmd.Stdin = r
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to create zstd stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to start zstd decompression: %w", err)
|
||||
}
|
||||
|
||||
header := make([]byte, 12)
|
||||
n, readErr := io.ReadFull(stdout, header)
|
||||
|
||||
cmd.Process.Kill()
|
||||
cmd.Wait()
|
||||
|
||||
if readErr != nil || n < 12 {
|
||||
return 0, 0, fmt.Errorf("failed to read zstd compressed blend header: %v", readErr)
|
||||
}
|
||||
if string(header[:7]) != "BLENDER" {
|
||||
return 0, 0, fmt.Errorf("invalid blend file format in zstd archive")
|
||||
}
|
||||
|
||||
return parseVersionDigits(header[9:12])
|
||||
}
|
||||
96
pkg/blendfile/version_test.go
Normal file
96
pkg/blendfile/version_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package blendfile
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func makeBlendHeader(major, minor int) []byte {
|
||||
header := make([]byte, 12)
|
||||
copy(header[:7], "BLENDER")
|
||||
header[7] = '-'
|
||||
header[8] = 'v'
|
||||
header[9] = byte('0' + major)
|
||||
header[10] = byte('0' + minor/10)
|
||||
header[11] = byte('0' + minor%10)
|
||||
return header
|
||||
}
|
||||
|
||||
func TestParseVersionFromReader_Uncompressed(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
major int
|
||||
minor int
|
||||
wantMajor int
|
||||
wantMinor int
|
||||
}{
|
||||
{"Blender 4.02", 4, 2, 4, 2},
|
||||
{"Blender 3.06", 3, 6, 3, 6},
|
||||
{"Blender 2.79", 2, 79, 2, 79},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
header := makeBlendHeader(tt.major, tt.minor)
|
||||
r := bytes.NewReader(header)
|
||||
|
||||
major, minor, err := ParseVersionFromReader(r)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseVersionFromReader: %v", err)
|
||||
}
|
||||
if major != tt.wantMajor || minor != tt.wantMinor {
|
||||
t.Errorf("got %d.%d, want %d.%d", major, minor, tt.wantMajor, tt.wantMinor)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseVersionFromReader_GzipCompressed(t *testing.T) {
|
||||
header := makeBlendHeader(4, 2)
|
||||
// Pad to ensure gzip has enough data for a full read
|
||||
data := make([]byte, 128)
|
||||
copy(data, header)
|
||||
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
gz.Write(data)
|
||||
gz.Close()
|
||||
|
||||
r := bytes.NewReader(buf.Bytes())
|
||||
major, minor, err := ParseVersionFromReader(r)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseVersionFromReader (gzip): %v", err)
|
||||
}
|
||||
if major != 4 || minor != 2 {
|
||||
t.Errorf("got %d.%d, want 4.2", major, minor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseVersionFromReader_InvalidMagic(t *testing.T) {
|
||||
data := []byte("NOT_BLEND_DATA_HERE")
|
||||
r := bytes.NewReader(data)
|
||||
|
||||
_, _, err := ParseVersionFromReader(r)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid magic, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseVersionFromReader_TooShort(t *testing.T) {
|
||||
data := []byte("SHORT")
|
||||
r := bytes.NewReader(data)
|
||||
|
||||
_, _, err := ParseVersionFromReader(r)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for short data, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVersionString(t *testing.T) {
|
||||
got := VersionString(4, 2)
|
||||
want := "4.2"
|
||||
if got != want {
|
||||
t.Errorf("VersionString(4, 2) = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
@@ -361,6 +361,9 @@ func RunCommandWithStreaming(
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil && !isBenignPipeReadError(err) {
|
||||
logSender(taskID, types.LogLevelWarn, fmt.Sprintf("stdout read error: %v", err), stepName)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
@@ -375,6 +378,9 @@ func RunCommandWithStreaming(
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil && !isBenignPipeReadError(err) {
|
||||
logSender(taskID, types.LogLevelWarn, fmt.Sprintf("stderr read error: %v", err), stepName)
|
||||
}
|
||||
}()
|
||||
|
||||
err = cmd.Wait()
|
||||
|
||||
@@ -11,6 +11,3 @@ var UnhideObjects string
|
||||
//go:embed scripts/render_blender.py.template
|
||||
var RenderBlenderTemplate string
|
||||
|
||||
//go:embed scripts/detect_gpu_backends.py
|
||||
var DetectGPUBackends string
|
||||
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# Minimal script to detect HIP (AMD) and NVIDIA (CUDA/OptiX) backends for Cycles.
|
||||
# Run with: blender -b --python detect_gpu_backends.py
|
||||
# Prints HAS_HIP and/or HAS_NVIDIA to stdout, one per line.
|
||||
import sys
|
||||
|
||||
def main():
|
||||
try:
|
||||
prefs = bpy.context.preferences
|
||||
if not hasattr(prefs, 'addons') or 'cycles' not in prefs.addons:
|
||||
return
|
||||
cprefs = prefs.addons['cycles'].preferences
|
||||
has_hip = False
|
||||
has_nvidia = False
|
||||
for device_type in ('HIP', 'CUDA', 'OPTIX'):
|
||||
try:
|
||||
cprefs.compute_device_type = device_type
|
||||
cprefs.refresh_devices()
|
||||
devs = []
|
||||
if hasattr(cprefs, 'get_devices'):
|
||||
devs = cprefs.get_devices()
|
||||
elif hasattr(cprefs, 'devices') and cprefs.devices:
|
||||
devs = list(cprefs.devices) if hasattr(cprefs.devices, '__iter__') else [cprefs.devices]
|
||||
if devs:
|
||||
if device_type == 'HIP':
|
||||
has_hip = True
|
||||
if device_type in ('CUDA', 'OPTIX'):
|
||||
has_nvidia = True
|
||||
except Exception:
|
||||
pass
|
||||
if has_hip:
|
||||
print('HAS_HIP', flush=True)
|
||||
if has_nvidia:
|
||||
print('HAS_NVIDIA', flush=True)
|
||||
except Exception as e:
|
||||
print('ERROR', str(e), file=sys.stderr, flush=True)
|
||||
sys.exit(1)
|
||||
|
||||
import bpy
|
||||
main()
|
||||
@@ -175,13 +175,9 @@ if render_settings_override:
|
||||
if current_engine == 'CYCLES':
|
||||
# Check if CPU rendering is forced
|
||||
force_cpu = False
|
||||
disable_hiprt = False
|
||||
if render_settings_override and render_settings_override.get('force_cpu'):
|
||||
force_cpu = render_settings_override.get('force_cpu', False)
|
||||
print("Force CPU rendering is enabled - skipping GPU detection")
|
||||
if render_settings_override and render_settings_override.get('disable_hiprt'):
|
||||
disable_hiprt = render_settings_override.get('disable_hiprt', False)
|
||||
print("Disable HIPRT flag is enabled")
|
||||
|
||||
# Ensure Cycles addon is enabled
|
||||
try:
|
||||
@@ -213,9 +209,10 @@ if current_engine == 'CYCLES':
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
# Check all devices and choose the best GPU type
|
||||
# Device type preference order (most performant first)
|
||||
device_type_preference = ['OPTIX', 'CUDA', 'HIP', 'ONEAPI', 'METAL']
|
||||
# Check all devices and choose the best GPU type.
|
||||
# Explicit fallback policy: NVIDIA -> Intel -> AMD -> CPU.
|
||||
# (OPTIX/CUDA are NVIDIA, ONEAPI is Intel, HIP/OPENCL are AMD)
|
||||
device_type_preference = ['OPTIX', 'CUDA', 'ONEAPI', 'HIP', 'OPENCL']
|
||||
gpu_available = False
|
||||
best_device_type = None
|
||||
best_gpu_devices = []
|
||||
@@ -325,16 +322,7 @@ if current_engine == 'CYCLES':
|
||||
try:
|
||||
if best_device_type == 'HIP':
|
||||
# HIPRT (HIP Ray Tracing) for AMD GPUs
|
||||
if disable_hiprt:
|
||||
if hasattr(cycles_prefs, 'use_hiprt'):
|
||||
cycles_prefs.use_hiprt = False
|
||||
print(f" Disabled HIPRT (HIP Ray Tracing) via runner compatibility flag")
|
||||
elif hasattr(scene.cycles, 'use_hiprt'):
|
||||
scene.cycles.use_hiprt = False
|
||||
print(f" Disabled HIPRT (HIP Ray Tracing) via runner compatibility flag")
|
||||
else:
|
||||
print(f" HIPRT toggle not available on this Blender version")
|
||||
elif hasattr(cycles_prefs, 'use_hiprt'):
|
||||
cycles_prefs.use_hiprt = True
|
||||
print(f" Enabled HIPRT (HIP Ray Tracing) for faster rendering")
|
||||
elif hasattr(scene.cycles, 'use_hiprt'):
|
||||
@@ -356,16 +344,6 @@ if current_engine == 'CYCLES':
|
||||
scene.cycles.use_optix_denoising = True
|
||||
print(f" Enabled OptiX denoising (if OptiX available)")
|
||||
print(f" CUDA ray tracing active")
|
||||
elif best_device_type == 'METAL':
|
||||
# MetalRT for Apple Silicon (if available)
|
||||
if hasattr(scene.cycles, 'use_metalrt'):
|
||||
scene.cycles.use_metalrt = True
|
||||
print(f" Enabled MetalRT (Metal Ray Tracing) for faster rendering")
|
||||
elif hasattr(cycles_prefs, 'use_metalrt'):
|
||||
cycles_prefs.use_metalrt = True
|
||||
print(f" Enabled MetalRT (Metal Ray Tracing) for faster rendering")
|
||||
else:
|
||||
print(f" MetalRT not available")
|
||||
elif best_device_type == 'ONEAPI':
|
||||
# Intel oneAPI - Embree might be available
|
||||
if hasattr(scene.cycles, 'use_embree'):
|
||||
|
||||
Reference in New Issue
Block a user