Compare commits
4 Commits
5303f01f7c
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| a3defe5cf6 | |||
| 16d6a95058 | |||
| 28cb50492c | |||
| dc525fbaa4 |
14
Makefile
14
Makefile
@@ -27,7 +27,19 @@ cleanup: cleanup-manager cleanup-runner
|
|||||||
run: cleanup build init-test
|
run: cleanup build init-test
|
||||||
@echo "Starting manager and runner in parallel..."
|
@echo "Starting manager and runner in parallel..."
|
||||||
@echo "Press Ctrl+C to stop both..."
|
@echo "Press Ctrl+C to stop both..."
|
||||||
@trap 'kill $$MANAGER_PID $$RUNNER_PID 2>/dev/null; exit' INT TERM; \
|
@MANAGER_PID=""; RUNNER_PID=""; INTERRUPTED=0; \
|
||||||
|
cleanup() { \
|
||||||
|
exit_code=$$?; \
|
||||||
|
trap - INT TERM EXIT; \
|
||||||
|
if [ -n "$$RUNNER_PID" ]; then kill -TERM "$$RUNNER_PID" 2>/dev/null || true; fi; \
|
||||||
|
if [ -n "$$MANAGER_PID" ]; then kill -TERM "$$MANAGER_PID" 2>/dev/null || true; fi; \
|
||||||
|
if [ -n "$$MANAGER_PID$$RUNNER_PID" ]; then wait $$MANAGER_PID $$RUNNER_PID 2>/dev/null || true; fi; \
|
||||||
|
if [ "$$INTERRUPTED" -eq 1 ]; then exit 0; fi; \
|
||||||
|
exit $$exit_code; \
|
||||||
|
}; \
|
||||||
|
on_interrupt() { INTERRUPTED=1; cleanup; }; \
|
||||||
|
trap on_interrupt INT TERM; \
|
||||||
|
trap cleanup EXIT; \
|
||||||
bin/jiggablend manager -l manager.log & \
|
bin/jiggablend manager -l manager.log & \
|
||||||
MANAGER_PID=$$!; \
|
MANAGER_PID=$$!; \
|
||||||
sleep 2; \
|
sleep 2; \
|
||||||
|
|||||||
@@ -154,6 +154,9 @@ bin/jiggablend runner --api-key <your-api-key>
|
|||||||
# With custom options
|
# With custom options
|
||||||
bin/jiggablend runner --manager http://localhost:8080 --name my-runner --api-key <key> --log-file runner.log
|
bin/jiggablend runner --manager http://localhost:8080 --name my-runner --api-key <key> --log-file runner.log
|
||||||
|
|
||||||
|
# Hardware compatibility flag (force CPU)
|
||||||
|
bin/jiggablend runner --api-key <key> --force-cpu-rendering
|
||||||
|
|
||||||
# Using environment variables
|
# Using environment variables
|
||||||
JIGGABLEND_MANAGER=http://localhost:8080 JIGGABLEND_API_KEY=<key> bin/jiggablend runner
|
JIGGABLEND_MANAGER=http://localhost:8080 JIGGABLEND_API_KEY=<key> bin/jiggablend runner
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"jiggablend/internal/auth"
|
"jiggablend/internal/auth"
|
||||||
@@ -151,7 +152,15 @@ func runManager(cmd *cobra.Command, args []string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func checkBlenderAvailable() error {
|
func checkBlenderAvailable() error {
|
||||||
cmd := exec.Command("blender", "--version")
|
blenderPath, err := exec.LookPath("blender")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to locate blender in PATH: %w", err)
|
||||||
|
}
|
||||||
|
blenderPath, err = filepath.Abs(blenderPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to resolve blender path %q: %w", blenderPath, err)
|
||||||
|
}
|
||||||
|
cmd := exec.Command(blenderPath, "--version")
|
||||||
output, err := cmd.CombinedOutput()
|
output, err := cmd.CombinedOutput()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to run 'blender --version': %w (output: %s)", err, string(output))
|
return fmt.Errorf("failed to run 'blender --version': %w (output: %s)", err, string(output))
|
||||||
|
|||||||
23
cmd/jiggablend/cmd/managerconfig_test.go
Normal file
23
cmd/jiggablend/cmd/managerconfig_test.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateAPIKey_Format(t *testing.T) {
|
||||||
|
key, prefix, hash, err := generateAPIKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generateAPIKey failed: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(prefix, "jk_r") {
|
||||||
|
t.Fatalf("unexpected prefix: %q", prefix)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(key, prefix+"_") {
|
||||||
|
t.Fatalf("key does not include prefix: %q", key)
|
||||||
|
}
|
||||||
|
if len(hash) != 64 {
|
||||||
|
t.Fatalf("expected sha256 hex hash length, got %d", len(hash))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
16
cmd/jiggablend/cmd/root_test.go
Normal file
16
cmd/jiggablend/cmd/root_test.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestRootCommand_HasKeySubcommands(t *testing.T) {
|
||||||
|
names := map[string]bool{}
|
||||||
|
for _, c := range rootCmd.Commands() {
|
||||||
|
names[c.Name()] = true
|
||||||
|
}
|
||||||
|
for _, required := range []string{"manager", "runner", "version"} {
|
||||||
|
if !names[required] {
|
||||||
|
t.Fatalf("expected subcommand %q to be registered", required)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -37,6 +37,7 @@ func init() {
|
|||||||
runnerCmd.Flags().String("log-level", "info", "Log level (debug, info, warn, error)")
|
runnerCmd.Flags().String("log-level", "info", "Log level (debug, info, warn, error)")
|
||||||
runnerCmd.Flags().BoolP("verbose", "v", false, "Enable verbose logging (same as --log-level=debug)")
|
runnerCmd.Flags().BoolP("verbose", "v", false, "Enable verbose logging (same as --log-level=debug)")
|
||||||
runnerCmd.Flags().Duration("poll-interval", 5*time.Second, "Job polling interval")
|
runnerCmd.Flags().Duration("poll-interval", 5*time.Second, "Job polling interval")
|
||||||
|
runnerCmd.Flags().Bool("force-cpu-rendering", false, "Force CPU rendering for all jobs (disables GPU rendering)")
|
||||||
|
|
||||||
// Bind flags to viper with JIGGABLEND_ prefix
|
// Bind flags to viper with JIGGABLEND_ prefix
|
||||||
runnerViper.SetEnvPrefix("JIGGABLEND")
|
runnerViper.SetEnvPrefix("JIGGABLEND")
|
||||||
@@ -51,6 +52,7 @@ func init() {
|
|||||||
runnerViper.BindPFlag("log_level", runnerCmd.Flags().Lookup("log-level"))
|
runnerViper.BindPFlag("log_level", runnerCmd.Flags().Lookup("log-level"))
|
||||||
runnerViper.BindPFlag("verbose", runnerCmd.Flags().Lookup("verbose"))
|
runnerViper.BindPFlag("verbose", runnerCmd.Flags().Lookup("verbose"))
|
||||||
runnerViper.BindPFlag("poll_interval", runnerCmd.Flags().Lookup("poll-interval"))
|
runnerViper.BindPFlag("poll_interval", runnerCmd.Flags().Lookup("poll-interval"))
|
||||||
|
runnerViper.BindPFlag("force_cpu_rendering", runnerCmd.Flags().Lookup("force-cpu-rendering"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func runRunner(cmd *cobra.Command, args []string) {
|
func runRunner(cmd *cobra.Command, args []string) {
|
||||||
@@ -63,6 +65,7 @@ func runRunner(cmd *cobra.Command, args []string) {
|
|||||||
logLevel := runnerViper.GetString("log_level")
|
logLevel := runnerViper.GetString("log_level")
|
||||||
verbose := runnerViper.GetBool("verbose")
|
verbose := runnerViper.GetBool("verbose")
|
||||||
pollInterval := runnerViper.GetDuration("poll_interval")
|
pollInterval := runnerViper.GetDuration("poll_interval")
|
||||||
|
forceCPURendering := runnerViper.GetBool("force_cpu_rendering")
|
||||||
|
|
||||||
var r *runner.Runner
|
var r *runner.Runner
|
||||||
|
|
||||||
@@ -118,7 +121,7 @@ func runRunner(cmd *cobra.Command, args []string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create runner
|
// Create runner
|
||||||
r = runner.New(managerURL, name, hostname)
|
r = runner.New(managerURL, name, hostname, forceCPURendering)
|
||||||
|
|
||||||
// Check for required tools early to fail fast
|
// Check for required tools early to fail fast
|
||||||
if err := r.CheckRequiredTools(); err != nil {
|
if err := r.CheckRequiredTools(); err != nil {
|
||||||
@@ -161,8 +164,8 @@ func runRunner(cmd *cobra.Command, args []string) {
|
|||||||
runnerID, err = r.Register(apiKey)
|
runnerID, err = r.Register(apiKey)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
logger.Infof("Registered runner with ID: %d", runnerID)
|
logger.Infof("Registered runner with ID: %d", runnerID)
|
||||||
// Download latest Blender and detect HIP vs NVIDIA so we only force CPU for Blender < 4.x when using HIP
|
// Detect GPU vendors/backends from host hardware so we only force CPU for Blender < 4.x when using AMD.
|
||||||
logger.Info("Detecting GPU backends (HIP/NVIDIA) for Blender < 4.x policy...")
|
logger.Info("Detecting GPU backends (AMD/NVIDIA/Intel) from host hardware for Blender < 4.x policy...")
|
||||||
r.DetectAndStoreGPUBackends()
|
r.DetectAndStoreGPUBackends()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|||||||
17
cmd/jiggablend/cmd/runner_test.go
Normal file
17
cmd/jiggablend/cmd/runner_test.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateShortID_IsHex8Bytes(t *testing.T) {
|
||||||
|
id := generateShortID()
|
||||||
|
if len(id) != 8 {
|
||||||
|
t.Fatalf("expected 8 hex chars, got %q", id)
|
||||||
|
}
|
||||||
|
if _, err := hex.DecodeString(id); err != nil {
|
||||||
|
t.Fatalf("id should be hex: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
13
cmd/jiggablend/cmd/version_test.go
Normal file
13
cmd/jiggablend/cmd/version_test.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestVersionCommand_Metadata(t *testing.T) {
|
||||||
|
if versionCmd.Use != "version" {
|
||||||
|
t.Fatalf("unexpected command use: %q", versionCmd.Use)
|
||||||
|
}
|
||||||
|
if versionCmd.Run == nil {
|
||||||
|
t.Fatal("version command run function should be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
8
cmd/jiggablend/main_test.go
Normal file
8
cmd/jiggablend/main_test.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestMainPackage_Builds(t *testing.T) {
|
||||||
|
// Smoke test placeholder to keep package main under test compilation.
|
||||||
|
}
|
||||||
|
|
||||||
15
installer.sh
15
installer.sh
@@ -79,17 +79,23 @@ cat << 'EOF' > jiggablend-runner.sh
|
|||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
# Wrapper to run jiggablend runner with test setup
|
# Wrapper to run jiggablend runner with test setup
|
||||||
# Usage: jiggablend-runner [MANAGER_URL]
|
# Usage: jiggablend-runner [MANAGER_URL] [RUNNER_FLAGS...]
|
||||||
# Default MANAGER_URL: http://localhost:8080
|
# Default MANAGER_URL: http://localhost:8080
|
||||||
# Run this in a directory where you want the logs
|
# Run this in a directory where you want the logs
|
||||||
|
|
||||||
MANAGER_URL="${1:-http://localhost:8080}"
|
MANAGER_URL="http://localhost:8080"
|
||||||
|
if [[ $# -gt 0 && "$1" != -* ]]; then
|
||||||
|
MANAGER_URL="$1"
|
||||||
|
shift
|
||||||
|
fi
|
||||||
|
|
||||||
|
EXTRA_ARGS=("$@")
|
||||||
|
|
||||||
mkdir -p logs
|
mkdir -p logs
|
||||||
rm -f logs/runner.log
|
rm -f logs/runner.log
|
||||||
|
|
||||||
# Run runner
|
# Run runner
|
||||||
jiggablend runner -l logs/runner.log --api-key=jk_r0_test_key_123456789012345678901234567890 --manager "$MANAGER_URL"
|
jiggablend runner -l logs/runner.log --api-key=jk_r0_test_key_123456789012345678901234567890 --manager "$MANAGER_URL" "${EXTRA_ARGS[@]}"
|
||||||
EOF
|
EOF
|
||||||
chmod +x jiggablend-runner.sh
|
chmod +x jiggablend-runner.sh
|
||||||
sudo install -m 0755 jiggablend-runner.sh /usr/local/bin/jiggablend-runner
|
sudo install -m 0755 jiggablend-runner.sh /usr/local/bin/jiggablend-runner
|
||||||
@@ -102,5 +108,6 @@ echo "Installation complete!"
|
|||||||
echo "Binary: jiggablend"
|
echo "Binary: jiggablend"
|
||||||
echo "Wrappers: jiggablend-manager, jiggablend-runner"
|
echo "Wrappers: jiggablend-manager, jiggablend-runner"
|
||||||
echo "Run 'jiggablend-manager' to start the manager with test config."
|
echo "Run 'jiggablend-manager' to start the manager with test config."
|
||||||
echo "Run 'jiggablend-runner [url]' to start the runner, e.g., jiggablend-runner http://your-manager:8080"
|
echo "Run 'jiggablend-runner [url] [runner flags...]' to start the runner."
|
||||||
|
echo "Example: jiggablend-runner http://your-manager:8080 --force-cpu-rendering"
|
||||||
echo "Note: Depending on whether you're running the manager or runner, additional dependencies like Blender, ImageMagick, or FFmpeg may be required. See the project README for details."
|
echo "Note: Depending on whether you're running the manager or runner, additional dependencies like Blender, ImageMagick, or FFmpeg may be required. See the project README for details."
|
||||||
@@ -668,24 +668,42 @@ func (a *Auth) IsProductionModeFromConfig() bool {
|
|||||||
return a.cfg.IsProductionMode()
|
return a.cfg.IsProductionMode()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Auth) writeUnauthorized(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Keep API behavior unchanged for programmatic clients.
|
||||||
|
if strings.HasPrefix(r.URL.Path, "/api/") {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// For HTMX UI fragment requests, trigger a full-page redirect to login.
|
||||||
|
if strings.EqualFold(r.Header.Get("HX-Request"), "true") {
|
||||||
|
w.Header().Set("HX-Redirect", "/login")
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// For normal browser page requests, redirect to login page.
|
||||||
|
http.Redirect(w, r, "/login", http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
// Middleware creates an authentication middleware
|
// Middleware creates an authentication middleware
|
||||||
func (a *Auth) Middleware(next http.HandlerFunc) http.HandlerFunc {
|
func (a *Auth) Middleware(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
cookie, err := r.Cookie("session_id")
|
cookie, err := r.Cookie("session_id")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Authentication failed: missing session cookie for %s %s", r.Method, r.URL.Path)
|
log.Printf("Authentication failed: missing session cookie for %s %s", r.Method, r.URL.Path)
|
||||||
w.Header().Set("Content-Type", "application/json")
|
a.writeUnauthorized(w, r)
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
session, ok := a.GetSession(cookie.Value)
|
session, ok := a.GetSession(cookie.Value)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Printf("Authentication failed: invalid session cookie for %s %s", r.Method, r.URL.Path)
|
log.Printf("Authentication failed: invalid session cookie for %s %s", r.Method, r.URL.Path)
|
||||||
w.Header().Set("Content-Type", "application/json")
|
a.writeUnauthorized(w, r)
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -717,18 +735,14 @@ func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
|||||||
cookie, err := r.Cookie("session_id")
|
cookie, err := r.Cookie("session_id")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Admin authentication failed: missing session cookie for %s %s", r.Method, r.URL.Path)
|
log.Printf("Admin authentication failed: missing session cookie for %s %s", r.Method, r.URL.Path)
|
||||||
w.Header().Set("Content-Type", "application/json")
|
a.writeUnauthorized(w, r)
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
session, ok := a.GetSession(cookie.Value)
|
session, ok := a.GetSession(cookie.Value)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Printf("Admin authentication failed: invalid session cookie for %s %s", r.Method, r.URL.Path)
|
log.Printf("Admin authentication failed: invalid session cookie for %s %s", r.Method, r.URL.Path)
|
||||||
w.Header().Set("Content-Type", "application/json")
|
a.writeUnauthorized(w, r)
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
56
internal/auth/auth_test.go
Normal file
56
internal/auth/auth_test.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestContextHelpers(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
ctx = context.WithValue(ctx, contextKeyUserID, int64(123))
|
||||||
|
ctx = context.WithValue(ctx, contextKeyIsAdmin, true)
|
||||||
|
|
||||||
|
id, ok := GetUserID(ctx)
|
||||||
|
if !ok || id != 123 {
|
||||||
|
t.Fatalf("GetUserID() = (%d,%v), want (123,true)", id, ok)
|
||||||
|
}
|
||||||
|
if !IsAdmin(ctx) {
|
||||||
|
t.Fatal("expected IsAdmin to be true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsProductionMode_UsesEnv(t *testing.T) {
|
||||||
|
t.Setenv("PRODUCTION", "true")
|
||||||
|
if !IsProductionMode() {
|
||||||
|
t.Fatal("expected production mode true when env is set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteUnauthorized_BehaviorByRequestType(t *testing.T) {
|
||||||
|
a := &Auth{}
|
||||||
|
|
||||||
|
reqAPI := httptest.NewRequest(http.MethodGet, "/api/jobs", nil)
|
||||||
|
rrAPI := httptest.NewRecorder()
|
||||||
|
a.writeUnauthorized(rrAPI, reqAPI)
|
||||||
|
if rrAPI.Code != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("api code = %d", rrAPI.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
reqPage := httptest.NewRequest(http.MethodGet, "/dashboard", nil)
|
||||||
|
rrPage := httptest.NewRecorder()
|
||||||
|
a.writeUnauthorized(rrPage, reqPage)
|
||||||
|
if rrPage.Code != http.StatusFound {
|
||||||
|
t.Fatalf("page code = %d", rrPage.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsProductionMode_DefaultFalse(t *testing.T) {
|
||||||
|
_ = os.Unsetenv("PRODUCTION")
|
||||||
|
if IsProductionMode() {
|
||||||
|
t.Fatal("expected false when PRODUCTION is unset")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
84
internal/auth/jobtoken_test.go
Normal file
84
internal/auth/jobtoken_test.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateAndValidateJobToken_RoundTrip(t *testing.T) {
|
||||||
|
token, err := GenerateJobToken(10, 20, 30)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateJobToken failed: %v", err)
|
||||||
|
}
|
||||||
|
claims, err := ValidateJobToken(token)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ValidateJobToken failed: %v", err)
|
||||||
|
}
|
||||||
|
if claims.JobID != 10 || claims.RunnerID != 20 || claims.TaskID != 30 {
|
||||||
|
t.Fatalf("unexpected claims: %+v", claims)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateJobToken_RejectsTampering(t *testing.T) {
|
||||||
|
token, err := GenerateJobToken(1, 2, 3)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateJobToken failed: %v", err)
|
||||||
|
}
|
||||||
|
parts := strings.Split(token, ".")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
t.Fatalf("unexpected token format: %q", token)
|
||||||
|
}
|
||||||
|
|
||||||
|
rawClaims, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode claims failed: %v", err)
|
||||||
|
}
|
||||||
|
var claims JobTokenClaims
|
||||||
|
if err := json.Unmarshal(rawClaims, &claims); err != nil {
|
||||||
|
t.Fatalf("unmarshal claims failed: %v", err)
|
||||||
|
}
|
||||||
|
claims.JobID = 999
|
||||||
|
tamperedClaims, _ := json.Marshal(claims)
|
||||||
|
tampered := base64.RawURLEncoding.EncodeToString(tamperedClaims) + "." + parts[1]
|
||||||
|
|
||||||
|
if _, err := ValidateJobToken(tampered); err == nil {
|
||||||
|
t.Fatal("expected signature validation error for tampered token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateJobToken_RejectsExpired(t *testing.T) {
|
||||||
|
expiredClaims := JobTokenClaims{
|
||||||
|
JobID: 1,
|
||||||
|
RunnerID: 2,
|
||||||
|
TaskID: 3,
|
||||||
|
Exp: time.Now().Add(-time.Minute).Unix(),
|
||||||
|
}
|
||||||
|
claimsJSON, _ := json.Marshal(expiredClaims)
|
||||||
|
sigToken, err := GenerateJobToken(1, 2, 3)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateJobToken failed: %v", err)
|
||||||
|
}
|
||||||
|
parts := strings.Split(sigToken, ".")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
t.Fatalf("unexpected token format: %q", sigToken)
|
||||||
|
}
|
||||||
|
// Re-sign expired payload with package secret.
|
||||||
|
h := signClaimsForTest(claimsJSON)
|
||||||
|
expiredToken := base64.RawURLEncoding.EncodeToString(claimsJSON) + "." + base64.RawURLEncoding.EncodeToString(h)
|
||||||
|
|
||||||
|
if _, err := ValidateJobToken(expiredToken); err == nil {
|
||||||
|
t.Fatal("expected token expiration error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func signClaimsForTest(claims []byte) []byte {
|
||||||
|
h := hmac.New(sha256.New, jobTokenSecret)
|
||||||
|
_, _ = h.Write(claims)
|
||||||
|
return h.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
32
internal/auth/secrets_test.go
Normal file
32
internal/auth/secrets_test.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateSecret_Length(t *testing.T) {
|
||||||
|
secret, err := generateSecret(8)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generateSecret failed: %v", err)
|
||||||
|
}
|
||||||
|
// hex encoding doubles length
|
||||||
|
if len(secret) != 16 {
|
||||||
|
t.Fatalf("unexpected secret length: %d", len(secret))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateAPIKey_Format(t *testing.T) {
|
||||||
|
s := &Secrets{}
|
||||||
|
key, err := s.generateAPIKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generateAPIKey failed: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(key, "jk_r") {
|
||||||
|
t.Fatalf("unexpected key prefix: %q", key)
|
||||||
|
}
|
||||||
|
if !strings.Contains(key, "_") {
|
||||||
|
t.Fatalf("unexpected key format: %q", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -23,6 +23,14 @@ const (
|
|||||||
KeyProductionMode = "production_mode"
|
KeyProductionMode = "production_mode"
|
||||||
KeyAllowedOrigins = "allowed_origins"
|
KeyAllowedOrigins = "allowed_origins"
|
||||||
KeyFramesPerRenderTask = "frames_per_render_task"
|
KeyFramesPerRenderTask = "frames_per_render_task"
|
||||||
|
|
||||||
|
// Operational limits (seconds / bytes / counts)
|
||||||
|
KeyRenderTimeoutSecs = "render_timeout_seconds"
|
||||||
|
KeyEncodeTimeoutSecs = "encode_timeout_seconds"
|
||||||
|
KeyMaxUploadBytes = "max_upload_bytes"
|
||||||
|
KeySessionCookieMaxAge = "session_cookie_max_age"
|
||||||
|
KeyAPIRateLimit = "api_rate_limit"
|
||||||
|
KeyAuthRateLimit = "auth_rate_limit"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config manages application configuration stored in the database
|
// Config manages application configuration stored in the database
|
||||||
@@ -311,3 +319,34 @@ func (c *Config) GetFramesPerRenderTask() int {
|
|||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RenderTimeoutSeconds returns the per-frame render timeout in seconds (default 3600 = 1 hour).
|
||||||
|
func (c *Config) RenderTimeoutSeconds() int {
|
||||||
|
return c.GetIntWithDefault(KeyRenderTimeoutSecs, 3600)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeTimeoutSeconds returns the video encode timeout in seconds (default 86400 = 24 hours).
|
||||||
|
func (c *Config) EncodeTimeoutSeconds() int {
|
||||||
|
return c.GetIntWithDefault(KeyEncodeTimeoutSecs, 86400)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxUploadBytes returns the maximum upload size in bytes (default 50 GB).
|
||||||
|
func (c *Config) MaxUploadBytes() int64 {
|
||||||
|
v := c.GetIntWithDefault(KeyMaxUploadBytes, 50<<30)
|
||||||
|
return int64(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionCookieMaxAgeSec returns the session cookie max-age in seconds (default 86400 = 24 hours).
|
||||||
|
func (c *Config) SessionCookieMaxAgeSec() int {
|
||||||
|
return c.GetIntWithDefault(KeySessionCookieMaxAge, 86400)
|
||||||
|
}
|
||||||
|
|
||||||
|
// APIRateLimitPerMinute returns the API rate limit (requests per minute per IP, default 100).
|
||||||
|
func (c *Config) APIRateLimitPerMinute() int {
|
||||||
|
return c.GetIntWithDefault(KeyAPIRateLimit, 100)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthRateLimitPerMinute returns the auth rate limit (requests per minute per IP, default 10).
|
||||||
|
func (c *Config) AuthRateLimitPerMinute() int {
|
||||||
|
return c.GetIntWithDefault(KeyAuthRateLimit, 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
66
internal/config/config_test.go
Normal file
66
internal/config/config_test.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"jiggablend/internal/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestConfig(t *testing.T) *Config {
|
||||||
|
t.Helper()
|
||||||
|
db, err := database.NewDB(filepath.Join(t.TempDir(), "cfg.db"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB failed: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = db.Close() })
|
||||||
|
return NewConfig(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetGetExistsDelete(t *testing.T) {
|
||||||
|
cfg := newTestConfig(t)
|
||||||
|
|
||||||
|
if err := cfg.Set("alpha", "1"); err != nil {
|
||||||
|
t.Fatalf("Set failed: %v", err)
|
||||||
|
}
|
||||||
|
v, err := cfg.Get("alpha")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get failed: %v", err)
|
||||||
|
}
|
||||||
|
if v != "1" {
|
||||||
|
t.Fatalf("unexpected value: %q", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
exists, err := cfg.Exists("alpha")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Exists failed: %v", err)
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
t.Fatal("expected key to exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cfg.Delete("alpha"); err != nil {
|
||||||
|
t.Fatalf("Delete failed: %v", err)
|
||||||
|
}
|
||||||
|
exists, err = cfg.Exists("alpha")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Exists after delete failed: %v", err)
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
t.Fatal("expected key to be deleted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetIntWithDefault_AndMinimumFrameTask(t *testing.T) {
|
||||||
|
cfg := newTestConfig(t)
|
||||||
|
if got := cfg.GetIntWithDefault("missing", 17); got != 17 {
|
||||||
|
t.Fatalf("expected default value, got %d", got)
|
||||||
|
}
|
||||||
|
if err := cfg.SetInt(KeyFramesPerRenderTask, 0); err != nil {
|
||||||
|
t.Fatalf("SetInt failed: %v", err)
|
||||||
|
}
|
||||||
|
if got := cfg.GetFramesPerRenderTask(); got != 1 {
|
||||||
|
t.Fatalf("expected clamped value 1, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
-- SQLite does not support DROP COLUMN directly; recreate the table without last_used_at.
|
||||||
|
CREATE TABLE runner_api_keys_backup (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
key_prefix TEXT NOT NULL,
|
||||||
|
key_hash TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
scope TEXT NOT NULL DEFAULT 'user',
|
||||||
|
is_active INTEGER NOT NULL DEFAULT 1,
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
created_by INTEGER,
|
||||||
|
FOREIGN KEY (created_by) REFERENCES users(id),
|
||||||
|
UNIQUE(key_prefix)
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO runner_api_keys_backup SELECT id, key_prefix, key_hash, name, description, scope, is_active, created_at, created_by FROM runner_api_keys;
|
||||||
|
|
||||||
|
DROP TABLE runner_api_keys;
|
||||||
|
|
||||||
|
ALTER TABLE runner_api_keys_backup RENAME TO runner_api_keys;
|
||||||
|
|
||||||
|
CREATE INDEX idx_runner_api_keys_prefix ON runner_api_keys(key_prefix);
|
||||||
|
CREATE INDEX idx_runner_api_keys_active ON runner_api_keys(is_active);
|
||||||
|
CREATE INDEX idx_runner_api_keys_created_by ON runner_api_keys(created_by);
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
ALTER TABLE runner_api_keys ADD COLUMN last_used_at TIMESTAMP;
|
||||||
58
internal/database/schema_test.go
Normal file
58
internal/database/schema_test.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewDB_RunsMigrationsAndSupportsQueries(t *testing.T) {
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "test.db")
|
||||||
|
db, err := NewDB(dbPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB failed: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if err := db.Ping(); err != nil {
|
||||||
|
t.Fatalf("Ping failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var exists bool
|
||||||
|
err = db.With(func(conn *sql.DB) error {
|
||||||
|
return conn.QueryRow("SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND name='settings')").Scan(&exists)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("query failed: %v", err)
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
t.Fatal("expected settings table after migrations")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithTx_RollbackOnError(t *testing.T) {
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "tx.db")
|
||||||
|
db, err := NewDB(dbPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB failed: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
_ = db.WithTx(func(tx *sql.Tx) error {
|
||||||
|
if _, err := tx.Exec("INSERT INTO settings (key, value) VALUES (?, ?)", "rollback_key", "x"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return sql.ErrTxDone
|
||||||
|
})
|
||||||
|
|
||||||
|
var count int
|
||||||
|
if err := db.With(func(conn *sql.DB) error {
|
||||||
|
return conn.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "rollback_key").Scan(&count)
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("count query failed: %v", err)
|
||||||
|
}
|
||||||
|
if count != 0 {
|
||||||
|
t.Fatalf("expected rollback, found %d rows", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
35
internal/logger/logger_test.go
Normal file
35
internal/logger/logger_test.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseLevel(t *testing.T) {
|
||||||
|
if ParseLevel("debug") != LevelDebug {
|
||||||
|
t.Fatal("debug should map to LevelDebug")
|
||||||
|
}
|
||||||
|
if ParseLevel("warning") != LevelWarn {
|
||||||
|
t.Fatal("warning should map to LevelWarn")
|
||||||
|
}
|
||||||
|
if ParseLevel("unknown") != LevelInfo {
|
||||||
|
t.Fatal("unknown should default to LevelInfo")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetAndGetLevel(t *testing.T) {
|
||||||
|
SetLevel(LevelError)
|
||||||
|
if GetLevel() != LevelError {
|
||||||
|
t.Fatalf("GetLevel() = %v, want %v", GetLevel(), LevelError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewWithFile_CreatesFile(t *testing.T) {
|
||||||
|
logPath := filepath.Join(t.TempDir(), "runner.log")
|
||||||
|
l, err := NewWithFile(logPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewWithFile failed: %v", err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
}
|
||||||
|
|
||||||
35
internal/manager/admin_test.go
Normal file
35
internal/manager/admin_test.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHandleGenerateRunnerAPIKey_UnauthorizedWithoutContext(t *testing.T) {
|
||||||
|
s := &Manager{}
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/admin/runner-api-keys", bytes.NewBufferString(`{"name":"k"}`))
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
s.handleGenerateRunnerAPIKey(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("status = %d, want %d", rr.Code, http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleGenerateRunnerAPIKey_RejectsBadJSON(t *testing.T) {
|
||||||
|
s := &Manager{}
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/admin/runner-api-keys", bytes.NewBufferString(`{`))
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
s.handleGenerateRunnerAPIKey(rr, req)
|
||||||
|
|
||||||
|
// No auth context means unauthorized happens first; this still validates safe
|
||||||
|
// failure handling for malformed requests in this handler path.
|
||||||
|
if rr.Code != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("status = %d, want %d", rr.Code, http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -3,7 +3,6 @@ package api
|
|||||||
import (
|
import (
|
||||||
"archive/tar"
|
"archive/tar"
|
||||||
"compress/bzip2"
|
"compress/bzip2"
|
||||||
"compress/gzip"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
@@ -16,6 +15,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"jiggablend/pkg/blendfile"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -439,144 +440,16 @@ func (s *Manager) cleanupExtractedBlenderFolders(blenderDir string, version *Ble
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseBlenderVersionFromFile parses the Blender version that a .blend file was saved with
|
// ParseBlenderVersionFromFile parses the Blender version that a .blend file was saved with.
|
||||||
// This reads the file header to determine the version
|
// Delegates to the shared pkg/blendfile implementation.
|
||||||
func ParseBlenderVersionFromFile(blendPath string) (major, minor int, err error) {
|
func ParseBlenderVersionFromFile(blendPath string) (major, minor int, err error) {
|
||||||
file, err := os.Open(blendPath)
|
return blendfile.ParseVersionFromFile(blendPath)
|
||||||
if err != nil {
|
|
||||||
return 0, 0, fmt.Errorf("failed to open blend file: %w", err)
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
return ParseBlenderVersionFromReader(file)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseBlenderVersionFromReader parses the Blender version from a reader
|
// ParseBlenderVersionFromReader parses the Blender version from a reader.
|
||||||
// Useful for reading from uploaded files without saving to disk first
|
// Delegates to the shared pkg/blendfile implementation.
|
||||||
func ParseBlenderVersionFromReader(r io.ReadSeeker) (major, minor int, err error) {
|
func ParseBlenderVersionFromReader(r io.ReadSeeker) (major, minor int, err error) {
|
||||||
// Read the first 12 bytes of the blend file header
|
return blendfile.ParseVersionFromReader(r)
|
||||||
// Format: BLENDER-v<major><minor><patch> or BLENDER_v<major><minor><patch>
|
|
||||||
// The header is: "BLENDER" (7 bytes) + pointer size (1 byte: '-' for 64-bit, '_' for 32-bit)
|
|
||||||
// + endianness (1 byte: 'v' for little-endian, 'V' for big-endian)
|
|
||||||
// + version (3 bytes: e.g., "402" for 4.02)
|
|
||||||
header := make([]byte, 12)
|
|
||||||
n, err := r.Read(header)
|
|
||||||
if err != nil || n < 12 {
|
|
||||||
return 0, 0, fmt.Errorf("failed to read blend file header: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for BLENDER magic
|
|
||||||
if string(header[:7]) != "BLENDER" {
|
|
||||||
// Might be compressed - try to decompress
|
|
||||||
r.Seek(0, 0)
|
|
||||||
return parseCompressedBlendVersion(r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse version from bytes 9-11 (3 digits)
|
|
||||||
versionStr := string(header[9:12])
|
|
||||||
var vMajor, vMinor int
|
|
||||||
|
|
||||||
// Version format changed in Blender 3.0
|
|
||||||
// Pre-3.0: "279" = 2.79, "280" = 2.80
|
|
||||||
// 3.0+: "300" = 3.0, "402" = 4.02, "410" = 4.10
|
|
||||||
if len(versionStr) == 3 {
|
|
||||||
// First digit is major version
|
|
||||||
fmt.Sscanf(string(versionStr[0]), "%d", &vMajor)
|
|
||||||
// Next two digits are minor version
|
|
||||||
fmt.Sscanf(versionStr[1:3], "%d", &vMinor)
|
|
||||||
}
|
|
||||||
|
|
||||||
return vMajor, vMinor, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseCompressedBlendVersion handles gzip and zstd compressed blend files
|
|
||||||
func parseCompressedBlendVersion(r io.ReadSeeker) (major, minor int, err error) {
|
|
||||||
// Check for compression magic bytes
|
|
||||||
magic := make([]byte, 4)
|
|
||||||
if _, err := r.Read(magic); err != nil {
|
|
||||||
return 0, 0, err
|
|
||||||
}
|
|
||||||
r.Seek(0, 0)
|
|
||||||
|
|
||||||
if magic[0] == 0x1f && magic[1] == 0x8b {
|
|
||||||
// gzip compressed
|
|
||||||
gzReader, err := gzip.NewReader(r)
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, fmt.Errorf("failed to create gzip reader: %w", err)
|
|
||||||
}
|
|
||||||
defer gzReader.Close()
|
|
||||||
|
|
||||||
header := make([]byte, 12)
|
|
||||||
n, err := gzReader.Read(header)
|
|
||||||
if err != nil || n < 12 {
|
|
||||||
return 0, 0, fmt.Errorf("failed to read compressed blend header: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if string(header[:7]) != "BLENDER" {
|
|
||||||
return 0, 0, fmt.Errorf("invalid blend file format")
|
|
||||||
}
|
|
||||||
|
|
||||||
versionStr := string(header[9:12])
|
|
||||||
var vMajor, vMinor int
|
|
||||||
if len(versionStr) == 3 {
|
|
||||||
fmt.Sscanf(string(versionStr[0]), "%d", &vMajor)
|
|
||||||
fmt.Sscanf(versionStr[1:3], "%d", &vMinor)
|
|
||||||
}
|
|
||||||
|
|
||||||
return vMajor, vMinor, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for zstd magic (Blender 3.0+): 0x28 0xB5 0x2F 0xFD
|
|
||||||
if magic[0] == 0x28 && magic[1] == 0xb5 && magic[2] == 0x2f && magic[3] == 0xfd {
|
|
||||||
return parseZstdBlendVersion(r)
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0, 0, fmt.Errorf("unknown blend file format")
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseZstdBlendVersion handles zstd-compressed blend files (Blender 3.0+)
|
|
||||||
// Uses zstd command line tool since Go doesn't have native zstd support
|
|
||||||
func parseZstdBlendVersion(r io.ReadSeeker) (major, minor int, err error) {
|
|
||||||
r.Seek(0, 0)
|
|
||||||
|
|
||||||
// We need to decompress just enough to read the header
|
|
||||||
// Use zstd command to decompress from stdin
|
|
||||||
cmd := exec.Command("zstd", "-d", "-c")
|
|
||||||
cmd.Stdin = r
|
|
||||||
|
|
||||||
stdout, err := cmd.StdoutPipe()
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, fmt.Errorf("failed to create zstd stdout pipe: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := cmd.Start(); err != nil {
|
|
||||||
return 0, 0, fmt.Errorf("failed to start zstd decompression: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read just the header (12 bytes)
|
|
||||||
header := make([]byte, 12)
|
|
||||||
n, readErr := io.ReadFull(stdout, header)
|
|
||||||
|
|
||||||
// Kill the process early - we only need the header
|
|
||||||
cmd.Process.Kill()
|
|
||||||
cmd.Wait()
|
|
||||||
|
|
||||||
if readErr != nil || n < 12 {
|
|
||||||
return 0, 0, fmt.Errorf("failed to read zstd compressed blend header: %v", readErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if string(header[:7]) != "BLENDER" {
|
|
||||||
return 0, 0, fmt.Errorf("invalid blend file format in zstd archive")
|
|
||||||
}
|
|
||||||
|
|
||||||
versionStr := string(header[9:12])
|
|
||||||
var vMajor, vMinor int
|
|
||||||
if len(versionStr) == 3 {
|
|
||||||
fmt.Sscanf(string(versionStr[0]), "%d", &vMajor)
|
|
||||||
fmt.Sscanf(versionStr[1:3], "%d", &vMinor)
|
|
||||||
}
|
|
||||||
|
|
||||||
return vMajor, vMinor, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleGetBlenderVersions returns available Blender versions
|
// handleGetBlenderVersions returns available Blender versions
|
||||||
@@ -713,7 +586,7 @@ func (s *Manager) handleDownloadBlender(w http.ResponseWriter, r *http.Request)
|
|||||||
tarFilename = strings.TrimSuffix(tarFilename, ".bz2")
|
tarFilename = strings.TrimSuffix(tarFilename, ".bz2")
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/x-tar")
|
w.Header().Set("Content-Type", "application/x-tar")
|
||||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", tarFilename))
|
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", tarFilename))
|
||||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
|
||||||
w.Header().Set("X-Blender-Version", blenderVersion.Full)
|
w.Header().Set("X-Blender-Version", blenderVersion.Full)
|
||||||
|
|
||||||
|
|||||||
35
internal/manager/blender_path.go
Normal file
35
internal/manager/blender_path.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// resolveBlenderBinaryPath resolves a Blender executable to an absolute path.
|
||||||
|
func resolveBlenderBinaryPath(blenderBinary string) (string, error) {
|
||||||
|
if blenderBinary == "" {
|
||||||
|
return "", fmt.Errorf("blender binary path is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Already contains a path component; normalize it.
|
||||||
|
if strings.Contains(blenderBinary, string(filepath.Separator)) {
|
||||||
|
absPath, err := filepath.Abs(blenderBinary)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to resolve blender binary path %q: %w", blenderBinary, err)
|
||||||
|
}
|
||||||
|
return absPath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bare executable name, resolve via PATH.
|
||||||
|
resolvedPath, err := exec.LookPath(blenderBinary)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to locate blender binary %q in PATH: %w", blenderBinary, err)
|
||||||
|
}
|
||||||
|
absPath, err := filepath.Abs(resolvedPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to resolve blender binary path %q: %w", resolvedPath, err)
|
||||||
|
}
|
||||||
|
return absPath, nil
|
||||||
|
}
|
||||||
23
internal/manager/blender_path_test.go
Normal file
23
internal/manager/blender_path_test.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResolveBlenderBinaryPath_WithPathComponent(t *testing.T) {
|
||||||
|
got, err := resolveBlenderBinaryPath("./blender")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("resolveBlenderBinaryPath failed: %v", err)
|
||||||
|
}
|
||||||
|
if !filepath.IsAbs(got) {
|
||||||
|
t.Fatalf("expected absolute path, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveBlenderBinaryPath_Empty(t *testing.T) {
|
||||||
|
if _, err := resolveBlenderBinaryPath(""); err == nil {
|
||||||
|
t.Fatal("expected error for empty path")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
27
internal/manager/blender_test.go
Normal file
27
internal/manager/blender_test.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetLatestBlenderForMajorMinor_UsesCachedVersions(t *testing.T) {
|
||||||
|
blenderVersionCache.mu.Lock()
|
||||||
|
blenderVersionCache.versions = []BlenderVersion{
|
||||||
|
{Major: 4, Minor: 2, Patch: 1, Full: "4.2.1"},
|
||||||
|
{Major: 4, Minor: 2, Patch: 3, Full: "4.2.3"},
|
||||||
|
{Major: 4, Minor: 1, Patch: 9, Full: "4.1.9"},
|
||||||
|
}
|
||||||
|
blenderVersionCache.fetchedAt = time.Now()
|
||||||
|
blenderVersionCache.mu.Unlock()
|
||||||
|
|
||||||
|
m := &Manager{}
|
||||||
|
v, err := m.GetLatestBlenderForMajorMinor(4, 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetLatestBlenderForMajorMinor failed: %v", err)
|
||||||
|
}
|
||||||
|
if v.Full != "4.2.3" {
|
||||||
|
t.Fatalf("expected highest patch, got %+v", *v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -97,6 +97,58 @@ func (s *Manager) failUploadSession(sessionID, errorMessage string) (int64, bool
|
|||||||
return userID, true
|
return userID, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
uploadSessionExpiredCode = "UPLOAD_SESSION_EXPIRED"
|
||||||
|
uploadSessionNotReadyCode = "UPLOAD_SESSION_NOT_READY"
|
||||||
|
)
|
||||||
|
|
||||||
|
type uploadSessionValidationError struct {
|
||||||
|
Code string
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *uploadSessionValidationError) Error() string {
|
||||||
|
return e.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateUploadSessionForJobCreation validates that an upload session can be used for job creation.
|
||||||
|
// Returns the session and its context tar path when valid.
|
||||||
|
func (s *Manager) validateUploadSessionForJobCreation(sessionID string, userID int64) (*UploadSession, string, error) {
|
||||||
|
s.uploadSessionsMu.RLock()
|
||||||
|
uploadSession := s.uploadSessions[sessionID]
|
||||||
|
s.uploadSessionsMu.RUnlock()
|
||||||
|
|
||||||
|
if uploadSession == nil || uploadSession.UserID != userID {
|
||||||
|
return nil, "", &uploadSessionValidationError{
|
||||||
|
Code: uploadSessionExpiredCode,
|
||||||
|
Message: "Upload session expired or not found. Please upload the file again.",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if uploadSession.Status != "completed" {
|
||||||
|
return nil, "", &uploadSessionValidationError{
|
||||||
|
Code: uploadSessionNotReadyCode,
|
||||||
|
Message: "Upload session is not ready yet. Wait for processing to complete.",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if uploadSession.TempDir == "" {
|
||||||
|
return nil, "", &uploadSessionValidationError{
|
||||||
|
Code: uploadSessionExpiredCode,
|
||||||
|
Message: "Upload session context data is missing. Please upload the file again.",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tempContextPath := filepath.Join(uploadSession.TempDir, "context.tar")
|
||||||
|
if _, statErr := os.Stat(tempContextPath); statErr != nil {
|
||||||
|
log.Printf("ERROR: Context archive not found at %s for session %s: %v", tempContextPath, sessionID, statErr)
|
||||||
|
return nil, "", &uploadSessionValidationError{
|
||||||
|
Code: uploadSessionExpiredCode,
|
||||||
|
Message: "Upload session context archive was not found (possibly after manager restart). Please upload the file again.",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return uploadSession, tempContextPath, nil
|
||||||
|
}
|
||||||
|
|
||||||
// handleCreateJob creates a new job
|
// handleCreateJob creates a new job
|
||||||
func (s *Manager) handleCreateJob(w http.ResponseWriter, r *http.Request) {
|
func (s *Manager) handleCreateJob(w http.ResponseWriter, r *http.Request) {
|
||||||
userID, err := getUserID(r)
|
userID, err := getUserID(r)
|
||||||
@@ -178,6 +230,22 @@ func (s *Manager) handleCreateJob(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var uploadSession *UploadSession
|
||||||
|
var tempContextPath string
|
||||||
|
if req.UploadSessionID != nil && *req.UploadSessionID != "" {
|
||||||
|
var validateErr error
|
||||||
|
uploadSession, tempContextPath, validateErr = s.validateUploadSessionForJobCreation(*req.UploadSessionID, userID)
|
||||||
|
if validateErr != nil {
|
||||||
|
var sessionErr *uploadSessionValidationError
|
||||||
|
if errors.As(validateErr, &sessionErr) {
|
||||||
|
s.respondErrorWithCode(w, http.StatusBadRequest, sessionErr.Code, sessionErr.Message)
|
||||||
|
} else {
|
||||||
|
s.respondError(w, http.StatusBadRequest, validateErr.Error())
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Store render settings, unhide_objects, enable_execution, and blender_version in blend_metadata if provided.
|
// Store render settings, unhide_objects, enable_execution, and blender_version in blend_metadata if provided.
|
||||||
var blendMetadataJSON *string
|
var blendMetadataJSON *string
|
||||||
if req.RenderSettings != nil || req.UnhideObjects != nil || req.EnableExecution != nil || req.BlenderVersion != nil || req.OutputFormat != nil {
|
if req.RenderSettings != nil || req.UnhideObjects != nil || req.EnableExecution != nil || req.BlenderVersion != nil || req.OutputFormat != nil {
|
||||||
@@ -226,39 +294,29 @@ func (s *Manager) handleCreateJob(w http.ResponseWriter, r *http.Request) {
|
|||||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create job: %v", err))
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create job: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
cleanupCreatedJob := func(reason string) {
|
||||||
|
log.Printf("Cleaning up partially created job %d: %s", jobID, reason)
|
||||||
|
_ = s.db.With(func(conn *sql.DB) error {
|
||||||
|
// Be defensive in case foreign key cascade is disabled.
|
||||||
|
_, _ = conn.Exec(`DELETE FROM task_logs WHERE task_id IN (SELECT id FROM tasks WHERE job_id = ?)`, jobID)
|
||||||
|
_, _ = conn.Exec(`DELETE FROM task_steps WHERE task_id IN (SELECT id FROM tasks WHERE job_id = ?)`, jobID)
|
||||||
|
_, _ = conn.Exec(`DELETE FROM tasks WHERE job_id = ?`, jobID)
|
||||||
|
_, _ = conn.Exec(`DELETE FROM job_files WHERE job_id = ?`, jobID)
|
||||||
|
_, _ = conn.Exec(`DELETE FROM jobs WHERE id = ?`, jobID)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
_ = os.RemoveAll(s.storage.JobPath(jobID))
|
||||||
|
}
|
||||||
|
|
||||||
// If upload session ID is provided, move the context archive from temp to job directory
|
// If upload session ID is provided, move the context archive from temp to job directory
|
||||||
if req.UploadSessionID != nil && *req.UploadSessionID != "" {
|
if uploadSession != nil {
|
||||||
log.Printf("Processing upload session for job %d: %s", jobID, *req.UploadSessionID)
|
log.Printf("Processing upload session for job %d: %s", jobID, *req.UploadSessionID)
|
||||||
var uploadSession *UploadSession
|
|
||||||
s.uploadSessionsMu.RLock()
|
|
||||||
uploadSession = s.uploadSessions[*req.UploadSessionID]
|
|
||||||
s.uploadSessionsMu.RUnlock()
|
|
||||||
|
|
||||||
if uploadSession == nil || uploadSession.UserID != userID {
|
|
||||||
s.respondError(w, http.StatusBadRequest, "Invalid upload session. Please upload the file again.")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if uploadSession.Status != "completed" {
|
|
||||||
s.respondError(w, http.StatusBadRequest, "Upload session is not ready yet. Wait for processing to complete.")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if uploadSession.TempDir == "" {
|
|
||||||
s.respondError(w, http.StatusBadRequest, "Upload session is missing context data. Please upload again.")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tempContextPath := filepath.Join(uploadSession.TempDir, "context.tar")
|
|
||||||
if _, statErr := os.Stat(tempContextPath); statErr != nil {
|
|
||||||
log.Printf("ERROR: Context archive not found at %s for session %s: %v", tempContextPath, *req.UploadSessionID, statErr)
|
|
||||||
s.respondError(w, http.StatusBadRequest, "Context archive not found for upload session. Please upload the file again.")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("Found context archive at %s, moving to job %d directory", tempContextPath, jobID)
|
log.Printf("Found context archive at %s, moving to job %d directory", tempContextPath, jobID)
|
||||||
jobPath := s.storage.JobPath(jobID)
|
jobPath := s.storage.JobPath(jobID)
|
||||||
if err := os.MkdirAll(jobPath, 0755); err != nil {
|
if err := os.MkdirAll(jobPath, 0755); err != nil {
|
||||||
log.Printf("ERROR: Failed to create job directory for job %d: %v", jobID, err)
|
log.Printf("ERROR: Failed to create job directory for job %d: %v", jobID, err)
|
||||||
|
cleanupCreatedJob("failed to create job directory")
|
||||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create job directory: %v", err))
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create job directory: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -267,6 +325,7 @@ func (s *Manager) handleCreateJob(w http.ResponseWriter, r *http.Request) {
|
|||||||
srcFile, err := os.Open(tempContextPath)
|
srcFile, err := os.Open(tempContextPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("ERROR: Failed to open source context archive %s: %v", tempContextPath, err)
|
log.Printf("ERROR: Failed to open source context archive %s: %v", tempContextPath, err)
|
||||||
|
cleanupCreatedJob("failed to open source context archive")
|
||||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to open context archive: %v", err))
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to open context archive: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -275,6 +334,7 @@ func (s *Manager) handleCreateJob(w http.ResponseWriter, r *http.Request) {
|
|||||||
dstFile, err := os.Create(jobContextPath)
|
dstFile, err := os.Create(jobContextPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("ERROR: Failed to create destination context archive %s: %v", jobContextPath, err)
|
log.Printf("ERROR: Failed to create destination context archive %s: %v", jobContextPath, err)
|
||||||
|
cleanupCreatedJob("failed to create destination context archive")
|
||||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create context archive: %v", err))
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create context archive: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -284,6 +344,7 @@ func (s *Manager) handleCreateJob(w http.ResponseWriter, r *http.Request) {
|
|||||||
dstFile.Close()
|
dstFile.Close()
|
||||||
os.Remove(jobContextPath)
|
os.Remove(jobContextPath)
|
||||||
log.Printf("ERROR: Failed to copy context archive from %s to %s: %v", tempContextPath, jobContextPath, err)
|
log.Printf("ERROR: Failed to copy context archive from %s to %s: %v", tempContextPath, jobContextPath, err)
|
||||||
|
cleanupCreatedJob("failed to copy context archive")
|
||||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to copy context archive: %v", err))
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to copy context archive: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -291,6 +352,7 @@ func (s *Manager) handleCreateJob(w http.ResponseWriter, r *http.Request) {
|
|||||||
srcFile.Close()
|
srcFile.Close()
|
||||||
if err := dstFile.Close(); err != nil {
|
if err := dstFile.Close(); err != nil {
|
||||||
log.Printf("ERROR: Failed to close destination file: %v", err)
|
log.Printf("ERROR: Failed to close destination file: %v", err)
|
||||||
|
cleanupCreatedJob("failed to finalize destination context archive")
|
||||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to finalize context archive: %v", err))
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to finalize context archive: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -301,6 +363,7 @@ func (s *Manager) handleCreateJob(w http.ResponseWriter, r *http.Request) {
|
|||||||
contextInfo, err := os.Stat(jobContextPath)
|
contextInfo, err := os.Stat(jobContextPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("ERROR: Failed to stat context archive after move: %v", err)
|
log.Printf("ERROR: Failed to stat context archive after move: %v", err)
|
||||||
|
cleanupCreatedJob("failed to stat copied context archive")
|
||||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify context archive: %v", err))
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify context archive: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -320,6 +383,7 @@ func (s *Manager) handleCreateJob(w http.ResponseWriter, r *http.Request) {
|
|||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("ERROR: Failed to record context archive in database for job %d: %v", jobID, err)
|
log.Printf("ERROR: Failed to record context archive in database for job %d: %v", jobID, err)
|
||||||
|
cleanupCreatedJob("failed to record context archive in database")
|
||||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to record context archive: %v", err))
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to record context archive: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -345,9 +409,9 @@ func (s *Manager) handleCreateJob(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Only create render tasks for render jobs
|
// Only create render tasks for render jobs
|
||||||
if req.JobType == types.JobTypeRender {
|
if req.JobType == types.JobTypeRender {
|
||||||
// Determine task timeout based on output format
|
// Determine task timeout based on output format
|
||||||
taskTimeout := RenderTimeout // 1 hour for render jobs
|
taskTimeout := s.renderTimeout
|
||||||
if *req.OutputFormat == "EXR_264_MP4" || *req.OutputFormat == "EXR_AV1_MP4" || *req.OutputFormat == "EXR_VP9_WEBM" {
|
if *req.OutputFormat == "EXR_264_MP4" || *req.OutputFormat == "EXR_AV1_MP4" || *req.OutputFormat == "EXR_VP9_WEBM" {
|
||||||
taskTimeout = VideoEncodeTimeout // 24 hours for encoding
|
taskTimeout = s.videoEncodeTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create tasks for the job (batch INSERT in a single transaction)
|
// Create tasks for the job (batch INSERT in a single transaction)
|
||||||
@@ -382,6 +446,7 @@ func (s *Manager) handleCreateJob(w http.ResponseWriter, r *http.Request) {
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
cleanupCreatedJob("failed to create render tasks")
|
||||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create tasks: %v", err))
|
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create tasks: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -390,7 +455,7 @@ func (s *Manager) handleCreateJob(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Create encode task immediately if output format requires it
|
// Create encode task immediately if output format requires it
|
||||||
// The task will have a condition that prevents it from being assigned until all render tasks are completed
|
// The task will have a condition that prevents it from being assigned until all render tasks are completed
|
||||||
if *req.OutputFormat == "EXR_264_MP4" || *req.OutputFormat == "EXR_AV1_MP4" || *req.OutputFormat == "EXR_VP9_WEBM" {
|
if *req.OutputFormat == "EXR_264_MP4" || *req.OutputFormat == "EXR_AV1_MP4" || *req.OutputFormat == "EXR_VP9_WEBM" {
|
||||||
encodeTaskTimeout := VideoEncodeTimeout // 24 hours for encoding
|
encodeTaskTimeout := s.videoEncodeTimeout
|
||||||
conditionJSON := `{"type": "all_render_tasks_completed"}`
|
conditionJSON := `{"type": "all_render_tasks_completed"}`
|
||||||
var encodeTaskID int64
|
var encodeTaskID int64
|
||||||
err = s.db.With(func(conn *sql.DB) error {
|
err = s.db.With(func(conn *sql.DB) error {
|
||||||
@@ -1984,10 +2049,14 @@ func (s *Manager) runBlenderMetadataExtraction(blendFile, workDir, blenderVersio
|
|||||||
return nil, fmt.Errorf("failed to create extraction script: %w", err)
|
return nil, fmt.Errorf("failed to create extraction script: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make blend file path relative to workDir to avoid path resolution issues
|
// Use absolute paths to avoid path normalization issues with relative traversal.
|
||||||
blendFileRel, err := filepath.Rel(workDir, blendFile)
|
blendFileAbs, err := filepath.Abs(blendFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get relative path for blend file: %w", err)
|
return nil, fmt.Errorf("failed to get absolute path for blend file: %w", err)
|
||||||
|
}
|
||||||
|
scriptPathAbs, err := filepath.Abs(scriptPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get absolute path for extraction script: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine which blender binary to use
|
// Determine which blender binary to use
|
||||||
@@ -2037,11 +2106,17 @@ func (s *Manager) runBlenderMetadataExtraction(blendFile, workDir, blenderVersio
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure Blender binary is always an absolute path.
|
||||||
|
blenderBinary, err = resolveBlenderBinaryPath(blenderBinary)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// Execute Blender using executils (set LD_LIBRARY_PATH for tarball installs)
|
// Execute Blender using executils (set LD_LIBRARY_PATH for tarball installs)
|
||||||
runEnv := blender.TarballEnv(blenderBinary, os.Environ())
|
runEnv := blender.TarballEnv(blenderBinary, os.Environ())
|
||||||
result, err := executils.RunCommand(
|
result, err := executils.RunCommand(
|
||||||
blenderBinary,
|
blenderBinary,
|
||||||
[]string{"-b", blendFileRel, "--python", "extract_metadata.py"},
|
[]string{"-b", blendFileAbs, "--python", scriptPathAbs},
|
||||||
workDir,
|
workDir,
|
||||||
runEnv,
|
runEnv,
|
||||||
0, // no task ID for metadata extraction
|
0, // no task ID for metadata extraction
|
||||||
@@ -2592,7 +2667,7 @@ func (s *Manager) handleDownloadJobFile(w http.ResponseWriter, r *http.Request)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set headers
|
// Set headers
|
||||||
w.Header().Set("Content-Disposition", fmt.Sprintf("%s; filename=%s", disposition, fileName))
|
w.Header().Set("Content-Disposition", fmt.Sprintf("%s; filename=%q", disposition, fileName))
|
||||||
w.Header().Set("Content-Type", contentType)
|
w.Header().Set("Content-Type", contentType)
|
||||||
|
|
||||||
// Stream file
|
// Stream file
|
||||||
@@ -2710,7 +2785,7 @@ func (s *Manager) handleDownloadEXRZip(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
fileName := fmt.Sprintf("%s-exr.zip", safeJobName)
|
fileName := fmt.Sprintf("%s-exr.zip", safeJobName)
|
||||||
w.Header().Set("Content-Type", "application/zip")
|
w.Header().Set("Content-Type", "application/zip")
|
||||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", fileName))
|
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", fileName))
|
||||||
|
|
||||||
zipWriter := zip.NewWriter(w)
|
zipWriter := zip.NewWriter(w)
|
||||||
defer zipWriter.Close()
|
defer zipWriter.Close()
|
||||||
@@ -2881,7 +2956,7 @@ func (s *Manager) handlePreviewEXR(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// Set headers
|
// Set headers
|
||||||
pngFileName := strings.TrimSuffix(fileName, filepath.Ext(fileName)) + ".png"
|
pngFileName := strings.TrimSuffix(fileName, filepath.Ext(fileName)) + ".png"
|
||||||
w.Header().Set("Content-Disposition", fmt.Sprintf("inline; filename=%s", pngFileName))
|
w.Header().Set("Content-Disposition", fmt.Sprintf("inline; filename=%q", pngFileName))
|
||||||
w.Header().Set("Content-Type", "image/png")
|
w.Header().Set("Content-Type", "image/png")
|
||||||
w.Header().Set("Content-Length", strconv.Itoa(len(pngData)))
|
w.Header().Set("Content-Length", strconv.Itoa(len(pngData)))
|
||||||
|
|
||||||
|
|||||||
145
internal/manager/jobs_test.go
Normal file
145
internal/manager/jobs_test.go
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"archive/tar"
|
||||||
|
"bytes"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateAndCheckETag(t *testing.T) {
|
||||||
|
etag := generateETag(map[string]interface{}{"a": 1})
|
||||||
|
if etag == "" {
|
||||||
|
t.Fatal("expected non-empty etag")
|
||||||
|
}
|
||||||
|
req := httptest.NewRequest("GET", "/x", nil)
|
||||||
|
req.Header.Set("If-None-Match", etag)
|
||||||
|
if !checkETag(req, etag) {
|
||||||
|
t.Fatal("expected etag match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadSessionPhase(t *testing.T) {
|
||||||
|
if got := uploadSessionPhase("uploading"); got != "upload" {
|
||||||
|
t.Fatalf("unexpected phase: %q", got)
|
||||||
|
}
|
||||||
|
if got := uploadSessionPhase("select_blend"); got != "action_required" {
|
||||||
|
t.Fatalf("unexpected phase: %q", got)
|
||||||
|
}
|
||||||
|
if got := uploadSessionPhase("something_else"); got != "processing" {
|
||||||
|
t.Fatalf("unexpected fallback phase: %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTarHeader_AndTruncateString(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
tw := tar.NewWriter(&buf)
|
||||||
|
_ = tw.WriteHeader(&tar.Header{Name: "a.txt", Mode: 0644, Size: 3, Typeflag: tar.TypeReg})
|
||||||
|
_, _ = tw.Write([]byte("abc"))
|
||||||
|
_ = tw.Close()
|
||||||
|
|
||||||
|
raw := buf.Bytes()
|
||||||
|
if len(raw) < 512 {
|
||||||
|
t.Fatal("tar buffer unexpectedly small")
|
||||||
|
}
|
||||||
|
var h tar.Header
|
||||||
|
if err := parseTarHeader(raw[:512], &h); err != nil {
|
||||||
|
t.Fatalf("parseTarHeader failed: %v", err)
|
||||||
|
}
|
||||||
|
if h.Name != "a.txt" {
|
||||||
|
t.Fatalf("unexpected parsed name: %q", h.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := truncateString("abcdef", 5); got != "ab..." {
|
||||||
|
t.Fatalf("truncateString = %q, want %q", got, "ab...")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateUploadSessionForJobCreation_MissingSession(t *testing.T) {
|
||||||
|
s := &Manager{
|
||||||
|
uploadSessions: map[string]*UploadSession{},
|
||||||
|
}
|
||||||
|
_, _, err := s.validateUploadSessionForJobCreation("missing", 1)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected validation error for missing session")
|
||||||
|
}
|
||||||
|
sessionErr, ok := err.(*uploadSessionValidationError)
|
||||||
|
if !ok || sessionErr.Code != uploadSessionExpiredCode {
|
||||||
|
t.Fatalf("expected %s validation error, got %#v", uploadSessionExpiredCode, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateUploadSessionForJobCreation_ContextMissing(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
s := &Manager{
|
||||||
|
uploadSessions: map[string]*UploadSession{
|
||||||
|
"s1": {
|
||||||
|
SessionID: "s1",
|
||||||
|
UserID: 9,
|
||||||
|
TempDir: tmpDir,
|
||||||
|
Status: "completed",
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, _, err := s.validateUploadSessionForJobCreation("s1", 9); err == nil {
|
||||||
|
t.Fatal("expected error when context.tar is missing")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateUploadSessionForJobCreation_NotReady(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
s := &Manager{
|
||||||
|
uploadSessions: map[string]*UploadSession{
|
||||||
|
"s1": {
|
||||||
|
SessionID: "s1",
|
||||||
|
UserID: 9,
|
||||||
|
TempDir: tmpDir,
|
||||||
|
Status: "processing",
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err := s.validateUploadSessionForJobCreation("s1", 9)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for session that is not completed")
|
||||||
|
}
|
||||||
|
sessionErr, ok := err.(*uploadSessionValidationError)
|
||||||
|
if !ok || sessionErr.Code != uploadSessionNotReadyCode {
|
||||||
|
t.Fatalf("expected %s validation error, got %#v", uploadSessionNotReadyCode, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateUploadSessionForJobCreation_Success(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
contextPath := filepath.Join(tmpDir, "context.tar")
|
||||||
|
if err := os.WriteFile(contextPath, []byte("tar-bytes"), 0644); err != nil {
|
||||||
|
t.Fatalf("write context.tar: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &Manager{
|
||||||
|
uploadSessions: map[string]*UploadSession{
|
||||||
|
"s1": {
|
||||||
|
SessionID: "s1",
|
||||||
|
UserID: 9,
|
||||||
|
TempDir: tmpDir,
|
||||||
|
Status: "completed",
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
session, gotPath, err := s.validateUploadSessionForJobCreation("s1", 9)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected valid session, got error: %v", err)
|
||||||
|
}
|
||||||
|
if session == nil || gotPath != contextPath {
|
||||||
|
t.Fatalf("unexpected result: session=%v path=%q", session, gotPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -30,27 +30,22 @@ import (
|
|||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Configuration constants
|
// Configuration constants (non-configurable infrastructure values)
|
||||||
const (
|
const (
|
||||||
// WebSocket timeouts
|
// WebSocket timeouts
|
||||||
WSReadDeadline = 90 * time.Second
|
WSReadDeadline = 90 * time.Second
|
||||||
WSPingInterval = 30 * time.Second
|
WSPingInterval = 30 * time.Second
|
||||||
WSWriteDeadline = 10 * time.Second
|
WSWriteDeadline = 10 * time.Second
|
||||||
|
|
||||||
// Task timeouts
|
// Infrastructure timers
|
||||||
RenderTimeout = 60 * 60 // 1 hour for frame rendering
|
|
||||||
VideoEncodeTimeout = 60 * 60 * 24 // 24 hours for encoding
|
|
||||||
|
|
||||||
// Limits
|
|
||||||
MaxUploadSize = 50 << 30 // 50 GB
|
|
||||||
RunnerHeartbeatTimeout = 90 * time.Second
|
RunnerHeartbeatTimeout = 90 * time.Second
|
||||||
TaskDistributionInterval = 10 * time.Second
|
TaskDistributionInterval = 10 * time.Second
|
||||||
ProgressUpdateThrottle = 2 * time.Second
|
ProgressUpdateThrottle = 2 * time.Second
|
||||||
|
|
||||||
// Cookie settings
|
|
||||||
SessionCookieMaxAge = 86400 // 24 hours
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Operational limits are loaded from database config at Manager initialization.
|
||||||
|
// Defaults are defined in internal/config/config.go convenience methods.
|
||||||
|
|
||||||
// Manager represents the manager server
|
// Manager represents the manager server
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
db *database.DB
|
db *database.DB
|
||||||
@@ -109,6 +104,12 @@ type Manager struct {
|
|||||||
|
|
||||||
// Server start time for health checks
|
// Server start time for health checks
|
||||||
startTime time.Time
|
startTime time.Time
|
||||||
|
|
||||||
|
// Configurable operational values loaded from config
|
||||||
|
renderTimeout int // seconds
|
||||||
|
videoEncodeTimeout int // seconds
|
||||||
|
maxUploadSize int64 // bytes
|
||||||
|
sessionCookieMaxAge int // seconds
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClientConnection represents a client WebSocket connection with subscriptions
|
// ClientConnection represents a client WebSocket connection with subscriptions
|
||||||
@@ -166,6 +167,11 @@ func NewManager(db *database.DB, cfg *config.Config, auth *authpkg.Auth, storage
|
|||||||
router: chi.NewRouter(),
|
router: chi.NewRouter(),
|
||||||
ui: ui,
|
ui: ui,
|
||||||
startTime: time.Now(),
|
startTime: time.Now(),
|
||||||
|
|
||||||
|
renderTimeout: cfg.RenderTimeoutSeconds(),
|
||||||
|
videoEncodeTimeout: cfg.EncodeTimeoutSeconds(),
|
||||||
|
maxUploadSize: cfg.MaxUploadBytes(),
|
||||||
|
sessionCookieMaxAge: cfg.SessionCookieMaxAgeSec(),
|
||||||
wsUpgrader: websocket.Upgrader{
|
wsUpgrader: websocket.Upgrader{
|
||||||
CheckOrigin: checkWebSocketOrigin,
|
CheckOrigin: checkWebSocketOrigin,
|
||||||
ReadBufferSize: 1024,
|
ReadBufferSize: 1024,
|
||||||
@@ -189,6 +195,10 @@ func NewManager(db *database.DB, cfg *config.Config, auth *authpkg.Auth, storage
|
|||||||
jobStatusUpdateMu: make(map[int64]*sync.Mutex),
|
jobStatusUpdateMu: make(map[int64]*sync.Mutex),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize rate limiters from config
|
||||||
|
apiRateLimiter = NewRateLimiter(cfg.APIRateLimitPerMinute(), time.Minute)
|
||||||
|
authRateLimiter = NewRateLimiter(cfg.AuthRateLimitPerMinute(), time.Minute)
|
||||||
|
|
||||||
// Check for required external tools
|
// Check for required external tools
|
||||||
if err := s.checkRequiredTools(); err != nil {
|
if err := s.checkRequiredTools(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -267,6 +277,7 @@ type RateLimiter struct {
|
|||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
limit int // max requests
|
limit int // max requests
|
||||||
window time.Duration // time window
|
window time.Duration // time window
|
||||||
|
stopChan chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRateLimiter creates a new rate limiter
|
// NewRateLimiter creates a new rate limiter
|
||||||
@@ -275,12 +286,17 @@ func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
|
|||||||
requests: make(map[string][]time.Time),
|
requests: make(map[string][]time.Time),
|
||||||
limit: limit,
|
limit: limit,
|
||||||
window: window,
|
window: window,
|
||||||
|
stopChan: make(chan struct{}),
|
||||||
}
|
}
|
||||||
// Start cleanup goroutine
|
|
||||||
go rl.cleanup()
|
go rl.cleanup()
|
||||||
return rl
|
return rl
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop shuts down the cleanup goroutine.
|
||||||
|
func (rl *RateLimiter) Stop() {
|
||||||
|
close(rl.stopChan)
|
||||||
|
}
|
||||||
|
|
||||||
// Allow checks if a request from the given IP is allowed
|
// Allow checks if a request from the given IP is allowed
|
||||||
func (rl *RateLimiter) Allow(ip string) bool {
|
func (rl *RateLimiter) Allow(ip string) bool {
|
||||||
rl.mu.Lock()
|
rl.mu.Lock()
|
||||||
@@ -313,7 +329,11 @@ func (rl *RateLimiter) Allow(ip string) bool {
|
|||||||
// cleanup periodically removes old entries
|
// cleanup periodically removes old entries
|
||||||
func (rl *RateLimiter) cleanup() {
|
func (rl *RateLimiter) cleanup() {
|
||||||
ticker := time.NewTicker(5 * time.Minute)
|
ticker := time.NewTicker(5 * time.Minute)
|
||||||
for range ticker.C {
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
rl.mu.Lock()
|
rl.mu.Lock()
|
||||||
cutoff := time.Now().Add(-rl.window)
|
cutoff := time.Now().Add(-rl.window)
|
||||||
for ip, reqs := range rl.requests {
|
for ip, reqs := range rl.requests {
|
||||||
@@ -330,15 +350,16 @@ func (rl *RateLimiter) cleanup() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
rl.mu.Unlock()
|
rl.mu.Unlock()
|
||||||
|
case <-rl.stopChan:
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Global rate limiters for different endpoint types
|
// Rate limiters — initialized per Manager instance in NewManager.
|
||||||
var (
|
var (
|
||||||
// General API rate limiter: 100 requests per minute per IP
|
apiRateLimiter *RateLimiter
|
||||||
apiRateLimiter = NewRateLimiter(100, time.Minute)
|
authRateLimiter *RateLimiter
|
||||||
// Auth rate limiter: 10 requests per minute per IP (stricter for login attempts)
|
|
||||||
authRateLimiter = NewRateLimiter(10, time.Minute)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// rateLimitMiddleware applies rate limiting based on client IP
|
// rateLimitMiddleware applies rate limiting based on client IP
|
||||||
@@ -609,18 +630,24 @@ func (s *Manager) respondError(w http.ResponseWriter, status int, message string
|
|||||||
s.respondJSON(w, status, map[string]string{"error": message})
|
s.respondJSON(w, status, map[string]string{"error": message})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Manager) respondErrorWithCode(w http.ResponseWriter, status int, code, message string) {
|
||||||
|
s.respondJSON(w, status, map[string]string{
|
||||||
|
"error": message,
|
||||||
|
"code": code,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// createSessionCookie creates a secure session cookie with appropriate flags for the environment
|
// createSessionCookie creates a secure session cookie with appropriate flags for the environment
|
||||||
func createSessionCookie(sessionID string) *http.Cookie {
|
func (s *Manager) createSessionCookie(sessionID string) *http.Cookie {
|
||||||
cookie := &http.Cookie{
|
cookie := &http.Cookie{
|
||||||
Name: "session_id",
|
Name: "session_id",
|
||||||
Value: sessionID,
|
Value: sessionID,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
MaxAge: SessionCookieMaxAge,
|
MaxAge: s.sessionCookieMaxAge,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
// In production mode, set Secure flag to require HTTPS
|
|
||||||
if authpkg.IsProductionMode() {
|
if authpkg.IsProductionMode() {
|
||||||
cookie.Secure = true
|
cookie.Secure = true
|
||||||
}
|
}
|
||||||
@@ -712,7 +739,7 @@ func (s *Manager) handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
sessionID := s.auth.CreateSession(session)
|
sessionID := s.auth.CreateSession(session)
|
||||||
http.SetCookie(w, createSessionCookie(sessionID))
|
http.SetCookie(w, s.createSessionCookie(sessionID))
|
||||||
|
|
||||||
http.Redirect(w, r, "/", http.StatusFound)
|
http.Redirect(w, r, "/", http.StatusFound)
|
||||||
}
|
}
|
||||||
@@ -745,7 +772,7 @@ func (s *Manager) handleDiscordCallback(w http.ResponseWriter, r *http.Request)
|
|||||||
}
|
}
|
||||||
|
|
||||||
sessionID := s.auth.CreateSession(session)
|
sessionID := s.auth.CreateSession(session)
|
||||||
http.SetCookie(w, createSessionCookie(sessionID))
|
http.SetCookie(w, s.createSessionCookie(sessionID))
|
||||||
|
|
||||||
http.Redirect(w, r, "/", http.StatusFound)
|
http.Redirect(w, r, "/", http.StatusFound)
|
||||||
}
|
}
|
||||||
@@ -838,7 +865,7 @@ func (s *Manager) handleLocalRegister(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
sessionID := s.auth.CreateSession(session)
|
sessionID := s.auth.CreateSession(session)
|
||||||
http.SetCookie(w, createSessionCookie(sessionID))
|
http.SetCookie(w, s.createSessionCookie(sessionID))
|
||||||
|
|
||||||
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
|
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
|
||||||
"message": "Registration successful",
|
"message": "Registration successful",
|
||||||
@@ -875,7 +902,7 @@ func (s *Manager) handleLocalLogin(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
sessionID := s.auth.CreateSession(session)
|
sessionID := s.auth.CreateSession(session)
|
||||||
http.SetCookie(w, createSessionCookie(sessionID))
|
http.SetCookie(w, s.createSessionCookie(sessionID))
|
||||||
|
|
||||||
s.respondJSON(w, http.StatusOK, map[string]interface{}{
|
s.respondJSON(w, http.StatusOK, map[string]interface{}{
|
||||||
"message": "Login successful",
|
"message": "Login successful",
|
||||||
|
|||||||
50
internal/manager/manager_test.go
Normal file
50
internal/manager/manager_test.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCheckWebSocketOrigin_DevelopmentAllowsOrigin(t *testing.T) {
|
||||||
|
t.Setenv("PRODUCTION", "false")
|
||||||
|
req := httptest.NewRequest("GET", "http://localhost/ws", nil)
|
||||||
|
req.Host = "localhost:8080"
|
||||||
|
req.Header.Set("Origin", "http://example.com")
|
||||||
|
if !checkWebSocketOrigin(req) {
|
||||||
|
t.Fatal("expected development mode to allow origin")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckWebSocketOrigin_ProductionSameHostAllowed(t *testing.T) {
|
||||||
|
t.Setenv("PRODUCTION", "true")
|
||||||
|
t.Setenv("ALLOWED_ORIGINS", "")
|
||||||
|
req := httptest.NewRequest("GET", "http://localhost/ws", nil)
|
||||||
|
req.Host = "localhost:8080"
|
||||||
|
req.Header.Set("Origin", "http://localhost:8080")
|
||||||
|
if !checkWebSocketOrigin(req) {
|
||||||
|
t.Fatal("expected same-host origin to be allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRespondErrorWithCode_IncludesCodeField(t *testing.T) {
|
||||||
|
s := &Manager{}
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
s.respondErrorWithCode(rr, http.StatusBadRequest, "UPLOAD_SESSION_EXPIRED", "Upload session expired.")
|
||||||
|
|
||||||
|
if rr.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("status = %d, want %d", rr.Code, http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
var payload map[string]string
|
||||||
|
if err := json.Unmarshal(rr.Body.Bytes(), &payload); err != nil {
|
||||||
|
t.Fatalf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
if payload["code"] != "UPLOAD_SESSION_EXPIRED" {
|
||||||
|
t.Fatalf("unexpected code: %q", payload["code"])
|
||||||
|
}
|
||||||
|
if payload["error"] == "" {
|
||||||
|
t.Fatal("expected non-empty error message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -19,6 +19,9 @@ import (
|
|||||||
"jiggablend/pkg/types"
|
"jiggablend/pkg/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var runMetadataCommand = executils.RunCommand
|
||||||
|
var resolveMetadataBlenderPath = resolveBlenderBinaryPath
|
||||||
|
|
||||||
// handleGetJobMetadata retrieves metadata for a job
|
// handleGetJobMetadata retrieves metadata for a job
|
||||||
func (s *Manager) handleGetJobMetadata(w http.ResponseWriter, r *http.Request) {
|
func (s *Manager) handleGetJobMetadata(w http.ResponseWriter, r *http.Request) {
|
||||||
userID, err := getUserID(r)
|
userID, err := getUserID(r)
|
||||||
@@ -141,16 +144,24 @@ func (s *Manager) extractMetadataFromContext(jobID int64) (*types.BlendMetadata,
|
|||||||
return nil, fmt.Errorf("failed to create extraction script: %w", err)
|
return nil, fmt.Errorf("failed to create extraction script: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make blend file path relative to tmpDir to avoid path resolution issues
|
// Use absolute paths to avoid path normalization issues with relative traversal.
|
||||||
blendFileRel, err := filepath.Rel(tmpDir, blendFile)
|
blendFileAbs, err := filepath.Abs(blendFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get relative path for blend file: %w", err)
|
return nil, fmt.Errorf("failed to get absolute path for blend file: %w", err)
|
||||||
|
}
|
||||||
|
scriptPathAbs, err := filepath.Abs(scriptPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get absolute path for extraction script: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute Blender with Python script using executils
|
// Execute Blender with Python script using executils
|
||||||
result, err := executils.RunCommand(
|
blenderBinary, err := resolveMetadataBlenderPath("blender")
|
||||||
"blender",
|
if err != nil {
|
||||||
[]string{"-b", blendFileRel, "--python", "extract_metadata.py"},
|
return nil, err
|
||||||
|
}
|
||||||
|
result, err := runMetadataCommand(
|
||||||
|
blenderBinary,
|
||||||
|
[]string{"-b", blendFileAbs, "--python", scriptPathAbs},
|
||||||
tmpDir,
|
tmpDir,
|
||||||
nil, // inherit environment
|
nil, // inherit environment
|
||||||
jobID,
|
jobID,
|
||||||
@@ -225,8 +236,17 @@ func (s *Manager) extractTar(tarPath, destDir string) error {
|
|||||||
return fmt.Errorf("failed to read tar header: %w", err)
|
return fmt.Errorf("failed to read tar header: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sanitize path to prevent directory traversal
|
// Sanitize path to prevent directory traversal. TAR stores "/" separators, so normalize first.
|
||||||
target := filepath.Join(destDir, header.Name)
|
normalizedHeaderPath := filepath.FromSlash(header.Name)
|
||||||
|
cleanHeaderPath := filepath.Clean(normalizedHeaderPath)
|
||||||
|
if cleanHeaderPath == "." {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if filepath.IsAbs(cleanHeaderPath) || strings.HasPrefix(cleanHeaderPath, ".."+string(os.PathSeparator)) || cleanHeaderPath == ".." {
|
||||||
|
log.Printf("ERROR: Invalid file path in TAR - header: %s", header.Name)
|
||||||
|
return fmt.Errorf("invalid file path in archive: %s", header.Name)
|
||||||
|
}
|
||||||
|
target := filepath.Join(destDir, cleanHeaderPath)
|
||||||
|
|
||||||
// Ensure target is within destDir
|
// Ensure target is within destDir
|
||||||
cleanTarget := filepath.Clean(target)
|
cleanTarget := filepath.Clean(target)
|
||||||
@@ -237,14 +257,14 @@ func (s *Manager) extractTar(tarPath, destDir string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create parent directories
|
// Create parent directories
|
||||||
if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil {
|
if err := os.MkdirAll(filepath.Dir(cleanTarget), 0755); err != nil {
|
||||||
return fmt.Errorf("failed to create directory: %w", err)
|
return fmt.Errorf("failed to create directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write file
|
// Write file
|
||||||
switch header.Typeflag {
|
switch header.Typeflag {
|
||||||
case tar.TypeReg:
|
case tar.TypeReg:
|
||||||
outFile, err := os.Create(target)
|
outFile, err := os.Create(cleanTarget)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create file: %w", err)
|
return fmt.Errorf("failed to create file: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
98
internal/manager/metadata_test.go
Normal file
98
internal/manager/metadata_test.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"archive/tar"
|
||||||
|
"bytes"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"jiggablend/internal/storage"
|
||||||
|
"jiggablend/pkg/executils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractTar_ExtractsRegularFile(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
tw := tar.NewWriter(&buf)
|
||||||
|
_ = tw.WriteHeader(&tar.Header{Name: "ctx/scene.blend", Mode: 0644, Size: 4, Typeflag: tar.TypeReg})
|
||||||
|
_, _ = tw.Write([]byte("data"))
|
||||||
|
_ = tw.Close()
|
||||||
|
|
||||||
|
tarPath := filepath.Join(t.TempDir(), "ctx.tar")
|
||||||
|
if err := os.WriteFile(tarPath, buf.Bytes(), 0644); err != nil {
|
||||||
|
t.Fatalf("write tar: %v", err)
|
||||||
|
}
|
||||||
|
dest := t.TempDir()
|
||||||
|
m := &Manager{}
|
||||||
|
if err := m.extractTar(tarPath, dest); err != nil {
|
||||||
|
t.Fatalf("extractTar failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(filepath.Join(dest, "ctx", "scene.blend")); err != nil {
|
||||||
|
t.Fatalf("expected extracted file: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractTar_RejectsTraversal(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
tw := tar.NewWriter(&buf)
|
||||||
|
_ = tw.WriteHeader(&tar.Header{Name: "../evil.txt", Mode: 0644, Size: 1, Typeflag: tar.TypeReg})
|
||||||
|
_, _ = tw.Write([]byte("x"))
|
||||||
|
_ = tw.Close()
|
||||||
|
|
||||||
|
tarPath := filepath.Join(t.TempDir(), "bad.tar")
|
||||||
|
if err := os.WriteFile(tarPath, buf.Bytes(), 0644); err != nil {
|
||||||
|
t.Fatalf("write tar: %v", err)
|
||||||
|
}
|
||||||
|
m := &Manager{}
|
||||||
|
if err := m.extractTar(tarPath, t.TempDir()); err == nil {
|
||||||
|
t.Fatal("expected path traversal error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractMetadataFromContext_UsesCommandSeam(t *testing.T) {
|
||||||
|
base := t.TempDir()
|
||||||
|
st, err := storage.NewStorage(base)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new storage: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jobID := int64(42)
|
||||||
|
jobDir := st.JobPath(jobID)
|
||||||
|
if err := os.MkdirAll(jobDir, 0755); err != nil {
|
||||||
|
t.Fatalf("mkdir job dir: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
tw := tar.NewWriter(&buf)
|
||||||
|
_ = tw.WriteHeader(&tar.Header{Name: "scene.blend", Mode: 0644, Size: 4, Typeflag: tar.TypeReg})
|
||||||
|
_, _ = tw.Write([]byte("fake"))
|
||||||
|
_ = tw.Close()
|
||||||
|
if err := os.WriteFile(filepath.Join(jobDir, "context.tar"), buf.Bytes(), 0644); err != nil {
|
||||||
|
t.Fatalf("write context tar: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
origResolve := resolveMetadataBlenderPath
|
||||||
|
origRun := runMetadataCommand
|
||||||
|
resolveMetadataBlenderPath = func(_ string) (string, error) { return "/usr/bin/blender", nil }
|
||||||
|
runMetadataCommand = func(_ string, _ []string, _ string, _ []string, _ int64, _ *executils.ProcessTracker) (*executils.CommandResult, error) {
|
||||||
|
return &executils.CommandResult{
|
||||||
|
Stdout: `noise
|
||||||
|
{"frame_start":1,"frame_end":3,"has_negative_frames":false,"render_settings":{"resolution_x":1920,"resolution_y":1080,"frame_rate":24,"output_format":"PNG","engine":"CYCLES"},"scene_info":{"camera_count":1,"object_count":2,"material_count":3}}
|
||||||
|
done`,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
resolveMetadataBlenderPath = origResolve
|
||||||
|
runMetadataCommand = origRun
|
||||||
|
}()
|
||||||
|
|
||||||
|
m := &Manager{storage: st}
|
||||||
|
meta, err := m.extractMetadataFromContext(jobID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("extractMetadataFromContext failed: %v", err)
|
||||||
|
}
|
||||||
|
if meta.FrameStart != 1 || meta.FrameEnd != 3 {
|
||||||
|
t.Fatalf("unexpected metadata: %+v", *meta)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -3,6 +3,7 @@ package api
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -92,13 +93,17 @@ func newUIRenderer() (*uiRenderer, error) {
|
|||||||
func (r *uiRenderer) render(w http.ResponseWriter, data pageData) {
|
func (r *uiRenderer) render(w http.ResponseWriter, data pageData) {
|
||||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
if err := r.templates.ExecuteTemplate(w, "base", data); err != nil {
|
if err := r.templates.ExecuteTemplate(w, "base", data); err != nil {
|
||||||
|
log.Printf("Template render error: %v", err)
|
||||||
http.Error(w, "template render error", http.StatusInternalServerError)
|
http.Error(w, "template render error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *uiRenderer) renderTemplate(w http.ResponseWriter, templateName string, data interface{}) {
|
func (r *uiRenderer) renderTemplate(w http.ResponseWriter, templateName string, data interface{}) {
|
||||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
if err := r.templates.ExecuteTemplate(w, templateName, data); err != nil {
|
if err := r.templates.ExecuteTemplate(w, templateName, data); err != nil {
|
||||||
|
log.Printf("Template render error for %s: %v", templateName, err)
|
||||||
http.Error(w, "template render error", http.StatusInternalServerError)
|
http.Error(w, "template render error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -765,7 +765,7 @@ func (s *Manager) handleDownloadJobContext(w http.ResponseWriter, r *http.Reques
|
|||||||
|
|
||||||
// Set appropriate headers for tar file
|
// Set appropriate headers for tar file
|
||||||
w.Header().Set("Content-Type", "application/x-tar")
|
w.Header().Set("Content-Type", "application/x-tar")
|
||||||
w.Header().Set("Content-Disposition", "attachment; filename=context.tar")
|
w.Header().Set("Content-Disposition", "attachment; filename=\"context.tar\"")
|
||||||
|
|
||||||
// Stream the file to the response
|
// Stream the file to the response
|
||||||
io.Copy(w, file)
|
io.Copy(w, file)
|
||||||
@@ -821,7 +821,7 @@ func (s *Manager) handleDownloadJobContextWithToken(w http.ResponseWriter, r *ht
|
|||||||
|
|
||||||
// Set appropriate headers for tar file
|
// Set appropriate headers for tar file
|
||||||
w.Header().Set("Content-Type", "application/x-tar")
|
w.Header().Set("Content-Type", "application/x-tar")
|
||||||
w.Header().Set("Content-Disposition", "attachment; filename=context.tar")
|
w.Header().Set("Content-Disposition", "attachment; filename=\"context.tar\"")
|
||||||
|
|
||||||
// Stream the file to the response
|
// Stream the file to the response
|
||||||
io.Copy(w, file)
|
io.Copy(w, file)
|
||||||
@@ -836,7 +836,7 @@ func (s *Manager) handleUploadFileFromRunner(w http.ResponseWriter, r *http.Requ
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.ParseMultipartForm(MaxUploadSize) // 50 GB (for large output files)
|
err = r.ParseMultipartForm(s.maxUploadSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Failed to parse multipart form: %v", err))
|
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Failed to parse multipart form: %v", err))
|
||||||
return
|
return
|
||||||
@@ -944,7 +944,7 @@ func (s *Manager) handleUploadFileWithToken(w http.ResponseWriter, r *http.Reque
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.ParseMultipartForm(MaxUploadSize) // 50 GB (for large output files)
|
err = r.ParseMultipartForm(s.maxUploadSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Failed to parse multipart form: %v", err))
|
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Failed to parse multipart form: %v", err))
|
||||||
return
|
return
|
||||||
@@ -1228,7 +1228,7 @@ func (s *Manager) handleDownloadFileForRunner(w http.ResponseWriter, r *http.Req
|
|||||||
|
|
||||||
// Set headers
|
// Set headers
|
||||||
w.Header().Set("Content-Type", contentType)
|
w.Header().Set("Content-Type", contentType)
|
||||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", decodedFileName))
|
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", decodedFileName))
|
||||||
|
|
||||||
// Stream file
|
// Stream file
|
||||||
io.Copy(w, file)
|
io.Copy(w, file)
|
||||||
@@ -1476,40 +1476,33 @@ func (s *Manager) handleRunnerJobWebSocket(w http.ResponseWriter, r *http.Reques
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "runner_heartbeat":
|
case "runner_heartbeat":
|
||||||
// Lookup runner ID from job's assigned_runner_id
|
s.handleWSRunnerHeartbeat(conn, jobID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleWSRunnerHeartbeat processes a runner heartbeat received over a job WebSocket.
|
||||||
|
func (s *Manager) handleWSRunnerHeartbeat(conn *websocket.Conn, jobID int64) {
|
||||||
var assignedRunnerID sql.NullInt64
|
var assignedRunnerID sql.NullInt64
|
||||||
err := s.db.With(func(db *sql.DB) error {
|
err := s.db.With(func(db *sql.DB) error {
|
||||||
return db.QueryRow(
|
return db.QueryRow(
|
||||||
"SELECT assigned_runner_id FROM jobs WHERE id = ?",
|
"SELECT assigned_runner_id FROM jobs WHERE id = ?", jobID,
|
||||||
jobID,
|
|
||||||
).Scan(&assignedRunnerID)
|
).Scan(&assignedRunnerID)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to lookup runner for job %d heartbeat: %v", jobID, err)
|
log.Printf("Failed to lookup runner for job %d heartbeat: %v", jobID, err)
|
||||||
// Send error response
|
s.sendWebSocketMessage(conn, map[string]interface{}{"type": "error", "message": "Failed to process heartbeat"})
|
||||||
response := map[string]interface{}{
|
return
|
||||||
"type": "error",
|
|
||||||
"message": "Failed to process heartbeat",
|
|
||||||
}
|
|
||||||
s.sendWebSocketMessage(conn, response)
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !assignedRunnerID.Valid {
|
if !assignedRunnerID.Valid {
|
||||||
log.Printf("Job %d has no assigned runner, skipping heartbeat update", jobID)
|
log.Printf("Job %d has no assigned runner, skipping heartbeat update", jobID)
|
||||||
// Send acknowledgment but no database update
|
s.sendWebSocketMessage(conn, map[string]interface{}{"type": "heartbeat_ack", "timestamp": time.Now().Unix(), "message": "No assigned runner for this job"})
|
||||||
response := map[string]interface{}{
|
return
|
||||||
"type": "heartbeat_ack",
|
|
||||||
"timestamp": time.Now().Unix(),
|
|
||||||
"message": "No assigned runner for this job",
|
|
||||||
}
|
|
||||||
s.sendWebSocketMessage(conn, response)
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
runnerID := assignedRunnerID.Int64
|
runnerID := assignedRunnerID.Int64
|
||||||
|
|
||||||
// Update runner heartbeat
|
|
||||||
err = s.db.With(func(db *sql.DB) error {
|
err = s.db.With(func(db *sql.DB) error {
|
||||||
_, err := db.Exec(
|
_, err := db.Exec(
|
||||||
"UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?",
|
"UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?",
|
||||||
@@ -1519,25 +1512,11 @@ func (s *Manager) handleRunnerJobWebSocket(w http.ResponseWriter, r *http.Reques
|
|||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to update runner %d heartbeat for job %d: %v", runnerID, jobID, err)
|
log.Printf("Failed to update runner %d heartbeat for job %d: %v", runnerID, jobID, err)
|
||||||
// Send error response
|
s.sendWebSocketMessage(conn, map[string]interface{}{"type": "error", "message": "Failed to update heartbeat"})
|
||||||
response := map[string]interface{}{
|
return
|
||||||
"type": "error",
|
|
||||||
"message": "Failed to update heartbeat",
|
|
||||||
}
|
|
||||||
s.sendWebSocketMessage(conn, response)
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send acknowledgment
|
s.sendWebSocketMessage(conn, map[string]interface{}{"type": "heartbeat_ack", "timestamp": time.Now().Unix()})
|
||||||
response := map[string]interface{}{
|
|
||||||
"type": "heartbeat_ack",
|
|
||||||
"timestamp": time.Now().Unix(),
|
|
||||||
}
|
|
||||||
s.sendWebSocketMessage(conn, response)
|
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleWebSocketLog handles log entries from WebSocket
|
// handleWebSocketLog handles log entries from WebSocket
|
||||||
@@ -1948,162 +1927,164 @@ func (s *Manager) cleanupJobStatusUpdateMutex(jobID int64) {
|
|||||||
// This function is serialized per jobID to prevent race conditions when multiple tasks
|
// This function is serialized per jobID to prevent race conditions when multiple tasks
|
||||||
// complete concurrently and trigger status updates simultaneously.
|
// complete concurrently and trigger status updates simultaneously.
|
||||||
func (s *Manager) updateJobStatusFromTasks(jobID int64) {
|
func (s *Manager) updateJobStatusFromTasks(jobID int64) {
|
||||||
// Serialize updates per job to prevent race conditions
|
|
||||||
mu := s.getJobStatusUpdateMutex(jobID)
|
mu := s.getJobStatusUpdateMutex(jobID)
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
defer mu.Unlock()
|
defer mu.Unlock()
|
||||||
|
|
||||||
now := time.Now()
|
currentStatus, err := s.getJobStatus(jobID)
|
||||||
|
|
||||||
// All jobs now use parallel runners (one task per frame), so we always use task-based progress
|
|
||||||
|
|
||||||
// Get current job status to detect changes
|
|
||||||
var currentStatus string
|
|
||||||
err := s.db.With(func(conn *sql.DB) error {
|
|
||||||
return conn.QueryRow(`SELECT status FROM jobs WHERE id = ?`, jobID).Scan(¤tStatus)
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to get current job status for job %d: %v", jobID, err)
|
log.Printf("Failed to get current job status for job %d: %v", jobID, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cancellation is terminal from the user's perspective.
|
|
||||||
// Do not allow asynchronous task updates to revive cancelled jobs.
|
|
||||||
if currentStatus == string(types.JobStatusCancelled) {
|
if currentStatus == string(types.JobStatusCancelled) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count total tasks and completed tasks
|
counts, err := s.getJobTaskCounts(jobID)
|
||||||
var totalTasks, completedTasks int
|
if err != nil {
|
||||||
err = s.db.With(func(conn *sql.DB) error {
|
log.Printf("Failed to count tasks for job %d: %v", jobID, err)
|
||||||
err := conn.QueryRow(
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
progress := counts.progress()
|
||||||
|
|
||||||
|
if counts.pendingOrRunning == 0 && counts.total > 0 {
|
||||||
|
s.handleAllTasksFinished(jobID, currentStatus, counts, progress)
|
||||||
|
} else {
|
||||||
|
s.handleTasksInProgress(jobID, currentStatus, counts, progress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// jobTaskCounts holds task state counts for a job.
|
||||||
|
type jobTaskCounts struct {
|
||||||
|
total int
|
||||||
|
completed int
|
||||||
|
pendingOrRunning int
|
||||||
|
failed int
|
||||||
|
running int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *jobTaskCounts) progress() float64 {
|
||||||
|
if c.total == 0 {
|
||||||
|
return 0.0
|
||||||
|
}
|
||||||
|
return float64(c.completed) / float64(c.total) * 100.0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Manager) getJobStatus(jobID int64) (string, error) {
|
||||||
|
var status string
|
||||||
|
err := s.db.With(func(conn *sql.DB) error {
|
||||||
|
return conn.QueryRow(`SELECT status FROM jobs WHERE id = ?`, jobID).Scan(&status)
|
||||||
|
})
|
||||||
|
return status, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Manager) getJobTaskCounts(jobID int64) (*jobTaskCounts, error) {
|
||||||
|
c := &jobTaskCounts{}
|
||||||
|
err := s.db.With(func(conn *sql.DB) error {
|
||||||
|
if err := conn.QueryRow(
|
||||||
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?, ?, ?)`,
|
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?, ?, ?)`,
|
||||||
jobID, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed,
|
jobID, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed,
|
||||||
).Scan(&totalTasks)
|
).Scan(&c.total); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return conn.QueryRow(
|
if err := conn.QueryRow(
|
||||||
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
|
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
|
||||||
jobID, types.TaskStatusCompleted,
|
jobID, types.TaskStatusCompleted,
|
||||||
).Scan(&completedTasks)
|
).Scan(&c.completed); err != nil {
|
||||||
})
|
return err
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to count completed tasks for job %d: %v", jobID, err)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
if err := conn.QueryRow(
|
||||||
// Calculate progress
|
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?)`,
|
||||||
var progress float64
|
|
||||||
if totalTasks == 0 {
|
|
||||||
// All tasks cancelled or no tasks, set progress to 0
|
|
||||||
progress = 0.0
|
|
||||||
} else {
|
|
||||||
// Standard task-based progress
|
|
||||||
progress = float64(completedTasks) / float64(totalTasks) * 100.0
|
|
||||||
}
|
|
||||||
|
|
||||||
var jobStatus string
|
|
||||||
|
|
||||||
// Check if all non-cancelled tasks are completed
|
|
||||||
var pendingOrRunningTasks int
|
|
||||||
err = s.db.With(func(conn *sql.DB) error {
|
|
||||||
return conn.QueryRow(
|
|
||||||
`SELECT COUNT(*) FROM tasks
|
|
||||||
WHERE job_id = ? AND status IN (?, ?)`,
|
|
||||||
jobID, types.TaskStatusPending, types.TaskStatusRunning,
|
jobID, types.TaskStatusPending, types.TaskStatusRunning,
|
||||||
).Scan(&pendingOrRunningTasks)
|
).Scan(&c.pendingOrRunning); err != nil {
|
||||||
})
|
return err
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to count pending/running tasks for job %d: %v", jobID, err)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
if err := conn.QueryRow(
|
||||||
if pendingOrRunningTasks == 0 && totalTasks > 0 {
|
|
||||||
// All tasks are either completed or failed/cancelled
|
|
||||||
// Check if any tasks failed
|
|
||||||
var failedTasks int
|
|
||||||
s.db.With(func(conn *sql.DB) error {
|
|
||||||
conn.QueryRow(
|
|
||||||
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
|
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
|
||||||
jobID, types.TaskStatusFailed,
|
jobID, types.TaskStatusFailed,
|
||||||
).Scan(&failedTasks)
|
).Scan(&c.failed); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.QueryRow(
|
||||||
|
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
|
||||||
|
jobID, types.TaskStatusRunning,
|
||||||
|
).Scan(&c.running); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
return c, err
|
||||||
|
}
|
||||||
|
|
||||||
if failedTasks > 0 {
|
// handleAllTasksFinished handles the case where no pending/running tasks remain.
|
||||||
// Some tasks failed - check if job has retries left
|
func (s *Manager) handleAllTasksFinished(jobID int64, currentStatus string, counts *jobTaskCounts, progress float64) {
|
||||||
|
now := time.Now()
|
||||||
|
var jobStatus string
|
||||||
|
|
||||||
|
if counts.failed > 0 {
|
||||||
|
jobStatus = s.handleFailedTasks(jobID, currentStatus, &progress)
|
||||||
|
if jobStatus == "" {
|
||||||
|
return // retry handled; early exit
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
jobStatus = string(types.JobStatusCompleted)
|
||||||
|
progress = 100.0
|
||||||
|
}
|
||||||
|
|
||||||
|
s.setJobFinalStatus(jobID, currentStatus, jobStatus, progress, now, counts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleFailedTasks decides whether to retry or mark the job failed.
|
||||||
|
// Returns "" if a retry was triggered (caller should return early),
|
||||||
|
// or the final status string.
|
||||||
|
func (s *Manager) handleFailedTasks(jobID int64, currentStatus string, progress *float64) string {
|
||||||
var retryCount, maxRetries int
|
var retryCount, maxRetries int
|
||||||
err := s.db.With(func(conn *sql.DB) error {
|
err := s.db.With(func(conn *sql.DB) error {
|
||||||
return conn.QueryRow(
|
return conn.QueryRow(
|
||||||
`SELECT retry_count, max_retries FROM jobs WHERE id = ?`,
|
`SELECT retry_count, max_retries FROM jobs WHERE id = ?`, jobID,
|
||||||
jobID,
|
|
||||||
).Scan(&retryCount, &maxRetries)
|
).Scan(&retryCount, &maxRetries)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to get retry info for job %d: %v", jobID, err)
|
log.Printf("Failed to get retry info for job %d: %v", jobID, err)
|
||||||
// Fall back to marking job as failed
|
return string(types.JobStatusFailed)
|
||||||
jobStatus = string(types.JobStatusFailed)
|
}
|
||||||
} else if retryCount < maxRetries {
|
|
||||||
// Job has retries left - reset failed tasks and redistribute
|
if retryCount < maxRetries {
|
||||||
if err := s.resetFailedTasksAndRedistribute(jobID); err != nil {
|
if err := s.resetFailedTasksAndRedistribute(jobID); err != nil {
|
||||||
log.Printf("Failed to reset failed tasks for job %d: %v", jobID, err)
|
log.Printf("Failed to reset failed tasks for job %d: %v", jobID, err)
|
||||||
// If reset fails, mark job as failed
|
return string(types.JobStatusFailed)
|
||||||
jobStatus = string(types.JobStatusFailed)
|
|
||||||
} else {
|
|
||||||
// Tasks reset successfully - job remains in running/pending state
|
|
||||||
// Don't update job status, just update progress
|
|
||||||
jobStatus = currentStatus // Keep current status
|
|
||||||
// Recalculate progress after reset (failed tasks are now pending again)
|
|
||||||
var newTotalTasks, newCompletedTasks int
|
|
||||||
s.db.With(func(conn *sql.DB) error {
|
|
||||||
conn.QueryRow(
|
|
||||||
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status IN (?, ?, ?, ?)`,
|
|
||||||
jobID, types.TaskStatusPending, types.TaskStatusRunning, types.TaskStatusCompleted, types.TaskStatusFailed,
|
|
||||||
).Scan(&newTotalTasks)
|
|
||||||
conn.QueryRow(
|
|
||||||
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
|
|
||||||
jobID, types.TaskStatusCompleted,
|
|
||||||
).Scan(&newCompletedTasks)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if newTotalTasks > 0 {
|
|
||||||
progress = float64(newCompletedTasks) / float64(newTotalTasks) * 100.0
|
|
||||||
}
|
}
|
||||||
// Update progress only
|
// Recalculate progress after reset
|
||||||
err := s.db.With(func(conn *sql.DB) error {
|
counts, err := s.getJobTaskCounts(jobID)
|
||||||
_, err := conn.Exec(
|
if err == nil && counts.total > 0 {
|
||||||
`UPDATE jobs SET progress = ? WHERE id = ?`,
|
*progress = counts.progress()
|
||||||
progress, jobID,
|
}
|
||||||
)
|
err = s.db.With(func(conn *sql.DB) error {
|
||||||
|
_, err := conn.Exec(`UPDATE jobs SET progress = ? WHERE id = ?`, *progress, jobID)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to update job %d progress: %v", jobID, err)
|
log.Printf("Failed to update job %d progress: %v", jobID, err)
|
||||||
} else {
|
} else {
|
||||||
// Broadcast job update via WebSocket
|
|
||||||
s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{
|
s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{
|
||||||
"status": jobStatus,
|
"status": currentStatus,
|
||||||
"progress": progress,
|
"progress": *progress,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return // Exit early since we've handled the retry
|
return "" // retry handled
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// No retries left - mark job as failed and cancel active tasks
|
// No retries left
|
||||||
jobStatus = string(types.JobStatusFailed)
|
|
||||||
if err := s.cancelActiveTasksForJob(jobID); err != nil {
|
if err := s.cancelActiveTasksForJob(jobID); err != nil {
|
||||||
log.Printf("Failed to cancel active tasks for job %d: %v", jobID, err)
|
log.Printf("Failed to cancel active tasks for job %d: %v", jobID, err)
|
||||||
}
|
}
|
||||||
}
|
return string(types.JobStatusFailed)
|
||||||
} else {
|
|
||||||
// All tasks completed successfully
|
|
||||||
jobStatus = string(types.JobStatusCompleted)
|
|
||||||
progress = 100.0 // Ensure progress is 100% when all tasks complete
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update job status (if we didn't return early from retry logic)
|
// setJobFinalStatus persists the terminal job status and broadcasts the update.
|
||||||
if jobStatus != "" {
|
func (s *Manager) setJobFinalStatus(jobID int64, currentStatus, jobStatus string, progress float64, now time.Time, counts *jobTaskCounts) {
|
||||||
err := s.db.With(func(conn *sql.DB) error {
|
err := s.db.With(func(conn *sql.DB) error {
|
||||||
_, err := conn.Exec(
|
_, err := conn.Exec(
|
||||||
`UPDATE jobs SET status = ?, progress = ?, completed_at = ? WHERE id = ?`,
|
`UPDATE jobs SET status = ?, progress = ?, completed_at = ? WHERE id = ?`,
|
||||||
@@ -2113,44 +2094,30 @@ func (s *Manager) updateJobStatusFromTasks(jobID int64) {
|
|||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to update job %d status to %s: %v", jobID, jobStatus, err)
|
log.Printf("Failed to update job %d status to %s: %v", jobID, jobStatus, err)
|
||||||
} else {
|
return
|
||||||
// Only log if status actually changed
|
}
|
||||||
if currentStatus != jobStatus {
|
if currentStatus != jobStatus {
|
||||||
log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed tasks: %d/%d)", jobID, currentStatus, jobStatus, progress, completedTasks, totalTasks)
|
log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed tasks: %d/%d)", jobID, currentStatus, jobStatus, progress, counts.completed, counts.total)
|
||||||
}
|
}
|
||||||
// Broadcast job update via WebSocket
|
|
||||||
s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{
|
s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{
|
||||||
"status": jobStatus,
|
"status": jobStatus,
|
||||||
"progress": progress,
|
"progress": progress,
|
||||||
"completed_at": now,
|
"completed_at": now,
|
||||||
})
|
})
|
||||||
// Clean up mutex for jobs in final states (completed or failed)
|
|
||||||
// No more status updates will occur for these jobs
|
|
||||||
if jobStatus == string(types.JobStatusCompleted) || jobStatus == string(types.JobStatusFailed) {
|
if jobStatus == string(types.JobStatusCompleted) || jobStatus == string(types.JobStatusFailed) {
|
||||||
s.cleanupJobStatusUpdateMutex(jobID)
|
s.cleanupJobStatusUpdateMutex(jobID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Encode tasks are now created immediately when the job is created
|
// handleTasksInProgress handles the case where tasks are still pending or running.
|
||||||
// with a condition that prevents assignment until all render tasks are completed.
|
func (s *Manager) handleTasksInProgress(jobID int64, currentStatus string, counts *jobTaskCounts, progress float64) {
|
||||||
// No need to create them here anymore.
|
now := time.Now()
|
||||||
} else {
|
var jobStatus string
|
||||||
// Job has pending or running tasks - determine if it's running or still pending
|
|
||||||
var runningTasks int
|
|
||||||
s.db.With(func(conn *sql.DB) error {
|
|
||||||
conn.QueryRow(
|
|
||||||
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
|
|
||||||
jobID, types.TaskStatusRunning,
|
|
||||||
).Scan(&runningTasks)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if runningTasks > 0 {
|
if counts.running > 0 {
|
||||||
// Has running tasks - job is running
|
|
||||||
jobStatus = string(types.JobStatusRunning)
|
jobStatus = string(types.JobStatusRunning)
|
||||||
var startedAt sql.NullTime
|
|
||||||
s.db.With(func(conn *sql.DB) error {
|
s.db.With(func(conn *sql.DB) error {
|
||||||
|
var startedAt sql.NullTime
|
||||||
conn.QueryRow(`SELECT started_at FROM jobs WHERE id = ?`, jobID).Scan(&startedAt)
|
conn.QueryRow(`SELECT started_at FROM jobs WHERE id = ?`, jobID).Scan(&startedAt)
|
||||||
if !startedAt.Valid {
|
if !startedAt.Valid {
|
||||||
conn.Exec(`UPDATE jobs SET started_at = ? WHERE id = ?`, now, jobID)
|
conn.Exec(`UPDATE jobs SET started_at = ? WHERE id = ?`, now, jobID)
|
||||||
@@ -2158,7 +2125,6 @@ func (s *Manager) updateJobStatusFromTasks(jobID int64) {
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
// All tasks are pending - job is pending
|
|
||||||
jobStatus = string(types.JobStatusPending)
|
jobStatus = string(types.JobStatusPending)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2171,19 +2137,17 @@ func (s *Manager) updateJobStatusFromTasks(jobID int64) {
|
|||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to update job %d status to %s: %v", jobID, jobStatus, err)
|
log.Printf("Failed to update job %d status to %s: %v", jobID, jobStatus, err)
|
||||||
} else {
|
return
|
||||||
// Only log if status actually changed
|
}
|
||||||
if currentStatus != jobStatus {
|
if currentStatus != jobStatus {
|
||||||
log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed: %d/%d, pending: %d, running: %d)", jobID, currentStatus, jobStatus, progress, completedTasks, totalTasks, pendingOrRunningTasks-runningTasks, runningTasks)
|
pending := counts.pendingOrRunning - counts.running
|
||||||
|
log.Printf("Updated job %d status from %s to %s (progress: %.1f%%, completed: %d/%d, pending: %d, running: %d)", jobID, currentStatus, jobStatus, progress, counts.completed, counts.total, pending, counts.running)
|
||||||
}
|
}
|
||||||
// Broadcast job update during execution (not just on completion)
|
|
||||||
s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{
|
s.broadcastJobUpdate(jobID, "job_update", map[string]interface{}{
|
||||||
"status": jobStatus,
|
"status": jobStatus,
|
||||||
"progress": progress,
|
"progress": progress,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// broadcastLogToFrontend broadcasts log to connected frontend clients
|
// broadcastLogToFrontend broadcasts log to connected frontend clients
|
||||||
func (s *Manager) broadcastLogToFrontend(taskID int64, logEntry WSLogEntry) {
|
func (s *Manager) broadcastLogToFrontend(taskID int64, logEntry WSLogEntry) {
|
||||||
|
|||||||
21
internal/manager/runners_test.go
Normal file
21
internal/manager/runners_test.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestParseBlenderFrame(t *testing.T) {
|
||||||
|
frame, ok := parseBlenderFrame("Info Fra:2470 Mem:12.00M")
|
||||||
|
if !ok || frame != 2470 {
|
||||||
|
t.Fatalf("parseBlenderFrame() = (%d,%v), want (2470,true)", frame, ok)
|
||||||
|
}
|
||||||
|
if _, ok := parseBlenderFrame("no frame here"); ok {
|
||||||
|
t.Fatal("expected parse to fail for non-frame text")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJobTaskCounts_Progress(t *testing.T) {
|
||||||
|
c := &jobTaskCounts{total: 10, completed: 4}
|
||||||
|
if got := c.progress(); got != 40 {
|
||||||
|
t.Fatalf("progress() = %v, want 40", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
44
internal/runner/api/jobconn_test.go
Normal file
44
internal/runner/api/jobconn_test.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestJobConnection_ConnectAndClose(t *testing.T) {
|
||||||
|
upgrader := websocket.Upgrader{}
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
var msg map[string]interface{}
|
||||||
|
if err := conn.ReadJSON(&msg); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msg["type"] == "auth" {
|
||||||
|
_ = conn.WriteJSON(map[string]string{"type": "auth_ok"})
|
||||||
|
}
|
||||||
|
// Keep open briefly so client can mark connected.
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
jc := NewJobConnection()
|
||||||
|
managerURL := strings.Replace(server.URL, "http://", "http://", 1)
|
||||||
|
if err := jc.Connect(managerURL, "/job/1", "token123"); err != nil {
|
||||||
|
t.Fatalf("Connect failed: %v", err)
|
||||||
|
}
|
||||||
|
if !jc.IsConnected() {
|
||||||
|
t.Fatal("expected connection to be marked connected")
|
||||||
|
}
|
||||||
|
jc.Close()
|
||||||
|
}
|
||||||
|
|
||||||
@@ -241,8 +241,8 @@ func (m *ManagerClient) DownloadContext(contextPath, jobToken string) (io.ReadCl
|
|||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
defer resp.Body.Close()
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
resp.Body.Close()
|
|
||||||
return nil, fmt.Errorf("context download failed with status %d: %s", resp.StatusCode, string(body))
|
return nil, fmt.Errorf("context download failed with status %d: %s", resp.StatusCode, string(body))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -435,8 +435,8 @@ func (m *ManagerClient) DownloadBlender(version string) (io.ReadCloser, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
defer resp.Body.Close()
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
resp.Body.Close()
|
|
||||||
return nil, fmt.Errorf("failed to download blender: status %d, body: %s", resp.StatusCode, string(body))
|
return nil, fmt.Errorf("failed to download blender: status %d, body: %s", resp.StatusCode, string(body))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
45
internal/runner/api/manager_test.go
Normal file
45
internal/runner/api/manager_test.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewManagerClient_TrimsTrailingSlash(t *testing.T) {
|
||||||
|
c := NewManagerClient("http://example.com/")
|
||||||
|
if c.GetBaseURL() != "http://example.com" {
|
||||||
|
t.Fatalf("unexpected base url: %q", c.GetBaseURL())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDoRequest_SetsAuthorizationHeader(t *testing.T) {
|
||||||
|
var authHeader string
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
authHeader = r.Header.Get("Authorization")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]bool{"ok": true})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewManagerClient(ts.URL)
|
||||||
|
c.SetCredentials(1, "abc123")
|
||||||
|
|
||||||
|
resp, err := c.Request(http.MethodGet, "/x", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if authHeader != "Bearer abc123" {
|
||||||
|
t.Fatalf("unexpected Authorization header: %q", authHeader)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequest_RequiresAuth(t *testing.T) {
|
||||||
|
c := NewManagerClient("http://example.com")
|
||||||
|
if _, err := c.Request(http.MethodGet, "/x", nil); err == nil {
|
||||||
|
t.Fatal("expected auth error when api key is missing")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -44,8 +45,12 @@ func (m *Manager) GetBinaryPath(version string) (string, error) {
|
|||||||
if binaryInfo, err := os.Stat(binaryPath); err == nil {
|
if binaryInfo, err := os.Stat(binaryPath); err == nil {
|
||||||
// Verify it's actually a file (not a directory)
|
// Verify it's actually a file (not a directory)
|
||||||
if !binaryInfo.IsDir() {
|
if !binaryInfo.IsDir() {
|
||||||
log.Printf("Found existing Blender %s installation at %s", version, binaryPath)
|
absBinaryPath, err := ResolveBinaryPath(binaryPath)
|
||||||
return binaryPath, nil
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
log.Printf("Found existing Blender %s installation at %s", version, absBinaryPath)
|
||||||
|
return absBinaryPath, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Version folder exists but binary is missing - might be incomplete installation
|
// Version folder exists but binary is missing - might be incomplete installation
|
||||||
@@ -72,20 +77,50 @@ func (m *Manager) GetBinaryPath(version string) (string, error) {
|
|||||||
return "", fmt.Errorf("blender binary not found after extraction")
|
return "", fmt.Errorf("blender binary not found after extraction")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Blender %s installed at %s", version, binaryPath)
|
absBinaryPath, err := ResolveBinaryPath(binaryPath)
|
||||||
return binaryPath, nil
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Blender %s installed at %s", version, absBinaryPath)
|
||||||
|
return absBinaryPath, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetBinaryForJob returns the Blender binary path for a job.
|
// GetBinaryForJob returns the Blender binary path for a job.
|
||||||
// Uses the version from metadata or falls back to system blender.
|
// Uses the version from metadata or falls back to system blender.
|
||||||
func (m *Manager) GetBinaryForJob(version string) (string, error) {
|
func (m *Manager) GetBinaryForJob(version string) (string, error) {
|
||||||
if version == "" {
|
if version == "" {
|
||||||
return "blender", nil // System blender
|
return ResolveBinaryPath("blender")
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.GetBinaryPath(version)
|
return m.GetBinaryPath(version)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResolveBinaryPath resolves a Blender executable to an absolute path.
|
||||||
|
func ResolveBinaryPath(blenderBinary string) (string, error) {
|
||||||
|
if blenderBinary == "" {
|
||||||
|
return "", fmt.Errorf("blender binary path is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(blenderBinary, string(filepath.Separator)) {
|
||||||
|
absPath, err := filepath.Abs(blenderBinary)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to resolve blender binary path %q: %w", blenderBinary, err)
|
||||||
|
}
|
||||||
|
return absPath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedPath, err := exec.LookPath(blenderBinary)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to locate blender binary %q in PATH: %w", blenderBinary, err)
|
||||||
|
}
|
||||||
|
absPath, err := filepath.Abs(resolvedPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to resolve blender binary path %q: %w", resolvedPath, err)
|
||||||
|
}
|
||||||
|
return absPath, nil
|
||||||
|
}
|
||||||
|
|
||||||
// TarballEnv returns a copy of baseEnv with LD_LIBRARY_PATH set so that a
|
// TarballEnv returns a copy of baseEnv with LD_LIBRARY_PATH set so that a
|
||||||
// tarball Blender installation can find its bundled libs (e.g. lib/python3.x).
|
// tarball Blender installation can find its bundled libs (e.g. lib/python3.x).
|
||||||
// If blenderBinary is the system "blender" or has no path component, baseEnv is
|
// If blenderBinary is the system "blender" or has no path component, baseEnv is
|
||||||
|
|||||||
34
internal/runner/blender/binary_test.go
Normal file
34
internal/runner/blender/binary_test.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package blender
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResolveBinaryPath_AbsoluteLikePath(t *testing.T) {
|
||||||
|
got, err := ResolveBinaryPath("./blender")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ResolveBinaryPath failed: %v", err)
|
||||||
|
}
|
||||||
|
if !filepath.IsAbs(got) {
|
||||||
|
t.Fatalf("expected absolute path, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveBinaryPath_Empty(t *testing.T) {
|
||||||
|
if _, err := ResolveBinaryPath(""); err == nil {
|
||||||
|
t.Fatal("expected error for empty blender binary")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTarballEnv_SetsAndExtendsLDLibraryPath(t *testing.T) {
|
||||||
|
bin := filepath.Join(string(os.PathSeparator), "tmp", "blender", "blender")
|
||||||
|
got := TarballEnv(bin, []string{"A=B", "LD_LIBRARY_PATH=/old"})
|
||||||
|
joined := strings.Join(got, "\n")
|
||||||
|
if !strings.Contains(joined, "LD_LIBRARY_PATH=/tmp/blender/lib:/old") {
|
||||||
|
t.Fatalf("expected LD_LIBRARY_PATH to include blender lib, got %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -1,45 +1,116 @@
|
|||||||
// Package blender: GPU backend detection for HIP vs NVIDIA.
|
// Package blender: host GPU backend detection for AMD/NVIDIA/Intel.
|
||||||
package blender
|
package blender
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"jiggablend/pkg/scripts"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// DetectGPUBackends runs a minimal Blender script to detect whether HIP (AMD) and/or
|
// DetectGPUBackends detects whether AMD, NVIDIA, and/or Intel GPUs are available
|
||||||
// NVIDIA (CUDA/OptiX) devices are available. Use this to decide whether to force CPU
|
// using host-level hardware probing only.
|
||||||
// for Blender < 4.x (only force when HIP is present, since HIP has no official support pre-4).
|
func DetectGPUBackends() (hasAMD, hasNVIDIA, hasIntel bool, ok bool) {
|
||||||
func DetectGPUBackends(blenderBinary, scriptDir string) (hasHIP, hasNVIDIA bool, err error) {
|
return detectGPUBackendsFromHost()
|
||||||
scriptPath := filepath.Join(scriptDir, "detect_gpu_backends.py")
|
|
||||||
if err := os.WriteFile(scriptPath, []byte(scripts.DetectGPUBackends), 0644); err != nil {
|
|
||||||
return false, false, fmt.Errorf("write detection script: %w", err)
|
|
||||||
}
|
}
|
||||||
defer os.Remove(scriptPath)
|
|
||||||
|
|
||||||
env := TarballEnv(blenderBinary, os.Environ())
|
func detectGPUBackendsFromHost() (hasAMD, hasNVIDIA, hasIntel bool, ok bool) {
|
||||||
cmd := exec.Command(blenderBinary, "-b", "--python", scriptPath)
|
if amd, nvidia, intel, found := detectGPUBackendsFromDRM(); found {
|
||||||
cmd.Env = env
|
return amd, nvidia, intel, true
|
||||||
cmd.Dir = scriptDir
|
}
|
||||||
out, err := cmd.CombinedOutput()
|
if amd, nvidia, intel, found := detectGPUBackendsFromLSPCI(); found {
|
||||||
|
return amd, nvidia, intel, true
|
||||||
|
}
|
||||||
|
return false, false, false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func detectGPUBackendsFromDRM() (hasAMD, hasNVIDIA, hasIntel bool, ok bool) {
|
||||||
|
entries, err := os.ReadDir("/sys/class/drm")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, false, fmt.Errorf("run blender detection: %w (output: %s)", err, string(out))
|
return false, false, false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
name := entry.Name()
|
||||||
|
if !isDRMCardNode(name) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
vendorPath := filepath.Join("/sys/class/drm", name, "device", "vendor")
|
||||||
|
vendorRaw, err := os.ReadFile(vendorPath)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
vendor := strings.TrimSpace(strings.ToLower(string(vendorRaw)))
|
||||||
|
switch vendor {
|
||||||
|
case "0x1002":
|
||||||
|
hasAMD = true
|
||||||
|
ok = true
|
||||||
|
case "0x10de":
|
||||||
|
hasNVIDIA = true
|
||||||
|
ok = true
|
||||||
|
case "0x8086":
|
||||||
|
hasIntel = true
|
||||||
|
ok = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return hasAMD, hasNVIDIA, hasIntel, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func isDRMCardNode(name string) bool {
|
||||||
|
if !strings.HasPrefix(name, "card") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.Contains(name, "-") {
|
||||||
|
// Connector entries like card0-DP-1 are not GPU device nodes.
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(name) <= len("card") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, err := strconv.Atoi(strings.TrimPrefix(name, "card"))
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func detectGPUBackendsFromLSPCI() (hasAMD, hasNVIDIA, hasIntel bool, ok bool) {
|
||||||
|
if _, err := exec.LookPath("lspci"); err != nil {
|
||||||
|
return false, false, false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := exec.Command("lspci").CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return false, false, false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
scanner := bufio.NewScanner(strings.NewReader(string(out)))
|
scanner := bufio.NewScanner(strings.NewReader(string(out)))
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := strings.TrimSpace(scanner.Text())
|
line := strings.ToLower(strings.TrimSpace(scanner.Text()))
|
||||||
switch line {
|
if !isGPUControllerLine(line) {
|
||||||
case "HAS_HIP":
|
continue
|
||||||
hasHIP = true
|
}
|
||||||
case "HAS_NVIDIA":
|
|
||||||
|
if strings.Contains(line, "nvidia") {
|
||||||
hasNVIDIA = true
|
hasNVIDIA = true
|
||||||
|
ok = true
|
||||||
|
}
|
||||||
|
if strings.Contains(line, "amd") || strings.Contains(line, "ati") || strings.Contains(line, "radeon") {
|
||||||
|
hasAMD = true
|
||||||
|
ok = true
|
||||||
|
}
|
||||||
|
if strings.Contains(line, "intel") {
|
||||||
|
hasIntel = true
|
||||||
|
ok = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return hasHIP, hasNVIDIA, scanner.Err()
|
|
||||||
|
return hasAMD, hasNVIDIA, hasIntel, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func isGPUControllerLine(line string) bool {
|
||||||
|
return strings.Contains(line, "vga compatible controller") ||
|
||||||
|
strings.Contains(line, "3d controller") ||
|
||||||
|
strings.Contains(line, "display controller")
|
||||||
}
|
}
|
||||||
|
|||||||
32
internal/runner/blender/detect_test.go
Normal file
32
internal/runner/blender/detect_test.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package blender
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestIsDRMCardNode(t *testing.T) {
|
||||||
|
tests := map[string]bool{
|
||||||
|
"card0": true,
|
||||||
|
"card12": true,
|
||||||
|
"card": false,
|
||||||
|
"card0-DP-1": false,
|
||||||
|
"renderD128": false,
|
||||||
|
"foo": false,
|
||||||
|
}
|
||||||
|
for in, want := range tests {
|
||||||
|
if got := isDRMCardNode(in); got != want {
|
||||||
|
t.Fatalf("isDRMCardNode(%q) = %v, want %v", in, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsGPUControllerLine(t *testing.T) {
|
||||||
|
if !isGPUControllerLine("vga compatible controller: nvidia corp") {
|
||||||
|
t.Fatal("expected VGA controller line to match")
|
||||||
|
}
|
||||||
|
if !isGPUControllerLine("3d controller: amd") {
|
||||||
|
t.Fatal("expected 3d controller line to match")
|
||||||
|
}
|
||||||
|
if isGPUControllerLine("audio device: something") {
|
||||||
|
t.Fatal("audio line should not match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
34
internal/runner/blender/logfilter_test.go
Normal file
34
internal/runner/blender/logfilter_test.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package blender
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"jiggablend/pkg/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFilterLog_FiltersNoise(t *testing.T) {
|
||||||
|
cases := []string{
|
||||||
|
"",
|
||||||
|
"--------------------------------------------------------------------",
|
||||||
|
"Failed to add relation foo",
|
||||||
|
"BKE_modifier_set_error",
|
||||||
|
"Depth Type Name",
|
||||||
|
}
|
||||||
|
for _, in := range cases {
|
||||||
|
filtered, level := FilterLog(in)
|
||||||
|
if !filtered {
|
||||||
|
t.Fatalf("expected filtered for %q", in)
|
||||||
|
}
|
||||||
|
if level != types.LogLevelInfo {
|
||||||
|
t.Fatalf("unexpected level for %q: %s", in, level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterLog_KeepsNormalLine(t *testing.T) {
|
||||||
|
filtered, _ := FilterLog("Rendering done.")
|
||||||
|
if filtered {
|
||||||
|
t.Fatal("normal line should not be filtered")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -1,143 +1,19 @@
|
|||||||
package blender
|
package blender
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"compress/gzip"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"os"
|
"jiggablend/pkg/blendfile"
|
||||||
"os/exec"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ParseVersionFromFile parses the Blender version that a .blend file was saved with.
|
// ParseVersionFromFile parses the Blender version that a .blend file was saved with.
|
||||||
// Returns major and minor version numbers.
|
// Returns major and minor version numbers.
|
||||||
|
// Delegates to the shared pkg/blendfile implementation.
|
||||||
func ParseVersionFromFile(blendPath string) (major, minor int, err error) {
|
func ParseVersionFromFile(blendPath string) (major, minor int, err error) {
|
||||||
file, err := os.Open(blendPath)
|
return blendfile.ParseVersionFromFile(blendPath)
|
||||||
if err != nil {
|
|
||||||
return 0, 0, fmt.Errorf("failed to open blend file: %w", err)
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
// Read the first 12 bytes of the blend file header
|
|
||||||
// Format: BLENDER-v<major><minor><patch> or BLENDER_v<major><minor><patch>
|
|
||||||
// The header is: "BLENDER" (7 bytes) + pointer size (1 byte: '-' for 64-bit, '_' for 32-bit)
|
|
||||||
// + endianness (1 byte: 'v' for little-endian, 'V' for big-endian)
|
|
||||||
// + version (3 bytes: e.g., "402" for 4.02)
|
|
||||||
header := make([]byte, 12)
|
|
||||||
n, err := file.Read(header)
|
|
||||||
if err != nil || n < 12 {
|
|
||||||
return 0, 0, fmt.Errorf("failed to read blend file header: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for BLENDER magic
|
|
||||||
if string(header[:7]) != "BLENDER" {
|
|
||||||
// Might be compressed - try to decompress
|
|
||||||
file.Seek(0, 0)
|
|
||||||
return parseCompressedVersion(file)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse version from bytes 9-11 (3 digits)
|
|
||||||
versionStr := string(header[9:12])
|
|
||||||
|
|
||||||
// Version format changed in Blender 3.0
|
|
||||||
// Pre-3.0: "279" = 2.79, "280" = 2.80
|
|
||||||
// 3.0+: "300" = 3.0, "402" = 4.02, "410" = 4.10
|
|
||||||
if len(versionStr) == 3 {
|
|
||||||
// First digit is major version
|
|
||||||
fmt.Sscanf(string(versionStr[0]), "%d", &major)
|
|
||||||
// Next two digits are minor version
|
|
||||||
fmt.Sscanf(versionStr[1:3], "%d", &minor)
|
|
||||||
}
|
|
||||||
|
|
||||||
return major, minor, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseCompressedVersion handles gzip and zstd compressed blend files.
|
|
||||||
func parseCompressedVersion(file *os.File) (major, minor int, err error) {
|
|
||||||
magic := make([]byte, 4)
|
|
||||||
if _, err := file.Read(magic); err != nil {
|
|
||||||
return 0, 0, err
|
|
||||||
}
|
|
||||||
file.Seek(0, 0)
|
|
||||||
|
|
||||||
if magic[0] == 0x1f && magic[1] == 0x8b {
|
|
||||||
// gzip compressed
|
|
||||||
gzReader, err := gzip.NewReader(file)
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, fmt.Errorf("failed to create gzip reader: %w", err)
|
|
||||||
}
|
|
||||||
defer gzReader.Close()
|
|
||||||
|
|
||||||
header := make([]byte, 12)
|
|
||||||
n, err := gzReader.Read(header)
|
|
||||||
if err != nil || n < 12 {
|
|
||||||
return 0, 0, fmt.Errorf("failed to read compressed blend header: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if string(header[:7]) != "BLENDER" {
|
|
||||||
return 0, 0, fmt.Errorf("invalid blend file format")
|
|
||||||
}
|
|
||||||
|
|
||||||
versionStr := string(header[9:12])
|
|
||||||
if len(versionStr) == 3 {
|
|
||||||
fmt.Sscanf(string(versionStr[0]), "%d", &major)
|
|
||||||
fmt.Sscanf(versionStr[1:3], "%d", &minor)
|
|
||||||
}
|
|
||||||
|
|
||||||
return major, minor, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for zstd magic (Blender 3.0+): 0x28 0xB5 0x2F 0xFD
|
|
||||||
if magic[0] == 0x28 && magic[1] == 0xb5 && magic[2] == 0x2f && magic[3] == 0xfd {
|
|
||||||
return parseZstdVersion(file)
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0, 0, fmt.Errorf("unknown blend file format")
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseZstdVersion handles zstd-compressed blend files (Blender 3.0+).
|
|
||||||
// Uses zstd command line tool since Go doesn't have native zstd support.
|
|
||||||
func parseZstdVersion(file *os.File) (major, minor int, err error) {
|
|
||||||
file.Seek(0, 0)
|
|
||||||
|
|
||||||
cmd := exec.Command("zstd", "-d", "-c")
|
|
||||||
cmd.Stdin = file
|
|
||||||
|
|
||||||
stdout, err := cmd.StdoutPipe()
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, fmt.Errorf("failed to create zstd stdout pipe: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := cmd.Start(); err != nil {
|
|
||||||
return 0, 0, fmt.Errorf("failed to start zstd decompression: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read just the header (12 bytes)
|
|
||||||
header := make([]byte, 12)
|
|
||||||
n, readErr := io.ReadFull(stdout, header)
|
|
||||||
|
|
||||||
// Kill the process early - we only need the header
|
|
||||||
cmd.Process.Kill()
|
|
||||||
cmd.Wait()
|
|
||||||
|
|
||||||
if readErr != nil || n < 12 {
|
|
||||||
return 0, 0, fmt.Errorf("failed to read zstd compressed blend header: %v", readErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if string(header[:7]) != "BLENDER" {
|
|
||||||
return 0, 0, fmt.Errorf("invalid blend file format in zstd archive")
|
|
||||||
}
|
|
||||||
|
|
||||||
versionStr := string(header[9:12])
|
|
||||||
if len(versionStr) == 3 {
|
|
||||||
fmt.Sscanf(string(versionStr[0]), "%d", &major)
|
|
||||||
fmt.Sscanf(versionStr[1:3], "%d", &minor)
|
|
||||||
}
|
|
||||||
|
|
||||||
return major, minor, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// VersionString returns a formatted version string like "4.2".
|
// VersionString returns a formatted version string like "4.2".
|
||||||
func VersionString(major, minor int) string {
|
func VersionString(major, minor int) string {
|
||||||
return fmt.Sprintf("%d.%d", major, minor)
|
return fmt.Sprintf("%d.%d", major, minor)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
10
internal/runner/blender/version_test.go
Normal file
10
internal/runner/blender/version_test.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
package blender
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestVersionString(t *testing.T) {
|
||||||
|
if got := VersionString(4, 2); got != "4.2" {
|
||||||
|
t.Fatalf("VersionString() = %q, want %q", got, "4.2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -46,18 +46,22 @@ type Runner struct {
|
|||||||
gpuLockedOut bool
|
gpuLockedOut bool
|
||||||
gpuLockedOutMu sync.RWMutex
|
gpuLockedOutMu sync.RWMutex
|
||||||
|
|
||||||
// hasHIP/hasNVIDIA are set at startup by running latest Blender to detect GPU backends.
|
// hasAMD/hasNVIDIA/hasIntel are set at startup by hardware/Blender GPU backend detection.
|
||||||
// Used to force CPU only for Blender < 4.x when HIP is present (no official HIP support pre-4).
|
// Used to force CPU only for Blender < 4.x when AMD is present (no official HIP support pre-4).
|
||||||
// gpuDetectionFailed is true when detection could not run; we then force CPU for all versions (we could not determine HIP vs NVIDIA).
|
// gpuDetectionFailed is true when detection could not run; we then force CPU for all versions.
|
||||||
gpuBackendMu sync.RWMutex
|
gpuBackendMu sync.RWMutex
|
||||||
hasHIP bool
|
hasAMD bool
|
||||||
hasNVIDIA bool
|
hasNVIDIA bool
|
||||||
|
hasIntel bool
|
||||||
gpuBackendProbed bool
|
gpuBackendProbed bool
|
||||||
gpuDetectionFailed bool
|
gpuDetectionFailed bool
|
||||||
|
|
||||||
|
// forceCPURendering forces CPU rendering for all jobs regardless of metadata/backend detection.
|
||||||
|
forceCPURendering bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new runner.
|
// New creates a new runner.
|
||||||
func New(managerURL, name, hostname string) *Runner {
|
func New(managerURL, name, hostname string, forceCPURendering bool) *Runner {
|
||||||
manager := api.NewManagerClient(managerURL)
|
manager := api.NewManagerClient(managerURL)
|
||||||
|
|
||||||
r := &Runner{
|
r := &Runner{
|
||||||
@@ -67,6 +71,8 @@ func New(managerURL, name, hostname string) *Runner {
|
|||||||
processes: executils.NewProcessTracker(),
|
processes: executils.NewProcessTracker(),
|
||||||
stopChan: make(chan struct{}),
|
stopChan: make(chan struct{}),
|
||||||
processors: make(map[string]tasks.Processor),
|
processors: make(map[string]tasks.Processor),
|
||||||
|
|
||||||
|
forceCPURendering: forceCPURendering,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate fingerprint
|
// Generate fingerprint
|
||||||
@@ -85,17 +91,16 @@ func (r *Runner) CheckRequiredTools() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var cachedCapabilities map[string]interface{} = nil
|
var (
|
||||||
|
cachedCapabilities map[string]interface{}
|
||||||
|
capabilitiesOnce sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
// ProbeCapabilities detects hardware capabilities.
|
// ProbeCapabilities detects hardware capabilities.
|
||||||
func (r *Runner) ProbeCapabilities() map[string]interface{} {
|
func (r *Runner) ProbeCapabilities() map[string]interface{} {
|
||||||
if cachedCapabilities != nil {
|
capabilitiesOnce.Do(func() {
|
||||||
return cachedCapabilities
|
|
||||||
}
|
|
||||||
|
|
||||||
caps := make(map[string]interface{})
|
caps := make(map[string]interface{})
|
||||||
|
|
||||||
// Check for ffmpeg and probe encoding capabilities
|
|
||||||
if err := exec.Command("ffmpeg", "-version").Run(); err == nil {
|
if err := exec.Command("ffmpeg", "-version").Run(); err == nil {
|
||||||
caps["ffmpeg"] = true
|
caps["ffmpeg"] = true
|
||||||
} else {
|
} else {
|
||||||
@@ -103,7 +108,8 @@ func (r *Runner) ProbeCapabilities() map[string]interface{} {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cachedCapabilities = caps
|
cachedCapabilities = caps
|
||||||
return caps
|
})
|
||||||
|
return cachedCapabilities
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register registers the runner with the manager.
|
// Register registers the runner with the manager.
|
||||||
@@ -133,52 +139,66 @@ func (r *Runner) Register(apiKey string) (int64, error) {
|
|||||||
return id, nil
|
return id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DetectAndStoreGPUBackends downloads the latest Blender from the manager (if needed),
|
// DetectAndStoreGPUBackends runs host-level backend detection and stores AMD/NVIDIA/Intel results.
|
||||||
// runs a detection script to see if HIP (AMD) and/or NVIDIA devices are available,
|
// Call after Register. Used so we only force CPU for Blender < 4.x when AMD is present.
|
||||||
// and stores the result. Call after Register. Used so we only force CPU for Blender < 4.x
|
|
||||||
// when the runner has HIP (no official HIP support pre-4); NVIDIA is allowed.
|
|
||||||
func (r *Runner) DetectAndStoreGPUBackends() {
|
func (r *Runner) DetectAndStoreGPUBackends() {
|
||||||
r.gpuBackendMu.Lock()
|
r.gpuBackendMu.Lock()
|
||||||
defer r.gpuBackendMu.Unlock()
|
defer r.gpuBackendMu.Unlock()
|
||||||
if r.gpuBackendProbed {
|
if r.gpuBackendProbed {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
latestVer, err := r.manager.GetLatestBlenderVersion()
|
hasAMD, hasNVIDIA, hasIntel, ok := blender.DetectGPUBackends()
|
||||||
if err != nil {
|
if !ok {
|
||||||
log.Printf("GPU backend detection failed (could not get latest Blender version: %v). All jobs will use CPU because we could not determine HIP vs NVIDIA.", err)
|
log.Printf("GPU backend detection failed (host probe unavailable). All jobs will use CPU because backend availability is unknown.")
|
||||||
r.gpuBackendProbed = true
|
r.gpuBackendProbed = true
|
||||||
r.gpuDetectionFailed = true
|
r.gpuDetectionFailed = true
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
binaryPath, err := r.blender.GetBinaryPath(latestVer)
|
|
||||||
if err != nil {
|
detectedTypes := 0
|
||||||
log.Printf("GPU backend detection failed (could not get Blender binary: %v). All jobs will use CPU because we could not determine HIP vs NVIDIA.", err)
|
if hasAMD {
|
||||||
r.gpuBackendProbed = true
|
detectedTypes++
|
||||||
r.gpuDetectionFailed = true
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
hasHIP, hasNVIDIA, err := blender.DetectGPUBackends(binaryPath, r.workspace.BaseDir())
|
if hasNVIDIA {
|
||||||
if err != nil {
|
detectedTypes++
|
||||||
log.Printf("GPU backend detection failed (script error: %v). All jobs will use CPU because we could not determine HIP vs NVIDIA.", err)
|
|
||||||
r.gpuBackendProbed = true
|
|
||||||
r.gpuDetectionFailed = true
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
r.hasHIP = hasHIP
|
if hasIntel {
|
||||||
|
detectedTypes++
|
||||||
|
}
|
||||||
|
if detectedTypes > 1 {
|
||||||
|
log.Printf("mixed GPU vendors detected (AMD=%v NVIDIA=%v INTEL=%v): multi-vendor setups may not work reliably, but runner will continue with GPU enabled", hasAMD, hasNVIDIA, hasIntel)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.hasAMD = hasAMD
|
||||||
r.hasNVIDIA = hasNVIDIA
|
r.hasNVIDIA = hasNVIDIA
|
||||||
|
r.hasIntel = hasIntel
|
||||||
r.gpuBackendProbed = true
|
r.gpuBackendProbed = true
|
||||||
r.gpuDetectionFailed = false
|
r.gpuDetectionFailed = false
|
||||||
log.Printf("GPU backend detection: HIP=%v NVIDIA=%v (Blender < 4.x will force CPU only when HIP is present)", hasHIP, hasNVIDIA)
|
log.Printf("GPU backend detection: AMD=%v NVIDIA=%v INTEL=%v (Blender < 4.x will force CPU only when AMD is present)", hasAMD, hasNVIDIA, hasIntel)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasHIP returns whether the runner detected HIP (AMD) devices. Used to force CPU for Blender < 4.x only when HIP is present.
|
// HasAMD returns whether the runner detected AMD devices. Used to force CPU for Blender < 4.x only when AMD is present.
|
||||||
func (r *Runner) HasHIP() bool {
|
func (r *Runner) HasAMD() bool {
|
||||||
r.gpuBackendMu.RLock()
|
r.gpuBackendMu.RLock()
|
||||||
defer r.gpuBackendMu.RUnlock()
|
defer r.gpuBackendMu.RUnlock()
|
||||||
return r.hasHIP
|
return r.hasAMD
|
||||||
}
|
}
|
||||||
|
|
||||||
// GPUDetectionFailed returns true when startup GPU backend detection could not run or failed. When true, all jobs use CPU because we could not determine HIP vs NVIDIA.
|
// HasNVIDIA returns whether the runner detected NVIDIA GPUs.
|
||||||
|
func (r *Runner) HasNVIDIA() bool {
|
||||||
|
r.gpuBackendMu.RLock()
|
||||||
|
defer r.gpuBackendMu.RUnlock()
|
||||||
|
return r.hasNVIDIA
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasIntel returns whether the runner detected Intel GPUs (e.g. Arc).
|
||||||
|
func (r *Runner) HasIntel() bool {
|
||||||
|
r.gpuBackendMu.RLock()
|
||||||
|
defer r.gpuBackendMu.RUnlock()
|
||||||
|
return r.hasIntel
|
||||||
|
}
|
||||||
|
|
||||||
|
// GPUDetectionFailed returns true when startup GPU backend detection could not run or failed. When true, all jobs use CPU because backend availability is unknown.
|
||||||
func (r *Runner) GPUDetectionFailed() bool {
|
func (r *Runner) GPUDetectionFailed() bool {
|
||||||
r.gpuBackendMu.RLock()
|
r.gpuBackendMu.RLock()
|
||||||
defer r.gpuBackendMu.RUnlock()
|
defer r.gpuBackendMu.RUnlock()
|
||||||
@@ -305,8 +325,11 @@ func (r *Runner) executeJob(job *api.NextJobResponse) (err error) {
|
|||||||
r.encoder,
|
r.encoder,
|
||||||
r.processes,
|
r.processes,
|
||||||
r.IsGPULockedOut(),
|
r.IsGPULockedOut(),
|
||||||
r.HasHIP(),
|
r.HasAMD(),
|
||||||
|
r.HasNVIDIA(),
|
||||||
|
r.HasIntel(),
|
||||||
r.GPUDetectionFailed(),
|
r.GPUDetectionFailed(),
|
||||||
|
r.forceCPURendering,
|
||||||
func() { r.SetGPULockedOut(true) },
|
func() { r.SetGPULockedOut(true) },
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
40
internal/runner/runner_test.go
Normal file
40
internal/runner/runner_test.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewRunner_InitializesFields(t *testing.T) {
|
||||||
|
r := New("http://localhost:8080", "runner-a", "host-a", false)
|
||||||
|
if r == nil {
|
||||||
|
t.Fatal("New should return a runner")
|
||||||
|
}
|
||||||
|
if r.name != "runner-a" || r.hostname != "host-a" {
|
||||||
|
t.Fatalf("unexpected runner identity: %q %q", r.name, r.hostname)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunner_GPUFlagsSetters(t *testing.T) {
|
||||||
|
r := New("http://localhost:8080", "runner-a", "host-a", false)
|
||||||
|
r.SetGPULockedOut(true)
|
||||||
|
if !r.IsGPULockedOut() {
|
||||||
|
t.Fatal("expected GPU lockout to be true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateFingerprint_PopulatesValue(t *testing.T) {
|
||||||
|
r := New("http://localhost:8080", "runner-a", "host-a", false)
|
||||||
|
r.generateFingerprint()
|
||||||
|
fp := r.GetFingerprint()
|
||||||
|
if fp == "" {
|
||||||
|
t.Fatal("fingerprint should not be empty")
|
||||||
|
}
|
||||||
|
if len(fp) != 64 {
|
||||||
|
t.Fatalf("fingerprint should be sha256 hex, got %q", fp)
|
||||||
|
}
|
||||||
|
if _, err := hex.DecodeString(fp); err != nil {
|
||||||
|
t.Fatalf("fingerprint should be valid hex: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -298,6 +298,9 @@ func (p *EncodeProcessor) Process(ctx *Context) error {
|
|||||||
ctx.Info(line)
|
ctx.Info(line)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
log.Printf("Error reading encode stdout: %v", err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Stream stderr
|
// Stream stderr
|
||||||
@@ -311,6 +314,9 @@ func (p *EncodeProcessor) Process(ctx *Context) error {
|
|||||||
ctx.Warn(line)
|
ctx.Warn(line)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
log.Printf("Error reading encode stderr: %v", err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = cmd.Wait()
|
err = cmd.Wait()
|
||||||
@@ -379,7 +385,7 @@ func (p *EncodeProcessor) Process(ctx *Context) error {
|
|||||||
func detectAlphaChannel(ctx *Context, filePath string) bool {
|
func detectAlphaChannel(ctx *Context, filePath string) bool {
|
||||||
// Use ffprobe to check pixel format and stream properties
|
// Use ffprobe to check pixel format and stream properties
|
||||||
// EXR files with alpha will have formats like gbrapf32le (RGBA) vs gbrpf32le (RGB)
|
// EXR files with alpha will have formats like gbrapf32le (RGBA) vs gbrpf32le (RGB)
|
||||||
cmd := exec.Command("ffprobe",
|
cmd := execCommand("ffprobe",
|
||||||
"-v", "error",
|
"-v", "error",
|
||||||
"-select_streams", "v:0",
|
"-select_streams", "v:0",
|
||||||
"-show_entries", "stream=pix_fmt:stream=codec_name",
|
"-show_entries", "stream=pix_fmt:stream=codec_name",
|
||||||
@@ -412,7 +418,7 @@ func detectAlphaChannel(ctx *Context, filePath string) bool {
|
|||||||
// detectHDR checks if an EXR file contains HDR content using ffprobe
|
// detectHDR checks if an EXR file contains HDR content using ffprobe
|
||||||
func detectHDR(ctx *Context, filePath string) bool {
|
func detectHDR(ctx *Context, filePath string) bool {
|
||||||
// First, check if the pixel format supports HDR (32-bit float)
|
// First, check if the pixel format supports HDR (32-bit float)
|
||||||
cmd := exec.Command("ffprobe",
|
cmd := execCommand("ffprobe",
|
||||||
"-v", "error",
|
"-v", "error",
|
||||||
"-select_streams", "v:0",
|
"-select_streams", "v:0",
|
||||||
"-show_entries", "stream=pix_fmt",
|
"-show_entries", "stream=pix_fmt",
|
||||||
@@ -440,7 +446,7 @@ func detectHDR(ctx *Context, filePath string) bool {
|
|||||||
// For 32-bit float EXR, sample pixels to check if values exceed SDR range (> 1.0)
|
// For 32-bit float EXR, sample pixels to check if values exceed SDR range (> 1.0)
|
||||||
// Use ffmpeg to extract pixel statistics - check max pixel values
|
// Use ffmpeg to extract pixel statistics - check max pixel values
|
||||||
// This is more efficient than sampling individual pixels
|
// This is more efficient than sampling individual pixels
|
||||||
cmd = exec.Command("ffmpeg",
|
cmd = execCommand("ffmpeg",
|
||||||
"-v", "error",
|
"-v", "error",
|
||||||
"-i", filePath,
|
"-i", filePath,
|
||||||
"-vf", "signalstats",
|
"-vf", "signalstats",
|
||||||
@@ -483,7 +489,7 @@ func detectHDRBySampling(ctx *Context, filePath string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, region := range sampleRegions {
|
for _, region := range sampleRegions {
|
||||||
cmd := exec.Command("ffmpeg",
|
cmd := execCommand("ffmpeg",
|
||||||
"-v", "error",
|
"-v", "error",
|
||||||
"-i", filePath,
|
"-i", filePath,
|
||||||
"-vf", fmt.Sprintf("%s,scale=1:1", region),
|
"-vf", fmt.Sprintf("%s,scale=1:1", region),
|
||||||
|
|||||||
120
internal/runner/tasks/encode_test.go
Normal file
120
internal/runner/tasks/encode_test.go
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
package tasks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"math"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFloat32FromBytes(t *testing.T) {
|
||||||
|
got := float32FromBytes([]byte{0x00, 0x00, 0x80, 0x3f}) // 1.0 little-endian
|
||||||
|
if got != 1.0 {
|
||||||
|
t.Fatalf("float32FromBytes() = %v, want 1.0", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMax(t *testing.T) {
|
||||||
|
if got := max(1, 2); got != 2 {
|
||||||
|
t.Fatalf("max() = %v, want 2", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractFrameNumber(t *testing.T) {
|
||||||
|
if got := extractFrameNumber("render_0042.png"); got != 42 {
|
||||||
|
t.Fatalf("extractFrameNumber() = %d, want 42", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckFFmpegSizeError(t *testing.T) {
|
||||||
|
err := checkFFmpegSizeError("hardware does not support encoding at size ... constraints: width 128-4096 height 128-4096")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected a size error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetectAlphaChannel_UsesExecSeam(t *testing.T) {
|
||||||
|
orig := execCommand
|
||||||
|
execCommand = fakeExecCommand
|
||||||
|
defer func() { execCommand = orig }()
|
||||||
|
|
||||||
|
if !detectAlphaChannel(&Context{}, "/tmp/frame.exr") {
|
||||||
|
t.Fatal("expected alpha channel detection via mocked ffprobe output")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetectHDR_UsesExecSeam(t *testing.T) {
|
||||||
|
orig := execCommand
|
||||||
|
execCommand = fakeExecCommand
|
||||||
|
defer func() { execCommand = orig }()
|
||||||
|
|
||||||
|
if !detectHDR(&Context{}, "/tmp/frame.exr") {
|
||||||
|
t.Fatal("expected HDR detection via mocked ffmpeg sampling output")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fakeExecCommand(command string, args ...string) *exec.Cmd {
|
||||||
|
cs := []string{"-test.run=TestExecHelperProcess", "--", command}
|
||||||
|
cs = append(cs, args...)
|
||||||
|
cmd := exec.Command(os.Args[0], cs...)
|
||||||
|
cmd.Env = append(os.Environ(), "GO_WANT_HELPER_PROCESS=1")
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecHelperProcess(t *testing.T) {
|
||||||
|
if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idx := 0
|
||||||
|
for i, a := range os.Args {
|
||||||
|
if a == "--" {
|
||||||
|
idx = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if idx == 0 || idx+1 >= len(os.Args) {
|
||||||
|
os.Exit(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmdName := os.Args[idx+1]
|
||||||
|
cmdArgs := os.Args[idx+2:]
|
||||||
|
|
||||||
|
switch cmdName {
|
||||||
|
case "ffprobe":
|
||||||
|
if containsArg(cmdArgs, "stream=pix_fmt:stream=codec_name") {
|
||||||
|
_, _ = os.Stdout.WriteString("pix_fmt=gbrapf32le\ncodec_name=exr\n")
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
_, _ = os.Stdout.WriteString("gbrpf32le\n")
|
||||||
|
os.Exit(0)
|
||||||
|
case "ffmpeg":
|
||||||
|
if containsArg(cmdArgs, "signalstats") {
|
||||||
|
_, _ = os.Stderr.WriteString("signalstats failed")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if containsArg(cmdArgs, "rawvideo") {
|
||||||
|
buf := make([]byte, 12)
|
||||||
|
binary.LittleEndian.PutUint32(buf[0:4], math.Float32bits(1.5))
|
||||||
|
binary.LittleEndian.PutUint32(buf[4:8], math.Float32bits(0.2))
|
||||||
|
binary.LittleEndian.PutUint32(buf[8:12], math.Float32bits(0.1))
|
||||||
|
_, _ = os.Stdout.Write(buf)
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
os.Exit(0)
|
||||||
|
default:
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsArg(args []string, target string) bool {
|
||||||
|
for _, a := range args {
|
||||||
|
if strings.Contains(a, target) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
7
internal/runner/tasks/exec_seams.go
Normal file
7
internal/runner/tasks/exec_seams.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
package tasks
|
||||||
|
|
||||||
|
import "os/exec"
|
||||||
|
|
||||||
|
// execCommand is a seam for process execution in tests.
|
||||||
|
var execCommand = exec.Command
|
||||||
|
|
||||||
@@ -11,8 +11,6 @@ import (
|
|||||||
"jiggablend/pkg/executils"
|
"jiggablend/pkg/executils"
|
||||||
"jiggablend/pkg/types"
|
"jiggablend/pkg/types"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -43,19 +41,25 @@ type Context struct {
|
|||||||
|
|
||||||
// GPULockedOut is set when the runner has detected a GPU error (e.g. HIP) and disables GPU for all jobs.
|
// GPULockedOut is set when the runner has detected a GPU error (e.g. HIP) and disables GPU for all jobs.
|
||||||
GPULockedOut bool
|
GPULockedOut bool
|
||||||
// HasHIP is true when the runner detected HIP (AMD) devices at startup. Used to force CPU for Blender < 4.x only when HIP is present.
|
// HasAMD is true when the runner detected AMD devices at startup.
|
||||||
HasHIP bool
|
HasAMD bool
|
||||||
// GPUDetectionFailed is true when startup GPU backend detection could not run; we force CPU for all versions (could not determine HIP vs NVIDIA).
|
// HasNVIDIA is true when the runner detected NVIDIA GPUs at startup.
|
||||||
|
HasNVIDIA bool
|
||||||
|
// HasIntel is true when the runner detected Intel GPUs (e.g. Arc) at startup.
|
||||||
|
HasIntel bool
|
||||||
|
// GPUDetectionFailed is true when startup GPU backend detection could not run; we force CPU for all versions (backend availability unknown).
|
||||||
GPUDetectionFailed bool
|
GPUDetectionFailed bool
|
||||||
// OnGPUError is called when a GPU error line is seen in render logs; typically sets runner GPU lockout.
|
// OnGPUError is called when a GPU error line is seen in render logs; typically sets runner GPU lockout.
|
||||||
OnGPUError func()
|
OnGPUError func()
|
||||||
|
// ForceCPURendering is a runner-level override that forces CPU rendering for all jobs.
|
||||||
|
ForceCPURendering bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrJobCancelled indicates the manager-side job was cancelled during execution.
|
// ErrJobCancelled indicates the manager-side job was cancelled during execution.
|
||||||
var ErrJobCancelled = errors.New("job cancelled")
|
var ErrJobCancelled = errors.New("job cancelled")
|
||||||
|
|
||||||
// NewContext creates a new task context. frameEnd should be >= frame; if 0 or less than frame, it is treated as single-frame (frameEnd = frame).
|
// NewContext creates a new task context. frameEnd should be >= frame; if 0 or less than frame, it is treated as single-frame (frameEnd = frame).
|
||||||
// gpuLockedOut is the runner's current GPU lockout state; hasHIP means the runner has HIP (AMD) devices (force CPU for Blender < 4.x only when true); gpuDetectionFailed means detection failed at startup (force CPU for all versions—could not determine HIP vs NVIDIA); onGPUError is called when a GPU error is detected in logs (may be nil).
|
// gpuLockedOut is the runner's current GPU lockout state; gpuDetectionFailed means detection failed at startup (force CPU for all versions); onGPUError is called when a GPU error is detected in logs (may be nil).
|
||||||
func NewContext(
|
func NewContext(
|
||||||
taskID, jobID int64,
|
taskID, jobID int64,
|
||||||
jobName string,
|
jobName string,
|
||||||
@@ -71,8 +75,11 @@ func NewContext(
|
|||||||
encoder *encoding.Selector,
|
encoder *encoding.Selector,
|
||||||
processes *executils.ProcessTracker,
|
processes *executils.ProcessTracker,
|
||||||
gpuLockedOut bool,
|
gpuLockedOut bool,
|
||||||
hasHIP bool,
|
hasAMD bool,
|
||||||
|
hasNVIDIA bool,
|
||||||
|
hasIntel bool,
|
||||||
gpuDetectionFailed bool,
|
gpuDetectionFailed bool,
|
||||||
|
forceCPURendering bool,
|
||||||
onGPUError func(),
|
onGPUError func(),
|
||||||
) *Context {
|
) *Context {
|
||||||
if frameEnd < frameStart {
|
if frameEnd < frameStart {
|
||||||
@@ -95,8 +102,11 @@ func NewContext(
|
|||||||
Encoder: encoder,
|
Encoder: encoder,
|
||||||
Processes: processes,
|
Processes: processes,
|
||||||
GPULockedOut: gpuLockedOut,
|
GPULockedOut: gpuLockedOut,
|
||||||
HasHIP: hasHIP,
|
HasAMD: hasAMD,
|
||||||
|
HasNVIDIA: hasNVIDIA,
|
||||||
|
HasIntel: hasIntel,
|
||||||
GPUDetectionFailed: gpuDetectionFailed,
|
GPUDetectionFailed: gpuDetectionFailed,
|
||||||
|
ForceCPURendering: forceCPURendering,
|
||||||
OnGPUError: onGPUError,
|
OnGPUError: onGPUError,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -179,23 +189,18 @@ func (c *Context) ShouldEnableExecution() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ShouldForceCPU returns true if GPU should be disabled and CPU rendering forced
|
// ShouldForceCPU returns true if GPU should be disabled and CPU rendering forced
|
||||||
// (runner GPU lockout, GPU detection failed at startup for any version, metadata force_cpu,
|
// (runner GPU lockout, GPU detection failed at startup, or metadata force_cpu).
|
||||||
// or Blender < 4.x when the runner has HIP).
|
|
||||||
func (c *Context) ShouldForceCPU() bool {
|
func (c *Context) ShouldForceCPU() bool {
|
||||||
|
if c.ForceCPURendering {
|
||||||
|
return true
|
||||||
|
}
|
||||||
if c.GPULockedOut {
|
if c.GPULockedOut {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
// Detection failed at startup: we could not determine HIP vs NVIDIA, so force CPU for all versions.
|
// Detection failed at startup: backend availability unknown, so force CPU for all versions.
|
||||||
if c.GPUDetectionFailed {
|
if c.GPUDetectionFailed {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
v := c.GetBlenderVersion()
|
|
||||||
major := parseBlenderMajor(v)
|
|
||||||
isPre4 := v != "" && major >= 0 && major < 4
|
|
||||||
// Blender < 4.x: force CPU when runner has HIP (no official HIP support).
|
|
||||||
if isPre4 && c.HasHIP {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if c.Metadata != nil && c.Metadata.RenderSettings.EngineSettings != nil {
|
if c.Metadata != nil && c.Metadata.RenderSettings.EngineSettings != nil {
|
||||||
if v, ok := c.Metadata.RenderSettings.EngineSettings["force_cpu"]; ok {
|
if v, ok := c.Metadata.RenderSettings.EngineSettings["force_cpu"]; ok {
|
||||||
if b, ok := v.(bool); ok && b {
|
if b, ok := v.(bool); ok && b {
|
||||||
@@ -206,21 +211,6 @@ func (c *Context) ShouldForceCPU() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseBlenderMajor returns the major version number from a string like "4.2.3" or "3.6".
|
|
||||||
// Returns -1 if the version cannot be parsed.
|
|
||||||
func parseBlenderMajor(version string) int {
|
|
||||||
version = strings.TrimSpace(version)
|
|
||||||
if version == "" {
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
parts := strings.SplitN(version, ".", 2)
|
|
||||||
major, err := strconv.Atoi(parts[0])
|
|
||||||
if err != nil {
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
return major
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsJobCancelled checks whether the manager marked this job as cancelled.
|
// IsJobCancelled checks whether the manager marked this job as cancelled.
|
||||||
func (c *Context) IsJobCancelled() (bool, error) {
|
func (c *Context) IsJobCancelled() (bool, error) {
|
||||||
if c.Manager == nil {
|
if c.Manager == nil {
|
||||||
|
|||||||
42
internal/runner/tasks/processor_test.go
Normal file
42
internal/runner/tasks/processor_test.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package tasks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"jiggablend/pkg/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewContext_NormalizesFrameEnd(t *testing.T) {
|
||||||
|
ctx := NewContext(1, 2, "job", 10, 1, "render", "/tmp", "tok", nil, nil, nil, nil, nil, nil, nil, false, false, false, false, false, false, nil)
|
||||||
|
if ctx.FrameEnd != 10 {
|
||||||
|
t.Fatalf("expected FrameEnd to be normalized to Frame, got %d", ctx.FrameEnd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContext_GetOutputFormat_Default(t *testing.T) {
|
||||||
|
ctx := &Context{}
|
||||||
|
if got := ctx.GetOutputFormat(); got != "PNG" {
|
||||||
|
t.Fatalf("GetOutputFormat() = %q, want PNG", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContext_ShouldForceCPU(t *testing.T) {
|
||||||
|
ctx := &Context{ForceCPURendering: true}
|
||||||
|
if !ctx.ShouldForceCPU() {
|
||||||
|
t.Fatal("expected force cpu when runner-level flag is set")
|
||||||
|
}
|
||||||
|
|
||||||
|
force := true
|
||||||
|
ctx = &Context{Metadata: &types.BlendMetadata{RenderSettings: types.RenderSettings{EngineSettings: map[string]interface{}{"force_cpu": force}}}}
|
||||||
|
if !ctx.ShouldForceCPU() {
|
||||||
|
t.Fatal("expected force cpu when metadata requests it")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestErrJobCancelled_IsSentinel(t *testing.T) {
|
||||||
|
if !errors.Is(ErrJobCancelled, ErrJobCancelled) {
|
||||||
|
t.Fatal("sentinel error should be self-identical")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -88,6 +88,11 @@ func (p *RenderProcessor) Process(ctx *Context) error {
|
|||||||
ctx.Info("No Blender version specified, using system blender")
|
ctx.Info("No Blender version specified, using system blender")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
blenderBinary, err = blender.ResolveBinaryPath(blenderBinary)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to resolve blender binary: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Create output directory
|
// Create output directory
|
||||||
outputDir := filepath.Join(ctx.WorkDir, "output")
|
outputDir := filepath.Join(ctx.WorkDir, "output")
|
||||||
if err := os.MkdirAll(outputDir, 0755); err != nil {
|
if err := os.MkdirAll(outputDir, 0755); err != nil {
|
||||||
@@ -104,13 +109,10 @@ func (p *RenderProcessor) Process(ctx *Context) error {
|
|||||||
renderFormat := "EXR"
|
renderFormat := "EXR"
|
||||||
|
|
||||||
if ctx.ShouldForceCPU() {
|
if ctx.ShouldForceCPU() {
|
||||||
v := ctx.GetBlenderVersion()
|
if ctx.ForceCPURendering {
|
||||||
major := parseBlenderMajor(v)
|
ctx.Info("Runner compatibility flag is enabled: forcing CPU rendering for this job")
|
||||||
isPre4 := v != "" && major >= 0 && major < 4
|
} else if ctx.GPUDetectionFailed {
|
||||||
if ctx.GPUDetectionFailed {
|
ctx.Info("GPU backend detection failed at startup—we could not determine available GPU backends, so rendering will use CPU to avoid compatibility issues")
|
||||||
ctx.Info("GPU backend detection failed at startup—we could not determine whether this machine has HIP (AMD) or NVIDIA GPUs, so rendering will use CPU to avoid compatibility issues")
|
|
||||||
} else if isPre4 && ctx.HasHIP {
|
|
||||||
ctx.Info("Blender < 4.x has no official HIP support: using CPU rendering only")
|
|
||||||
} else {
|
} else {
|
||||||
ctx.Info("GPU lockout active: using CPU rendering only")
|
ctx.Info("GPU lockout active: using CPU rendering only")
|
||||||
}
|
}
|
||||||
@@ -205,8 +207,16 @@ func (p *RenderProcessor) createRenderScript(ctx *Context, renderFormat string)
|
|||||||
|
|
||||||
func (p *RenderProcessor) runBlender(ctx *Context, blenderBinary, blendFile, outputDir, renderFormat, blenderHome string) error {
|
func (p *RenderProcessor) runBlender(ctx *Context, blenderBinary, blendFile, outputDir, renderFormat, blenderHome string) error {
|
||||||
scriptPath := filepath.Join(ctx.WorkDir, "enable_gpu.py")
|
scriptPath := filepath.Join(ctx.WorkDir, "enable_gpu.py")
|
||||||
|
blendFileAbs, err := filepath.Abs(blendFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to resolve blend file path: %w", err)
|
||||||
|
}
|
||||||
|
scriptPathAbs, err := filepath.Abs(scriptPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to resolve blender script path: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
args := []string{"-b", blendFile, "--python", scriptPath}
|
args := []string{"-b", blendFileAbs, "--python", scriptPathAbs}
|
||||||
if ctx.ShouldEnableExecution() {
|
if ctx.ShouldEnableExecution() {
|
||||||
args = append(args, "--enable-autoexec")
|
args = append(args, "--enable-autoexec")
|
||||||
}
|
}
|
||||||
@@ -223,7 +233,7 @@ func (p *RenderProcessor) runBlender(ctx *Context, blenderBinary, blendFile, out
|
|||||||
args = append(args, "-f", fmt.Sprintf("%d", ctx.Frame))
|
args = append(args, "-f", fmt.Sprintf("%d", ctx.Frame))
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := exec.Command(blenderBinary, args...)
|
cmd := execCommand(blenderBinary, args...)
|
||||||
cmd.Dir = ctx.WorkDir
|
cmd.Dir = ctx.WorkDir
|
||||||
|
|
||||||
// Set up environment: LD_LIBRARY_PATH for tarball Blender, then custom HOME
|
// Set up environment: LD_LIBRARY_PATH for tarball Blender, then custom HOME
|
||||||
@@ -274,6 +284,9 @@ func (p *RenderProcessor) runBlender(ctx *Context, blenderBinary, blendFile, out
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
log.Printf("Error reading stdout: %v", err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Stream stderr and watch for GPU error lines
|
// Stream stderr and watch for GPU error lines
|
||||||
@@ -294,6 +307,9 @@ func (p *RenderProcessor) runBlender(ctx *Context, blenderBinary, blendFile, out
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
log.Printf("Error reading stderr: %v", err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Wait for completion
|
// Wait for completion
|
||||||
|
|||||||
28
internal/runner/tasks/render_test.go
Normal file
28
internal/runner/tasks/render_test.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package tasks
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestCheckGPUErrorLine_TriggersCallback(t *testing.T) {
|
||||||
|
p := NewRenderProcessor()
|
||||||
|
triggered := false
|
||||||
|
ctx := &Context{
|
||||||
|
OnGPUError: func() { triggered = true },
|
||||||
|
}
|
||||||
|
p.checkGPUErrorLine(ctx, "Fatal: Illegal address in HIP kernel execution")
|
||||||
|
if !triggered {
|
||||||
|
t.Fatal("expected GPU error callback to be triggered")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckGPUErrorLine_IgnoresNormalLine(t *testing.T) {
|
||||||
|
p := NewRenderProcessor()
|
||||||
|
triggered := false
|
||||||
|
ctx := &Context{
|
||||||
|
OnGPUError: func() { triggered = true },
|
||||||
|
}
|
||||||
|
p.checkGPUErrorLine(ctx, "Render completed successfully")
|
||||||
|
if triggered {
|
||||||
|
t.Fatal("did not expect GPU callback for normal line")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -99,6 +99,11 @@ func ExtractTarStripPrefix(reader io.Reader, destDir string) error {
|
|||||||
|
|
||||||
targetPath := filepath.Join(destDir, name)
|
targetPath := filepath.Join(destDir, name)
|
||||||
|
|
||||||
|
// Sanitize path to prevent directory traversal
|
||||||
|
if !strings.HasPrefix(filepath.Clean(targetPath), filepath.Clean(destDir)+string(os.PathSeparator)) {
|
||||||
|
return fmt.Errorf("invalid file path in tar: %s", header.Name)
|
||||||
|
}
|
||||||
|
|
||||||
switch header.Typeflag {
|
switch header.Typeflag {
|
||||||
case tar.TypeDir:
|
case tar.TypeDir:
|
||||||
if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil {
|
if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil {
|
||||||
|
|||||||
125
internal/runner/workspace/archive_test.go
Normal file
125
internal/runner/workspace/archive_test.go
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
package workspace
|
||||||
|
|
||||||
|
import (
|
||||||
|
"archive/tar"
|
||||||
|
"bytes"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createTarBuffer(files map[string]string) *bytes.Buffer {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
tw := tar.NewWriter(&buf)
|
||||||
|
for name, content := range files {
|
||||||
|
hdr := &tar.Header{
|
||||||
|
Name: name,
|
||||||
|
Mode: 0644,
|
||||||
|
Size: int64(len(content)),
|
||||||
|
}
|
||||||
|
tw.WriteHeader(hdr)
|
||||||
|
tw.Write([]byte(content))
|
||||||
|
}
|
||||||
|
tw.Close()
|
||||||
|
return &buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractTar(t *testing.T) {
|
||||||
|
destDir := t.TempDir()
|
||||||
|
|
||||||
|
buf := createTarBuffer(map[string]string{
|
||||||
|
"hello.txt": "world",
|
||||||
|
"sub/a.txt": "nested",
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := ExtractTar(buf, destDir); err != nil {
|
||||||
|
t.Fatalf("ExtractTar: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(filepath.Join(destDir, "hello.txt"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read hello.txt: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "world" {
|
||||||
|
t.Errorf("hello.txt = %q, want %q", data, "world")
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err = os.ReadFile(filepath.Join(destDir, "sub", "a.txt"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read sub/a.txt: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "nested" {
|
||||||
|
t.Errorf("sub/a.txt = %q, want %q", data, "nested")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractTarStripPrefix(t *testing.T) {
|
||||||
|
destDir := t.TempDir()
|
||||||
|
|
||||||
|
buf := createTarBuffer(map[string]string{
|
||||||
|
"toplevel/": "",
|
||||||
|
"toplevel/foo.txt": "bar",
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := ExtractTarStripPrefix(buf, destDir); err != nil {
|
||||||
|
t.Fatalf("ExtractTarStripPrefix: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(filepath.Join(destDir, "foo.txt"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read foo.txt: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "bar" {
|
||||||
|
t.Errorf("foo.txt = %q, want %q", data, "bar")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractTarStripPrefix_PathTraversal(t *testing.T) {
|
||||||
|
destDir := t.TempDir()
|
||||||
|
|
||||||
|
buf := createTarBuffer(map[string]string{
|
||||||
|
"prefix/../../../etc/passwd": "pwned",
|
||||||
|
})
|
||||||
|
|
||||||
|
err := ExtractTarStripPrefix(buf, destDir)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for path traversal, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractTar_PathTraversal(t *testing.T) {
|
||||||
|
destDir := t.TempDir()
|
||||||
|
|
||||||
|
buf := createTarBuffer(map[string]string{
|
||||||
|
"../../../etc/passwd": "pwned",
|
||||||
|
})
|
||||||
|
|
||||||
|
err := ExtractTar(buf, destDir)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for path traversal, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractTarFile(t *testing.T) {
|
||||||
|
destDir := t.TempDir()
|
||||||
|
tarPath := filepath.Join(t.TempDir(), "archive.tar")
|
||||||
|
|
||||||
|
buf := createTarBuffer(map[string]string{
|
||||||
|
"hello.txt": "world",
|
||||||
|
})
|
||||||
|
if err := os.WriteFile(tarPath, buf.Bytes(), 0644); err != nil {
|
||||||
|
t.Fatalf("write tar file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ExtractTarFile(tarPath, destDir); err != nil {
|
||||||
|
t.Fatalf("ExtractTarFile: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := os.ReadFile(filepath.Join(destDir, "hello.txt"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read extracted file: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != "world" {
|
||||||
|
t.Fatalf("unexpected file content: %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
40
internal/runner/workspace/workspace_test.go
Normal file
40
internal/runner/workspace/workspace_test.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package workspace
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSanitizeName_ReplacesUnsafeChars(t *testing.T) {
|
||||||
|
got := sanitizeName("runner / with\\bad:chars")
|
||||||
|
if strings.ContainsAny(got, " /\\:") {
|
||||||
|
t.Fatalf("sanitizeName did not sanitize input: %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindBlendFiles_IgnoresBlendSaveFiles(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "scene.blend"), []byte("x"), 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "scene.blend1"), []byte("x"), 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
files, err := FindBlendFiles(dir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FindBlendFiles failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(files) != 1 || files[0] != "scene.blend" {
|
||||||
|
t.Fatalf("unexpected files: %#v", files)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindFirstBlendFile_ReturnsErrorWhenMissing(t *testing.T) {
|
||||||
|
_, err := FindFirstBlendFile(t.TempDir())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when no blend file exists")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -82,6 +82,9 @@ func (s *Storage) JobPath(jobID int64) string {
|
|||||||
|
|
||||||
// SaveUpload saves an uploaded file
|
// SaveUpload saves an uploaded file
|
||||||
func (s *Storage) SaveUpload(jobID int64, filename string, reader io.Reader) (string, error) {
|
func (s *Storage) SaveUpload(jobID int64, filename string, reader io.Reader) (string, error) {
|
||||||
|
// Sanitize filename to prevent path traversal
|
||||||
|
filename = filepath.Base(filename)
|
||||||
|
|
||||||
jobPath := s.JobPath(jobID)
|
jobPath := s.JobPath(jobID)
|
||||||
if err := os.MkdirAll(jobPath, 0755); err != nil {
|
if err := os.MkdirAll(jobPath, 0755); err != nil {
|
||||||
return "", fmt.Errorf("failed to create job directory: %w", err)
|
return "", fmt.Errorf("failed to create job directory: %w", err)
|
||||||
|
|||||||
95
internal/storage/storage_test.go
Normal file
95
internal/storage/storage_test.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupStorage(t *testing.T) *Storage {
|
||||||
|
t.Helper()
|
||||||
|
dir := t.TempDir()
|
||||||
|
s, err := NewStorage(dir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewStorage: %v", err)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveUpload(t *testing.T) {
|
||||||
|
s := setupStorage(t)
|
||||||
|
path, err := s.SaveUpload(1, "test.blend", strings.NewReader("data"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SaveUpload: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read saved file: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "data" {
|
||||||
|
t.Errorf("got %q, want %q", data, "data")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveUpload_PathTraversal(t *testing.T) {
|
||||||
|
s := setupStorage(t)
|
||||||
|
path, err := s.SaveUpload(1, "../../etc/passwd", strings.NewReader("evil"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SaveUpload: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// filepath.Base strips traversal, so the file should be inside the job dir
|
||||||
|
if !strings.HasPrefix(path, s.JobPath(1)) {
|
||||||
|
t.Errorf("saved file %q escaped job directory %q", path, s.JobPath(1))
|
||||||
|
}
|
||||||
|
|
||||||
|
if filepath.Base(path) != "passwd" {
|
||||||
|
t.Errorf("expected basename 'passwd', got %q", filepath.Base(path))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveOutput(t *testing.T) {
|
||||||
|
s := setupStorage(t)
|
||||||
|
path, err := s.SaveOutput(42, "output.png", strings.NewReader("img"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SaveOutput: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read saved output: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "img" {
|
||||||
|
t.Errorf("got %q, want %q", data, "img")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFile(t *testing.T) {
|
||||||
|
s := setupStorage(t)
|
||||||
|
savedPath, err := s.SaveUpload(1, "readme.txt", strings.NewReader("hello"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SaveUpload: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := s.GetFile(savedPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetFile: %v", err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
buf := make([]byte, 64)
|
||||||
|
n, _ := f.Read(buf)
|
||||||
|
if string(buf[:n]) != "hello" {
|
||||||
|
t.Errorf("got %q, want %q", string(buf[:n]), "hello")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJobPath(t *testing.T) {
|
||||||
|
s := setupStorage(t)
|
||||||
|
path := s.JobPath(99)
|
||||||
|
if !strings.Contains(path, "99") {
|
||||||
|
t.Errorf("JobPath(99) = %q, expected to contain '99'", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
123
pkg/blendfile/version.go
Normal file
123
pkg/blendfile/version.go
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
package blendfile
|
||||||
|
|
||||||
|
import (
|
||||||
|
"compress/gzip"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseVersionFromReader parses the Blender version from a reader.
|
||||||
|
// Returns major and minor version numbers.
|
||||||
|
//
|
||||||
|
// Blend file header layout (12 bytes):
|
||||||
|
//
|
||||||
|
// "BLENDER" (7) + pointer-size (1: '-'=64, '_'=32) + endian (1: 'v'=LE, 'V'=BE)
|
||||||
|
// + version (3 digits, e.g. "402" = 4.02)
|
||||||
|
//
|
||||||
|
// Supports uncompressed, gzip-compressed, and zstd-compressed blend files.
|
||||||
|
func ParseVersionFromReader(r io.ReadSeeker) (major, minor int, err error) {
|
||||||
|
header := make([]byte, 12)
|
||||||
|
n, err := r.Read(header)
|
||||||
|
if err != nil || n < 12 {
|
||||||
|
return 0, 0, fmt.Errorf("failed to read blend file header: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(header[:7]) != "BLENDER" {
|
||||||
|
r.Seek(0, 0)
|
||||||
|
return parseCompressedVersion(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseVersionDigits(header[9:12])
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseVersionFromFile opens a blend file and parses the Blender version.
|
||||||
|
func ParseVersionFromFile(blendPath string) (major, minor int, err error) {
|
||||||
|
file, err := os.Open(blendPath)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("failed to open blend file: %w", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
return ParseVersionFromReader(file)
|
||||||
|
}
|
||||||
|
|
||||||
|
// VersionString returns a formatted version string like "4.2".
|
||||||
|
func VersionString(major, minor int) string {
|
||||||
|
return fmt.Sprintf("%d.%d", major, minor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseVersionDigits(versionBytes []byte) (major, minor int, err error) {
|
||||||
|
if len(versionBytes) != 3 {
|
||||||
|
return 0, 0, fmt.Errorf("expected 3 version digits, got %d", len(versionBytes))
|
||||||
|
}
|
||||||
|
fmt.Sscanf(string(versionBytes[0]), "%d", &major)
|
||||||
|
fmt.Sscanf(string(versionBytes[1:3]), "%d", &minor)
|
||||||
|
return major, minor, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseCompressedVersion(r io.ReadSeeker) (major, minor int, err error) {
|
||||||
|
magic := make([]byte, 4)
|
||||||
|
if _, err := r.Read(magic); err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
r.Seek(0, 0)
|
||||||
|
|
||||||
|
// gzip: 0x1f 0x8b
|
||||||
|
if magic[0] == 0x1f && magic[1] == 0x8b {
|
||||||
|
gzReader, err := gzip.NewReader(r)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||||
|
}
|
||||||
|
defer gzReader.Close()
|
||||||
|
|
||||||
|
header := make([]byte, 12)
|
||||||
|
n, err := gzReader.Read(header)
|
||||||
|
if err != nil || n < 12 {
|
||||||
|
return 0, 0, fmt.Errorf("failed to read compressed blend header: %w", err)
|
||||||
|
}
|
||||||
|
if string(header[:7]) != "BLENDER" {
|
||||||
|
return 0, 0, fmt.Errorf("invalid blend file format")
|
||||||
|
}
|
||||||
|
return parseVersionDigits(header[9:12])
|
||||||
|
}
|
||||||
|
|
||||||
|
// zstd: 0x28 0xB5 0x2F 0xFD
|
||||||
|
if magic[0] == 0x28 && magic[1] == 0xb5 && magic[2] == 0x2f && magic[3] == 0xfd {
|
||||||
|
return parseZstdVersion(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, 0, fmt.Errorf("unknown blend file format")
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseZstdVersion(r io.ReadSeeker) (major, minor int, err error) {
|
||||||
|
r.Seek(0, 0)
|
||||||
|
|
||||||
|
cmd := exec.Command("zstd", "-d", "-c")
|
||||||
|
cmd.Stdin = r
|
||||||
|
|
||||||
|
stdout, err := cmd.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("failed to create zstd stdout pipe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("failed to start zstd decompression: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
header := make([]byte, 12)
|
||||||
|
n, readErr := io.ReadFull(stdout, header)
|
||||||
|
|
||||||
|
cmd.Process.Kill()
|
||||||
|
cmd.Wait()
|
||||||
|
|
||||||
|
if readErr != nil || n < 12 {
|
||||||
|
return 0, 0, fmt.Errorf("failed to read zstd compressed blend header: %v", readErr)
|
||||||
|
}
|
||||||
|
if string(header[:7]) != "BLENDER" {
|
||||||
|
return 0, 0, fmt.Errorf("invalid blend file format in zstd archive")
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseVersionDigits(header[9:12])
|
||||||
|
}
|
||||||
96
pkg/blendfile/version_test.go
Normal file
96
pkg/blendfile/version_test.go
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
package blendfile
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func makeBlendHeader(major, minor int) []byte {
|
||||||
|
header := make([]byte, 12)
|
||||||
|
copy(header[:7], "BLENDER")
|
||||||
|
header[7] = '-'
|
||||||
|
header[8] = 'v'
|
||||||
|
header[9] = byte('0' + major)
|
||||||
|
header[10] = byte('0' + minor/10)
|
||||||
|
header[11] = byte('0' + minor%10)
|
||||||
|
return header
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseVersionFromReader_Uncompressed(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
major int
|
||||||
|
minor int
|
||||||
|
wantMajor int
|
||||||
|
wantMinor int
|
||||||
|
}{
|
||||||
|
{"Blender 4.02", 4, 2, 4, 2},
|
||||||
|
{"Blender 3.06", 3, 6, 3, 6},
|
||||||
|
{"Blender 2.79", 2, 79, 2, 79},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
header := makeBlendHeader(tt.major, tt.minor)
|
||||||
|
r := bytes.NewReader(header)
|
||||||
|
|
||||||
|
major, minor, err := ParseVersionFromReader(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseVersionFromReader: %v", err)
|
||||||
|
}
|
||||||
|
if major != tt.wantMajor || minor != tt.wantMinor {
|
||||||
|
t.Errorf("got %d.%d, want %d.%d", major, minor, tt.wantMajor, tt.wantMinor)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseVersionFromReader_GzipCompressed(t *testing.T) {
|
||||||
|
header := makeBlendHeader(4, 2)
|
||||||
|
// Pad to ensure gzip has enough data for a full read
|
||||||
|
data := make([]byte, 128)
|
||||||
|
copy(data, header)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
gz := gzip.NewWriter(&buf)
|
||||||
|
gz.Write(data)
|
||||||
|
gz.Close()
|
||||||
|
|
||||||
|
r := bytes.NewReader(buf.Bytes())
|
||||||
|
major, minor, err := ParseVersionFromReader(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseVersionFromReader (gzip): %v", err)
|
||||||
|
}
|
||||||
|
if major != 4 || minor != 2 {
|
||||||
|
t.Errorf("got %d.%d, want 4.2", major, minor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseVersionFromReader_InvalidMagic(t *testing.T) {
|
||||||
|
data := []byte("NOT_BLEND_DATA_HERE")
|
||||||
|
r := bytes.NewReader(data)
|
||||||
|
|
||||||
|
_, _, err := ParseVersionFromReader(r)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid magic, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseVersionFromReader_TooShort(t *testing.T) {
|
||||||
|
data := []byte("SHORT")
|
||||||
|
r := bytes.NewReader(data)
|
||||||
|
|
||||||
|
_, _, err := ParseVersionFromReader(r)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for short data, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVersionString(t *testing.T) {
|
||||||
|
got := VersionString(4, 2)
|
||||||
|
want := "4.2"
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("VersionString(4, 2) = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -361,6 +361,9 @@ func RunCommandWithStreaming(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err := scanner.Err(); err != nil && !isBenignPipeReadError(err) {
|
||||||
|
logSender(taskID, types.LogLevelWarn, fmt.Sprintf("stdout read error: %v", err), stepName)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@@ -375,6 +378,9 @@ func RunCommandWithStreaming(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err := scanner.Err(); err != nil && !isBenignPipeReadError(err) {
|
||||||
|
logSender(taskID, types.LogLevelWarn, fmt.Sprintf("stderr read error: %v", err), stepName)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = cmd.Wait()
|
err = cmd.Wait()
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"os/exec"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestIsBenignPipeReadError(t *testing.T) {
|
func TestIsBenignPipeReadError(t *testing.T) {
|
||||||
@@ -30,3 +32,24 @@ func TestIsBenignPipeReadError(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProcessTracker_TrackUntrack(t *testing.T) {
|
||||||
|
pt := NewProcessTracker()
|
||||||
|
cmd := exec.Command("sh", "-c", "sleep 1")
|
||||||
|
pt.Track(1, cmd)
|
||||||
|
if count := pt.Count(); count != 1 {
|
||||||
|
t.Fatalf("Count() = %d, want 1", count)
|
||||||
|
}
|
||||||
|
pt.Untrack(1)
|
||||||
|
if count := pt.Count(); count != 0 {
|
||||||
|
t.Fatalf("Count() = %d, want 0", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunCommandWithTimeout_TimesOut(t *testing.T) {
|
||||||
|
pt := NewProcessTracker()
|
||||||
|
_, err := RunCommandWithTimeout(200*time.Millisecond, "sh", []string{"-c", "sleep 2"}, "", nil, 99, pt)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected timeout error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,6 +11,3 @@ var UnhideObjects string
|
|||||||
//go:embed scripts/render_blender.py.template
|
//go:embed scripts/render_blender.py.template
|
||||||
var RenderBlenderTemplate string
|
var RenderBlenderTemplate string
|
||||||
|
|
||||||
//go:embed scripts/detect_gpu_backends.py
|
|
||||||
var DetectGPUBackends string
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,39 +0,0 @@
|
|||||||
# Minimal script to detect HIP (AMD) and NVIDIA (CUDA/OptiX) backends for Cycles.
|
|
||||||
# Run with: blender -b --python detect_gpu_backends.py
|
|
||||||
# Prints HAS_HIP and/or HAS_NVIDIA to stdout, one per line.
|
|
||||||
import sys
|
|
||||||
|
|
||||||
def main():
|
|
||||||
try:
|
|
||||||
prefs = bpy.context.preferences
|
|
||||||
if not hasattr(prefs, 'addons') or 'cycles' not in prefs.addons:
|
|
||||||
return
|
|
||||||
cprefs = prefs.addons['cycles'].preferences
|
|
||||||
has_hip = False
|
|
||||||
has_nvidia = False
|
|
||||||
for device_type in ('HIP', 'CUDA', 'OPTIX'):
|
|
||||||
try:
|
|
||||||
cprefs.compute_device_type = device_type
|
|
||||||
cprefs.refresh_devices()
|
|
||||||
devs = []
|
|
||||||
if hasattr(cprefs, 'get_devices'):
|
|
||||||
devs = cprefs.get_devices()
|
|
||||||
elif hasattr(cprefs, 'devices') and cprefs.devices:
|
|
||||||
devs = list(cprefs.devices) if hasattr(cprefs.devices, '__iter__') else [cprefs.devices]
|
|
||||||
if devs:
|
|
||||||
if device_type == 'HIP':
|
|
||||||
has_hip = True
|
|
||||||
if device_type in ('CUDA', 'OPTIX'):
|
|
||||||
has_nvidia = True
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
if has_hip:
|
|
||||||
print('HAS_HIP', flush=True)
|
|
||||||
if has_nvidia:
|
|
||||||
print('HAS_NVIDIA', flush=True)
|
|
||||||
except Exception as e:
|
|
||||||
print('ERROR', str(e), file=sys.stderr, flush=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
import bpy
|
|
||||||
main()
|
|
||||||
@@ -209,9 +209,10 @@ if current_engine == 'CYCLES':
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Check all devices and choose the best GPU type
|
# Check all devices and choose the best GPU type.
|
||||||
# Device type preference order (most performant first)
|
# Explicit fallback policy: NVIDIA -> Intel -> AMD -> CPU.
|
||||||
device_type_preference = ['OPTIX', 'CUDA', 'HIP', 'ONEAPI', 'METAL']
|
# (OPTIX/CUDA are NVIDIA, ONEAPI is Intel, HIP/OPENCL are AMD)
|
||||||
|
device_type_preference = ['OPTIX', 'CUDA', 'ONEAPI', 'HIP', 'OPENCL']
|
||||||
gpu_available = False
|
gpu_available = False
|
||||||
best_device_type = None
|
best_device_type = None
|
||||||
best_gpu_devices = []
|
best_gpu_devices = []
|
||||||
@@ -343,16 +344,6 @@ if current_engine == 'CYCLES':
|
|||||||
scene.cycles.use_optix_denoising = True
|
scene.cycles.use_optix_denoising = True
|
||||||
print(f" Enabled OptiX denoising (if OptiX available)")
|
print(f" Enabled OptiX denoising (if OptiX available)")
|
||||||
print(f" CUDA ray tracing active")
|
print(f" CUDA ray tracing active")
|
||||||
elif best_device_type == 'METAL':
|
|
||||||
# MetalRT for Apple Silicon (if available)
|
|
||||||
if hasattr(scene.cycles, 'use_metalrt'):
|
|
||||||
scene.cycles.use_metalrt = True
|
|
||||||
print(f" Enabled MetalRT (Metal Ray Tracing) for faster rendering")
|
|
||||||
elif hasattr(cycles_prefs, 'use_metalrt'):
|
|
||||||
cycles_prefs.use_metalrt = True
|
|
||||||
print(f" Enabled MetalRT (Metal Ray Tracing) for faster rendering")
|
|
||||||
else:
|
|
||||||
print(f" MetalRT not available")
|
|
||||||
elif best_device_type == 'ONEAPI':
|
elif best_device_type == 'ONEAPI':
|
||||||
# Intel oneAPI - Embree might be available
|
# Intel oneAPI - Embree might be available
|
||||||
if hasattr(scene.cycles, 'use_embree'):
|
if hasattr(scene.cycles, 'use_embree'):
|
||||||
|
|||||||
19
pkg/scripts/scripts_test.go
Normal file
19
pkg/scripts/scripts_test.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package scripts
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEmbeddedScripts_ArePresent(t *testing.T) {
|
||||||
|
if strings.TrimSpace(ExtractMetadata) == "" {
|
||||||
|
t.Fatal("ExtractMetadata script should not be empty")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(UnhideObjects) == "" {
|
||||||
|
t.Fatal("UnhideObjects script should not be empty")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(RenderBlenderTemplate) == "" {
|
||||||
|
t.Fatal("RenderBlenderTemplate should not be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
49
pkg/types/types_test.go
Normal file
49
pkg/types/types_test.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestJobJSON_RoundTrip(t *testing.T) {
|
||||||
|
now := time.Now().UTC().Truncate(time.Second)
|
||||||
|
frameStart, frameEnd := 1, 10
|
||||||
|
format := "PNG"
|
||||||
|
job := Job{
|
||||||
|
ID: 42,
|
||||||
|
UserID: 7,
|
||||||
|
JobType: JobTypeRender,
|
||||||
|
Name: "demo",
|
||||||
|
Status: JobStatusPending,
|
||||||
|
Progress: 12.5,
|
||||||
|
FrameStart: &frameStart,
|
||||||
|
FrameEnd: &frameEnd,
|
||||||
|
OutputFormat: &format,
|
||||||
|
BlendMetadata: &BlendMetadata{
|
||||||
|
FrameStart: 1,
|
||||||
|
FrameEnd: 10,
|
||||||
|
RenderSettings: RenderSettings{
|
||||||
|
ResolutionX: 1920,
|
||||||
|
ResolutionY: 1080,
|
||||||
|
FrameRate: 24.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
CreatedAt: now,
|
||||||
|
}
|
||||||
|
|
||||||
|
raw, err := json.Marshal(job)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var out Job
|
||||||
|
if err := json.Unmarshal(raw, &out); err != nil {
|
||||||
|
t.Fatalf("unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if out.ID != job.ID || out.JobType != JobTypeRender || out.BlendMetadata == nil {
|
||||||
|
t.Fatalf("unexpected roundtrip result: %+v", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
19
version/version_test.go
Normal file
19
version/version_test.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package version
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInitDefaults_AreSet(t *testing.T) {
|
||||||
|
if Version == "" {
|
||||||
|
t.Fatal("Version should be initialized")
|
||||||
|
}
|
||||||
|
if Date == "" {
|
||||||
|
t.Fatal("Date should be initialized")
|
||||||
|
}
|
||||||
|
if !strings.Contains(Version, ".") {
|
||||||
|
t.Fatalf("Version should look semantic-ish, got %q", Version)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
30
web/embed_test.go
Normal file
30
web/embed_test.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package web
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetStaticFileSystem_NonNil(t *testing.T) {
|
||||||
|
fs := GetStaticFileSystem()
|
||||||
|
if fs == nil {
|
||||||
|
t.Fatal("static filesystem should not be nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStaticHandler_ServesWithoutPanic(t *testing.T) {
|
||||||
|
h := StaticHandler()
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("GET", "/assets/does-not-exist.txt", nil)
|
||||||
|
h.ServeHTTP(rr, req)
|
||||||
|
if rr.Code == 0 {
|
||||||
|
t.Fatal("handler should write a status code")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetTemplateFS_NonNil(t *testing.T) {
|
||||||
|
if GetTemplateFS() == nil {
|
||||||
|
t.Fatal("template fs should not be nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -169,11 +169,27 @@
|
|||||||
});
|
});
|
||||||
const data = await res.json().catch(() => ({}));
|
const data = await res.json().catch(() => ({}));
|
||||||
if (!res.ok) {
|
if (!res.ok) {
|
||||||
throw new Error(data.error || "Job creation failed");
|
const err = new Error(data.error || "Job creation failed");
|
||||||
|
if (data && typeof data.code === "string") {
|
||||||
|
err.code = data.code;
|
||||||
|
}
|
||||||
|
throw err;
|
||||||
}
|
}
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function resetToUploadStep(message) {
|
||||||
|
sessionID = "";
|
||||||
|
clearInterval(pollTimer);
|
||||||
|
setUploadBusy(false);
|
||||||
|
mainBlendWrapper.classList.add("hidden");
|
||||||
|
metadataPreview.innerHTML = "";
|
||||||
|
configSection.classList.add("hidden");
|
||||||
|
setStep(1);
|
||||||
|
showStatus("Please upload the file again.");
|
||||||
|
showError(message);
|
||||||
|
}
|
||||||
|
|
||||||
async function runSubmission(mainBlendFile) {
|
async function runSubmission(mainBlendFile) {
|
||||||
showError("");
|
showError("");
|
||||||
setStep(1);
|
setStep(1);
|
||||||
@@ -277,6 +293,14 @@
|
|||||||
showError("");
|
showError("");
|
||||||
await submitJobConfig();
|
await submitJobConfig();
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
if (err && err.code === "UPLOAD_SESSION_EXPIRED") {
|
||||||
|
resetToUploadStep(err.message || "Upload session expired. Please upload the file again.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (err && err.code === "UPLOAD_SESSION_NOT_READY") {
|
||||||
|
showError(err.message || "Upload session is still processing. Please wait and try again.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
showError(err.message || "Failed to create job");
|
showError(err.message || "Failed to create job");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
Reference in New Issue
Block a user