its a bit broken

This commit is contained in:
2025-11-25 03:48:28 -06:00
parent a53ea4dce7
commit 690e6b13f8
16 changed files with 1542 additions and 861 deletions

View File

@@ -10,75 +10,115 @@ import (
"jiggablend/pkg/types"
)
// handleGenerateRegistrationToken generates a new registration token
func (s *Server) handleGenerateRegistrationToken(w http.ResponseWriter, r *http.Request) {
// handleGenerateRunnerAPIKey generates a new runner API key
func (s *Server) handleGenerateRunnerAPIKey(w http.ResponseWriter, r *http.Request) {
userID, err := getUserID(r)
if err != nil {
s.respondError(w, http.StatusUnauthorized, err.Error())
return
}
// Default expiration: 24 hours
expiresIn := 24 * time.Hour
var req struct {
ExpiresInHours int `json:"expires_in_hours,omitempty"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
Scope string `json:"scope,omitempty"` // 'manager' or 'user'
}
if r.Body != nil && r.ContentLength > 0 {
if err := json.NewDecoder(r.Body).Decode(&req); err == nil {
if req.ExpiresInHours == 0 {
// 0 hours means infinite expiration
expiresIn = 0
} else if req.ExpiresInHours > 0 {
expiresIn = time.Duration(req.ExpiresInHours) * time.Hour
}
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
return
}
token, err := s.secrets.GenerateRegistrationToken(userID, expiresIn)
if req.Name == "" {
s.respondError(w, http.StatusBadRequest, "API key name is required")
return
}
// Default scope to 'user' if not specified
scope := req.Scope
if scope == "" {
scope = "user"
}
if scope != "manager" && scope != "user" {
s.respondError(w, http.StatusBadRequest, "Scope must be 'manager' or 'user'")
return
}
keyInfo, err := s.secrets.GenerateRunnerAPIKey(userID, req.Name, req.Description, scope)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to generate token: %v", err))
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to generate API key: %v", err))
return
}
response := map[string]interface{}{
"token": token,
}
if expiresIn == 0 {
response["expires_in"] = "infinite"
response["expires_at"] = nil
} else {
response["expires_in"] = expiresIn.String()
response["expires_at"] = time.Now().Add(expiresIn)
"id": keyInfo.ID,
"key": keyInfo.Key,
"name": keyInfo.Name,
"description": keyInfo.Description,
"is_active": keyInfo.IsActive,
"created_at": keyInfo.CreatedAt,
}
s.respondJSON(w, http.StatusCreated, response)
}
// handleListRegistrationTokens lists all registration tokens
func (s *Server) handleListRegistrationTokens(w http.ResponseWriter, r *http.Request) {
tokens, err := s.secrets.ListRegistrationTokens()
// handleListRunnerAPIKeys lists all runner API keys
func (s *Server) handleListRunnerAPIKeys(w http.ResponseWriter, r *http.Request) {
keys, err := s.secrets.ListRunnerAPIKeys()
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to list tokens: %v", err))
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to list API keys: %v", err))
return
}
s.respondJSON(w, http.StatusOK, tokens)
// Convert to response format (hide sensitive hash data)
var response []map[string]interface{}
for _, key := range keys {
item := map[string]interface{}{
"id": key.ID,
"key_prefix": key.Key, // Only show prefix, not full key
"name": key.Name,
"is_active": key.IsActive,
"created_at": key.CreatedAt,
"created_by": key.CreatedBy,
}
if key.Description != nil {
item["description"] = *key.Description
}
response = append(response, item)
}
s.respondJSON(w, http.StatusOK, response)
}
// handleRevokeRegistrationToken revokes a registration token
func (s *Server) handleRevokeRegistrationToken(w http.ResponseWriter, r *http.Request) {
tokenID, err := parseID(r, "id")
// handleRevokeRunnerAPIKey revokes a runner API key
func (s *Server) handleRevokeRunnerAPIKey(w http.ResponseWriter, r *http.Request) {
keyID, err := parseID(r, "id")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
if err := s.secrets.RevokeRegistrationToken(tokenID); err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to revoke token: %v", err))
if err := s.secrets.RevokeRunnerAPIKey(keyID); err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to revoke API key: %v", err))
return
}
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Token revoked"})
s.respondJSON(w, http.StatusOK, map[string]string{"message": "API key revoked"})
}
// handleDeleteRunnerAPIKey deletes a runner API key
func (s *Server) handleDeleteRunnerAPIKey(w http.ResponseWriter, r *http.Request) {
keyID, err := parseID(r, "id")
if err != nil {
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
if err := s.secrets.DeleteRunnerAPIKey(keyID); err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to delete API key: %v", err))
return
}
s.respondJSON(w, http.StatusOK, map[string]string{"message": "API key deleted"})
}
// handleVerifyRunner manually verifies a runner
@@ -136,8 +176,8 @@ func (s *Server) handleDeleteRunner(w http.ResponseWriter, r *http.Request) {
// handleListRunnersAdmin lists all runners with admin details
func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request) {
rows, err := s.db.Query(
`SELECT id, name, hostname, status, last_heartbeat, capabilities,
registration_token, verified, priority, created_at
`SELECT id, name, hostname, status, last_heartbeat, capabilities,
api_key_id, api_key_scope, priority, created_at
FROM runners ORDER BY created_at DESC`,
)
if err != nil {
@@ -149,13 +189,13 @@ func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request)
runners := []map[string]interface{}{}
for rows.Next() {
var runner types.Runner
var registrationToken sql.NullString
var verified bool
var apiKeyID sql.NullInt64
var apiKeyScope string
err := rows.Scan(
&runner.ID, &runner.Name, &runner.Hostname,
&runner.Status, &runner.LastHeartbeat, &runner.Capabilities,
&registrationToken, &verified, &runner.Priority, &runner.CreatedAt,
&apiKeyID, &apiKeyScope, &runner.Priority, &runner.CreatedAt,
)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan runner: %v", err))
@@ -163,16 +203,16 @@ func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request)
}
runners = append(runners, map[string]interface{}{
"id": runner.ID,
"name": runner.Name,
"hostname": runner.Hostname,
"status": runner.Status,
"last_heartbeat": runner.LastHeartbeat,
"capabilities": runner.Capabilities,
"registration_token": registrationToken.String,
"verified": verified,
"priority": runner.Priority,
"created_at": runner.CreatedAt,
"id": runner.ID,
"name": runner.Name,
"hostname": runner.Hostname,
"status": runner.Status,
"last_heartbeat": runner.LastHeartbeat,
"capabilities": runner.Capabilities,
"api_key_id": apiKeyID.Int64,
"api_key_scope": apiKeyScope,
"priority": runner.Priority,
"created_at": runner.CreatedAt,
})
}
@@ -335,7 +375,7 @@ func (s *Server) handleSetRegistrationEnabled(w http.ResponseWriter, r *http.Req
Enabled bool `json:"enabled"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
return
}
@@ -359,7 +399,7 @@ func (s *Server) handleSetUserAdminStatus(w http.ResponseWriter, r *http.Request
IsAdmin bool `json:"is_admin"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
return
}

View File

@@ -25,7 +25,6 @@ import (
authpkg "jiggablend/internal/auth"
"jiggablend/pkg/types"
"github.com/go-chi/chi/v5"
"github.com/gorilla/websocket"
"jiggablend/pkg/scripts"
@@ -62,7 +61,7 @@ func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) {
var req types.CreateJobRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
return
}
@@ -83,7 +82,7 @@ func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) {
s.respondError(w, http.StatusBadRequest, "frame_start and frame_end are required for render jobs")
return
}
if *req.FrameStart < 0 || *req.FrameEnd < *req.FrameStart {
if *req.FrameEnd < *req.FrameStart {
s.respondError(w, http.StatusBadRequest, "Invalid frame range")
return
}
@@ -671,7 +670,7 @@ func (s *Server) handleBatchGetJobs(w http.ResponseWriter, r *http.Request) {
JobIDs []int64 `json:"job_ids"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
return
}
@@ -1677,7 +1676,7 @@ func (s *Server) extractMetadataFromTempContext(contextPath string) (*types.Blen
})
if err != nil || blendFile == "" {
return nil, fmt.Errorf("no .blend file found in context")
return nil, fmt.Errorf("no .blend file found in context - the uploaded context archive must contain at least one .blend file to render")
}
// Use the same extraction script and process as extractMetadataFromContext
@@ -1894,7 +1893,7 @@ func (s *Server) createContextFromDir(sourceDir, destPath string, excludeFiles .
}
if blendFilesAtRoot == 0 {
return "", fmt.Errorf("no .blend file found at root level in context archive")
return "", fmt.Errorf("no .blend file found at root level in context archive - .blend files must be at the root level of the uploaded archive, not in subdirectories")
}
if blendFilesAtRoot > 1 {
return "", fmt.Errorf("multiple .blend files found at root level in context archive (found %d, expected 1)", blendFilesAtRoot)
@@ -2958,7 +2957,7 @@ func (s *Server) handleBatchGetTasks(w http.ResponseWriter, r *http.Request) {
TaskIDs []int64 `json:"task_ids"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
return
}
@@ -3057,10 +3056,9 @@ func (s *Server) handleGetTaskLogs(w http.ResponseWriter, r *http.Request) {
return
}
taskIDStr := chi.URLParam(r, "taskId")
taskID, err := strconv.ParseInt(taskIDStr, 10, 64)
taskID, err := parseID(r, "taskId")
if err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid task ID")
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
@@ -3196,10 +3194,9 @@ func (s *Server) handleGetTaskSteps(w http.ResponseWriter, r *http.Request) {
return
}
taskIDStr := chi.URLParam(r, "taskId")
taskID, err := strconv.ParseInt(taskIDStr, 10, 64)
taskID, err := parseID(r, "taskId")
if err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid task ID")
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
@@ -3304,10 +3301,9 @@ func (s *Server) handleRetryTask(w http.ResponseWriter, r *http.Request) {
return
}
taskIDStr := chi.URLParam(r, "taskId")
taskID, err := strconv.ParseInt(taskIDStr, 10, 64)
taskID, err := parseID(r, "taskId")
if err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid task ID")
s.respondError(w, http.StatusBadRequest, err.Error())
return
}
@@ -3396,10 +3392,9 @@ func (s *Server) handleStreamTaskLogsWebSocket(w http.ResponseWriter, r *http.Re
return
}
taskIDStr := chi.URLParam(r, "taskId")
taskID, err := strconv.ParseInt(taskIDStr, 10, 64)
taskID, err := parseID(r, "taskId")
if err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid task ID")
s.respondError(w, http.StatusBadRequest, err.Error())
return
}

View File

@@ -38,7 +38,7 @@ func (s *Server) handleSubmitMetadata(w http.ResponseWriter, r *http.Request) {
var metadata types.BlendMetadata
if err := json.NewDecoder(r.Body).Decode(&metadata); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid metadata JSON")
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid metadata JSON: %v", err))
return
}
@@ -230,7 +230,7 @@ func (s *Server) extractMetadataFromContext(jobID int64) (*types.BlendMetadata,
}
if blendFile == "" {
return nil, fmt.Errorf("no .blend file found in context")
return nil, fmt.Errorf("no .blend file found in context - the uploaded context archive must contain at least one .blend file for metadata extraction")
}
// Use embedded Python script

View File

@@ -28,41 +28,75 @@ type contextKey string
const runnerIDContextKey contextKey = "runner_id"
// runnerAuthMiddleware verifies runner requests using shared secret header
// runnerAuthMiddleware verifies runner requests using API key
func (s *Server) runnerAuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Get runner ID from query string
// Get API key from header
apiKey := r.Header.Get("Authorization")
if apiKey == "" {
// Try alternative header
apiKey = r.Header.Get("X-API-Key")
}
if apiKey == "" {
s.respondError(w, http.StatusUnauthorized, "API key required")
return
}
// Remove "Bearer " prefix if present
if strings.HasPrefix(apiKey, "Bearer ") {
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
}
// Validate API key and get its ID
apiKeyID, _, err := s.secrets.ValidateRunnerAPIKey(apiKey)
if err != nil {
log.Printf("API key validation failed: %v", err)
s.respondError(w, http.StatusUnauthorized, "invalid API key")
return
}
// Get runner ID from query string or find runner by API key
runnerIDStr := r.URL.Query().Get("runner_id")
if runnerIDStr == "" {
s.respondError(w, http.StatusBadRequest, "runner_id required in query string")
return
}
var runnerID int64
_, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID)
if err != nil {
s.respondError(w, http.StatusBadRequest, "invalid runner_id")
return
}
// Get runner secret
runnerSecret, err := s.secrets.GetRunnerSecret(runnerID)
if err != nil {
log.Printf("Failed to get runner secret for runner %d: %v", runnerID, err)
s.respondError(w, http.StatusUnauthorized, "runner not found or not verified")
return
}
if runnerIDStr != "" {
// Runner ID provided - verify it belongs to this API key
_, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID)
if err != nil {
s.respondError(w, http.StatusBadRequest, "invalid runner_id")
return
}
// Verify shared secret from header
providedSecret := r.Header.Get("X-Runner-Secret")
if providedSecret == "" {
s.respondError(w, http.StatusUnauthorized, "missing secret")
return
}
if providedSecret != runnerSecret {
s.respondError(w, http.StatusUnauthorized, "invalid secret")
return
// For fixed API keys, skip database verification
if apiKeyID != -1 {
// Verify runner exists and uses this API key
var dbAPIKeyID sql.NullInt64
err = s.db.QueryRow("SELECT api_key_id FROM runners WHERE id = ?", runnerID).Scan(&dbAPIKeyID)
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "runner not found")
return
}
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to query runner API key: %v", err))
return
}
if !dbAPIKeyID.Valid || dbAPIKeyID.Int64 != apiKeyID {
s.respondError(w, http.StatusForbidden, "runner does not belong to this API key")
return
}
}
} else {
// No runner ID provided - find the runner for this API key
// For simplicity, assume each API key has one runner
err = s.db.QueryRow("SELECT id FROM runners WHERE api_key_id = ?", apiKeyID).Scan(&runnerID)
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "no runner found for this API key")
return
}
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to query runner by API key: %v", err))
return
}
}
// Add runner ID to context
@@ -72,66 +106,40 @@ func (s *Server) runnerAuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
}
}
// handleRegisterRunner registers a new runner
// Note: Token expiration only affects whether the token can be used for registration.
// Once a runner is registered, it receives its own runner_secret and manager_secret
// and operates independently. The token expiration does not affect registered runners.
// handleRegisterRunner registers a new runner using an API key
func (s *Server) handleRegisterRunner(w http.ResponseWriter, r *http.Request) {
var req struct {
types.RegisterRunnerRequest
RegistrationToken string `json:"registration_token"`
APIKey string `json:"api_key"`
Fingerprint string `json:"fingerprint,omitempty"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
return
}
// Lock to prevent concurrent registrations that could create duplicate runners
s.secrets.RegistrationMu.Lock()
defer s.secrets.RegistrationMu.Unlock()
if req.Name == "" {
s.respondError(w, http.StatusBadRequest, "Runner name is required")
return
}
if req.RegistrationToken == "" {
s.respondError(w, http.StatusBadRequest, "Registration token is required")
if req.APIKey == "" {
s.respondError(w, http.StatusBadRequest, "API key is required")
return
}
// Validate registration token (expiration only affects token usability, not registered runners)
result, err := s.secrets.ValidateRegistrationTokenDetailed(req.RegistrationToken)
// Validate API key
apiKeyID, apiKeyScope, err := s.secrets.ValidateRunnerAPIKey(req.APIKey)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to validate token: %v", err))
return
}
if !result.Valid {
var errorMsg string
switch result.Reason {
case "already_used":
errorMsg = "Registration token has already been used"
case "expired":
errorMsg = "Registration token has expired"
case "not_found":
errorMsg = "Invalid registration token"
default:
errorMsg = "Invalid or expired registration token"
}
s.respondError(w, http.StatusUnauthorized, errorMsg)
return
}
// Get manager secret
managerSecret, err := s.secrets.GetManagerSecret()
if err != nil {
s.respondError(w, http.StatusInternalServerError, "Failed to get manager secret")
return
}
// Generate runner secret (runner will use this for all future authentication, independent of token)
runnerSecret, err := s.secrets.GenerateRunnerSecret()
if err != nil {
s.respondError(w, http.StatusInternalServerError, "Failed to generate runner secret")
s.respondError(w, http.StatusUnauthorized, fmt.Sprintf("Invalid API key: %v", err))
return
}
// For fixed API keys (keyID = -1), skip fingerprint checking
// Set default priority if not provided
priority := 100
if req.Priority != nil {
@@ -140,28 +148,110 @@ func (s *Server) handleRegisterRunner(w http.ResponseWriter, r *http.Request) {
// Register runner
var runnerID int64
// For fixed API keys, don't store api_key_id in database
var dbAPIKeyID interface{}
if apiKeyID == -1 {
dbAPIKeyID = nil // NULL for fixed API keys
} else {
dbAPIKeyID = apiKeyID
}
// Determine fingerprint value
fingerprint := req.Fingerprint
if apiKeyID == -1 || fingerprint == "" {
// For fixed API keys or when no fingerprint provided, generate a unique fingerprint
// to avoid conflicts while still maintaining some uniqueness
fingerprint = fmt.Sprintf("fixed-%s-%d", req.Name, time.Now().UnixNano())
}
// Check fingerprint uniqueness only for non-fixed API keys
if apiKeyID != -1 && req.Fingerprint != "" {
var existingRunnerID int64
var existingAPIKeyID sql.NullInt64
err = s.db.QueryRow(
"SELECT id, api_key_id FROM runners WHERE fingerprint = ?",
req.Fingerprint,
).Scan(&existingRunnerID, &existingAPIKeyID)
if err == nil {
// Runner already exists with this fingerprint
if existingAPIKeyID.Valid && existingAPIKeyID.Int64 == apiKeyID {
// Same API key - update and return existing runner
log.Printf("Runner with fingerprint %s already exists (ID: %d), updating info", req.Fingerprint, existingRunnerID)
_, err = s.db.Exec(
`UPDATE runners SET name = ?, hostname = ?, capabilities = ?, status = ?, last_heartbeat = ? WHERE id = ?`,
req.Name, req.Hostname, req.Capabilities, types.RunnerStatusOnline, time.Now(), existingRunnerID,
)
if err != nil {
log.Printf("Warning: Failed to update existing runner info: %v", err)
}
s.respondJSON(w, http.StatusOK, map[string]interface{}{
"id": existingRunnerID,
"name": req.Name,
"hostname": req.Hostname,
"status": types.RunnerStatusOnline,
"reused": true, // Indicates this was a re-registration
})
return
} else {
// Different API key - reject registration
s.respondError(w, http.StatusConflict, "Runner with this fingerprint already registered with different API key")
return
}
}
// If err is not nil, it means no existing runner with this fingerprint - proceed with new registration
}
// Insert runner
err = s.db.QueryRow(
`INSERT INTO runners (name, hostname, ip_address, status, last_heartbeat, capabilities,
registration_token, runner_secret, manager_secret, verified, priority)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`INSERT INTO runners (name, hostname, ip_address, status, last_heartbeat, capabilities,
api_key_id, api_key_scope, priority, fingerprint)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
RETURNING id`,
req.Name, req.Hostname, "", types.RunnerStatusOnline, time.Now(), req.Capabilities,
req.RegistrationToken, runnerSecret, managerSecret, true, priority,
dbAPIKeyID, apiKeyScope, priority, fingerprint,
).Scan(&runnerID)
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to register runner: %v", err))
return
}
// Return runner info with secrets
log.Printf("Registered new runner %s (ID: %d) with API key ID: %d", req.Name, runnerID, apiKeyID)
// Return runner info
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
"id": runnerID,
"name": req.Name,
"hostname": req.Hostname,
"status": types.RunnerStatusOnline,
"runner_secret": runnerSecret,
"manager_secret": managerSecret,
"verified": true,
"id": runnerID,
"name": req.Name,
"hostname": req.Hostname,
"status": types.RunnerStatusOnline,
})
}
// handleRunnerPing allows runners to validate their secrets and connection
func (s *Server) handleRunnerPing(w http.ResponseWriter, r *http.Request) {
// This endpoint uses runnerAuthMiddleware, so if we get here, secrets are valid
// Get runner ID from context (set by runnerAuthMiddleware)
runnerID, ok := r.Context().Value(runnerIDContextKey).(int64)
if !ok {
s.respondError(w, http.StatusUnauthorized, "runner_id not found in context")
return
}
// Update last heartbeat
_, err := s.db.Exec(
`UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`,
time.Now(), types.RunnerStatusOnline, runnerID,
)
if err != nil {
log.Printf("Warning: Failed to update runner heartbeat: %v", err)
}
s.respondJSON(w, http.StatusOK, map[string]interface{}{
"status": "ok",
"runner_id": runnerID,
"timestamp": time.Now().Unix(),
})
}
@@ -177,7 +267,7 @@ func (s *Server) handleUpdateTaskProgress(w http.ResponseWriter, r *http.Request
Progress float64 `json:"progress"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
return
}
@@ -207,7 +297,7 @@ func (s *Server) handleUpdateTaskStep(w http.ResponseWriter, r *http.Request) {
ErrorMessage string `json:"error_message,omitempty"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
return
}
@@ -354,7 +444,7 @@ func (s *Server) handleUploadFileFromRunner(w http.ResponseWriter, r *http.Reque
err = r.ParseMultipartForm(50 << 30) // 50 GB (for large output files)
if err != nil {
s.respondError(w, http.StatusBadRequest, "Failed to parse form")
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Failed to parse multipart form: %v", err))
return
}
@@ -530,7 +620,7 @@ func (s *Server) handleGetJobMetadataForRunner(w http.ResponseWriter, r *http.Re
var metadata types.BlendMetadata
if err := json.Unmarshal([]byte(blendMetadataJSON.String), &metadata); err != nil {
s.respondError(w, http.StatusInternalServerError, "Failed to parse metadata")
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to parse metadata JSON: %v", err))
return
}
@@ -645,33 +735,66 @@ type WSTaskUpdate struct {
// handleRunnerWebSocket handles WebSocket connections from runners
func (s *Server) handleRunnerWebSocket(w http.ResponseWriter, r *http.Request) {
// Get runner ID and secret from query params
// Get API key from query params or headers
apiKey := r.URL.Query().Get("api_key")
if apiKey == "" {
apiKey = r.Header.Get("Authorization")
if strings.HasPrefix(apiKey, "Bearer ") {
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
}
}
if apiKey == "" {
s.respondError(w, http.StatusBadRequest, "API key required")
return
}
// Validate API key
apiKeyID, _, err := s.secrets.ValidateRunnerAPIKey(apiKey)
if err != nil {
s.respondError(w, http.StatusUnauthorized, fmt.Sprintf("Invalid API key: %v", err))
return
}
// Get runner ID from query params or find by API key
runnerIDStr := r.URL.Query().Get("runner_id")
providedSecret := r.URL.Query().Get("secret")
if runnerIDStr == "" || providedSecret == "" {
s.respondError(w, http.StatusBadRequest, "runner_id and secret required")
return
}
var runnerID int64
_, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID)
if err != nil {
s.respondError(w, http.StatusBadRequest, "invalid runner_id")
return
}
// Get runner secret
runnerSecret, err := s.secrets.GetRunnerSecret(runnerID)
if err != nil {
s.respondError(w, http.StatusUnauthorized, "runner not found or not verified")
return
}
if runnerIDStr != "" {
// Runner ID provided - verify it belongs to this API key
_, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID)
if err != nil {
s.respondError(w, http.StatusBadRequest, "invalid runner_id")
return
}
// Verify shared secret
if providedSecret != runnerSecret {
s.respondError(w, http.StatusUnauthorized, "invalid secret")
return
// For fixed API keys, skip database verification
if apiKeyID != -1 {
var dbAPIKeyID sql.NullInt64
err = s.db.QueryRow("SELECT api_key_id FROM runners WHERE id = ?", runnerID).Scan(&dbAPIKeyID)
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "runner not found")
return
}
if err != nil {
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to query runner API key: %v", err))
return
}
if !dbAPIKeyID.Valid || dbAPIKeyID.Int64 != apiKeyID {
s.respondError(w, http.StatusForbidden, "runner does not belong to this API key")
return
}
}
} else {
// No runner ID provided - find the runner for this API key
err = s.db.QueryRow("SELECT id FROM runners WHERE api_key_id = ?", apiKeyID).Scan(&runnerID)
if err == sql.ErrNoRows {
s.respondError(w, http.StatusNotFound, "no runner found for this API key")
return
}
if err != nil {
s.respondError(w, http.StatusInternalServerError, "database error")
return
}
}
// Upgrade to WebSocket
@@ -685,18 +808,25 @@ func (s *Server) handleRunnerWebSocket(w http.ResponseWriter, r *http.Request) {
// Register connection (must be done before any distribution checks)
// Close old connection outside lock to avoid blocking
var oldConn *websocket.Conn
var hadExistingConnection bool
s.runnerConnsMu.Lock()
if existingConn, exists := s.runnerConns[runnerID]; exists {
oldConn = existingConn
hadExistingConnection = true
}
s.runnerConns[runnerID] = conn
s.runnerConnsMu.Unlock()
// Close old connection outside lock (if it existed)
if oldConn != nil {
log.Printf("Runner %d: closing existing WebSocket connection (reconnection)", runnerID)
oldConn.Close()
} else if hadExistingConnection {
log.Printf("Runner %d: replacing existing WebSocket connection", runnerID)
}
log.Printf("Runner %d: WebSocket connection established successfully", runnerID)
// Create a write mutex for this connection
s.runnerConnsWriteMuMu.Lock()
s.runnerConnsWriteMu[runnerID] = &sync.Mutex{}
@@ -717,20 +847,31 @@ func (s *Server) handleRunnerWebSocket(w http.ResponseWriter, r *http.Request) {
// Cleanup on disconnect
defer func() {
log.Printf("Runner %d: WebSocket connection cleanup started", runnerID)
// Update database status first
_, err := s.db.Exec(
`UPDATE runners SET status = ?, last_heartbeat = ? WHERE id = ?`,
types.RunnerStatusOffline, time.Now(), runnerID,
)
if err != nil {
log.Printf("Warning: Failed to update runner %d status to offline: %v", runnerID, err)
}
// Clean up connection maps
s.runnerConnsMu.Lock()
delete(s.runnerConns, runnerID)
s.runnerConnsMu.Unlock()
s.runnerConnsWriteMuMu.Lock()
delete(s.runnerConnsWriteMu, runnerID)
s.runnerConnsWriteMuMu.Unlock()
_, _ = s.db.Exec(
`UPDATE runners SET status = ? WHERE id = ?`,
types.RunnerStatusOffline, runnerID,
)
// Immediately redistribute tasks that were assigned to this runner
log.Printf("Runner %d disconnected, redistributing its tasks", runnerID)
log.Printf("Runner %d: WebSocket disconnected, redistributing tasks", runnerID)
s.redistributeRunnerTasks(runnerID)
log.Printf("Runner %d: WebSocket connection cleanup completed", runnerID)
}()
// Set pong handler to update heartbeat when we receive pong responses from runner
@@ -1341,7 +1482,7 @@ func (s *Server) distributeTasksToRunners() {
// Get all pending tasks
rows, err := s.db.Query(
`SELECT t.id, t.job_id, t.frame_start, t.frame_end, t.task_type, j.allow_parallel_runners, j.status as job_status, j.name as job_name
`SELECT t.id, t.job_id, t.frame_start, t.frame_end, t.task_type, j.allow_parallel_runners, j.status as job_status, j.name as job_name, j.user_id
FROM tasks t
JOIN jobs j ON t.job_id = j.id
WHERE t.status = ? AND j.status != ?
@@ -1363,6 +1504,7 @@ func (s *Server) distributeTasksToRunners() {
AllowParallelRunners bool
JobName string
JobStatus string
JobUserID int64
}
for rows.Next() {
@@ -1375,9 +1517,10 @@ func (s *Server) distributeTasksToRunners() {
AllowParallelRunners bool
JobName string
JobStatus string
JobUserID int64
}
var allowParallel sql.NullBool
err := rows.Scan(&t.TaskID, &t.JobID, &t.FrameStart, &t.FrameEnd, &t.TaskType, &allowParallel, &t.JobStatus, &t.JobName)
err := rows.Scan(&t.TaskID, &t.JobID, &t.FrameStart, &t.FrameEnd, &t.TaskType, &allowParallel, &t.JobStatus, &t.JobName, &t.JobUserID)
if err != nil {
log.Printf("Failed to scan pending task: %v", err)
continue
@@ -1411,19 +1554,22 @@ func (s *Server) distributeTasksToRunners() {
}
s.runnerConnsMu.RUnlock()
// Get runner priorities and capabilities for all connected runners
// Get runner priorities, capabilities, and API key scopes for all connected runners
runnerPriorities := make(map[int64]int)
runnerCapabilities := make(map[int64]map[string]interface{})
runnerScopes := make(map[int64]string)
for _, runnerID := range connectedRunners {
var priority int
var capabilitiesJSON string
err := s.db.QueryRow("SELECT priority, capabilities FROM runners WHERE id = ?", runnerID).Scan(&priority, &capabilitiesJSON)
var scope string
err := s.db.QueryRow("SELECT priority, capabilities, api_key_scope FROM runners WHERE id = ?", runnerID).Scan(&priority, &capabilitiesJSON, &scope)
if err != nil {
// Default to 100 if priority not found
priority = 100
capabilitiesJSON = "{}"
}
runnerPriorities[runnerID] = priority
runnerScopes[runnerID] = scope
// Parse capabilities JSON (can contain both bools and numbers)
var capabilities map[string]interface{}
@@ -1512,6 +1658,30 @@ func (s *Server) distributeTasksToRunners() {
// Try to find the best runner for this task
for _, runnerID := range connectedRunners {
// Check if runner's API key scope allows working on this job
runnerScope := runnerScopes[runnerID]
if runnerScope == "user" && task.JobUserID != 0 {
// User-scoped runner - check if they can work on jobs from this user
// For now, user-scoped runners can only work on jobs from the same user who created their API key
var apiKeyCreatedBy int64
if runnerScope == "user" {
// Get the user who created this runner's API key
var apiKeyID sql.NullInt64
err := s.db.QueryRow("SELECT api_key_id FROM runners WHERE id = ?", runnerID).Scan(&apiKeyID)
if err == nil && apiKeyID.Valid {
err = s.db.QueryRow("SELECT created_by FROM runner_api_keys WHERE id = ?", apiKeyID.Int64).Scan(&apiKeyCreatedBy)
if err != nil {
continue // Skip this runner if we can't determine API key ownership
}
// Only allow if the job owner matches the API key creator
if apiKeyCreatedBy != task.JobUserID {
continue // This user-scoped runner cannot work on this job
}
}
}
// Manager-scoped runners can work on any job
}
// Check if runner has required capability
capabilities := runnerCapabilities[runnerID]
hasRequired := false
@@ -1891,9 +2061,11 @@ func (s *Server) assignTaskToRunner(runnerID int64, taskID int64) error {
// redistributeRunnerTasks resets tasks assigned to a disconnected/dead runner and redistributes them
func (s *Server) redistributeRunnerTasks(runnerID int64) {
// Get tasks assigned to this runner
log.Printf("Starting task redistribution for disconnected runner %d", runnerID)
// Get tasks assigned to this runner that are still running
taskRows, err := s.db.Query(
`SELECT id, retry_count, max_retries FROM tasks
`SELECT id, retry_count, max_retries, job_id FROM tasks
WHERE runner_id = ? AND status = ?`,
runnerID, types.TaskStatusRunning,
)
@@ -1907,6 +2079,7 @@ func (s *Server) redistributeRunnerTasks(runnerID int64) {
ID int64
RetryCount int
MaxRetries int
JobID int64
}
for taskRows.Next() {
@@ -1914,51 +2087,78 @@ func (s *Server) redistributeRunnerTasks(runnerID int64) {
ID int64
RetryCount int
MaxRetries int
JobID int64
}
if err := taskRows.Scan(&t.ID, &t.RetryCount, &t.MaxRetries); err == nil {
tasksToReset = append(tasksToReset, t)
if err := taskRows.Scan(&t.ID, &t.RetryCount, &t.MaxRetries, &t.JobID); err != nil {
log.Printf("Failed to scan task for runner %d: %v", runnerID, err)
continue
}
tasksToReset = append(tasksToReset, t)
}
if len(tasksToReset) == 0 {
return // No tasks to redistribute
log.Printf("No running tasks found for runner %d to redistribute", runnerID)
return
}
log.Printf("Redistributing %d tasks from runner %d", len(tasksToReset), runnerID)
log.Printf("Redistributing %d running tasks from disconnected runner %d", len(tasksToReset), runnerID)
// Reset or fail tasks
resetCount := 0
failedCount := 0
for _, task := range tasksToReset {
if task.RetryCount >= task.MaxRetries {
// Mark as failed
_, err = s.db.Exec(
`UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL
WHERE id = ?`,
types.TaskStatusFailed, "Runner died, max retries exceeded", task.ID,
`UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL, completed_at = ?
WHERE id = ? AND runner_id = ?`,
types.TaskStatusFailed, "Runner disconnected, max retries exceeded", time.Now(), task.ID, runnerID,
)
if err != nil {
log.Printf("Failed to mark task %d as failed: %v", task.ID, err)
} else {
failedCount++
// Log task failure
s.logTaskEvent(task.ID, &runnerID, types.LogLevelError, fmt.Sprintf("Task failed - runner %d disconnected, max retries (%d) exceeded", runnerID, task.MaxRetries), "")
s.logTaskEvent(task.ID, &runnerID, types.LogLevelError,
fmt.Sprintf("Task failed - runner %d disconnected, max retries (%d) exceeded", runnerID, task.MaxRetries), "")
}
} else {
// Reset to pending so it can be redistributed
_, err = s.db.Exec(
`UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL,
retry_count = retry_count + 1 WHERE id = ?`,
types.TaskStatusPending, task.ID,
retry_count = retry_count + 1, started_at = NULL WHERE id = ? AND runner_id = ?`,
types.TaskStatusPending, task.ID, runnerID,
)
if err != nil {
log.Printf("Failed to reset task %d: %v", task.ID, err)
} else {
resetCount++
// Log task reset for redistribution
s.logTaskEvent(task.ID, &runnerID, types.LogLevelWarn, fmt.Sprintf("Runner %d disconnected, task reset for redistribution (retry %d/%d)", runnerID, task.RetryCount+1, task.MaxRetries), "")
s.logTaskEvent(task.ID, &runnerID, types.LogLevelWarn,
fmt.Sprintf("Runner %d disconnected, task reset for redistribution (retry %d/%d)", runnerID, task.RetryCount+1, task.MaxRetries), "")
}
}
}
log.Printf("Task redistribution complete for runner %d: %d tasks reset, %d tasks failed", runnerID, resetCount, failedCount)
// Update job statuses for affected jobs
jobIDs := make(map[int64]bool)
for _, task := range tasksToReset {
jobIDs[task.JobID] = true
}
for jobID := range jobIDs {
// Update job status based on remaining tasks
go s.updateJobStatusFromTasks(jobID)
}
// Immediately redistribute the reset tasks
s.triggerTaskDistribution()
if resetCount > 0 {
log.Printf("Triggering task distribution for %d reset tasks from runner %d", resetCount, runnerID)
s.triggerTaskDistribution()
}
}
// logTaskEvent logs an event to a task's log (manager-side logging)
@@ -1986,3 +2186,4 @@ func (s *Server) logTaskEvent(taskID int64, runnerID *int64, logLevel types.LogL
StepName: stepName,
})
}

View File

@@ -215,14 +215,20 @@ func (s *Server) setupRoutes() {
// WebSocket routes for real-time updates
r.With(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Remove timeout middleware for WebSocket
next.ServeHTTP(w, r)
// Apply authentication middleware first
s.auth.Middleware(func(w http.ResponseWriter, r *http.Request) {
// Remove timeout middleware for WebSocket
next.ServeHTTP(w, r)
})(w, r)
})
}).Get("/ws", s.handleJobsWebSocket)
r.With(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Remove timeout middleware for WebSocket
next.ServeHTTP(w, r)
// Apply authentication middleware first
s.auth.Middleware(func(w http.ResponseWriter, r *http.Request) {
// Remove timeout middleware for WebSocket
next.ServeHTTP(w, r)
})(w, r)
})
}).Get("/{id}/ws", s.handleJobWebSocket)
})
@@ -233,10 +239,11 @@ func (s *Server) setupRoutes() {
return http.HandlerFunc(s.auth.AdminMiddleware(next.ServeHTTP))
})
r.Route("/runners", func(r chi.Router) {
r.Route("/tokens", func(r chi.Router) {
r.Post("/", s.handleGenerateRegistrationToken)
r.Get("/", s.handleListRegistrationTokens)
r.Delete("/{id}", s.handleRevokeRegistrationToken)
r.Route("/api-keys", func(r chi.Router) {
r.Post("/", s.handleGenerateRunnerAPIKey)
r.Get("/", s.handleListRunnerAPIKeys)
r.Patch("/{id}/revoke", s.handleRevokeRunnerAPIKey)
r.Delete("/{id}", s.handleDeleteRunnerAPIKey)
})
r.Get("/", s.handleListRunnersAdmin)
r.Post("/{id}/verify", s.handleVerifyRunner)
@@ -266,6 +273,7 @@ func (s *Server) setupRoutes() {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(s.runnerAuthMiddleware(next.ServeHTTP))
})
r.Get("/ping", s.handleRunnerPing)
r.Post("/tasks/{id}/progress", s.handleUpdateTaskProgress)
r.Post("/tasks/{id}/steps", s.handleUpdateTaskStep)
r.Get("/jobs/{jobId}/context.tar", s.handleDownloadJobContext)
@@ -441,7 +449,7 @@ func (s *Server) handleLocalRegister(w http.ResponseWriter, r *http.Request) {
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
return
}
@@ -489,7 +497,7 @@ func (s *Server) handleLocalLogin(w http.ResponseWriter, r *http.Request) {
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
return
}
@@ -540,7 +548,7 @@ func (s *Server) handleChangePassword(w http.ResponseWriter, r *http.Request) {
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondError(w, http.StatusBadRequest, "Invalid request body")
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
return
}