its a bit broken
This commit is contained in:
@@ -28,41 +28,75 @@ type contextKey string
|
||||
|
||||
const runnerIDContextKey contextKey = "runner_id"
|
||||
|
||||
// runnerAuthMiddleware verifies runner requests using shared secret header
|
||||
// runnerAuthMiddleware verifies runner requests using API key
|
||||
func (s *Server) runnerAuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get runner ID from query string
|
||||
// Get API key from header
|
||||
apiKey := r.Header.Get("Authorization")
|
||||
if apiKey == "" {
|
||||
// Try alternative header
|
||||
apiKey = r.Header.Get("X-API-Key")
|
||||
}
|
||||
if apiKey == "" {
|
||||
s.respondError(w, http.StatusUnauthorized, "API key required")
|
||||
return
|
||||
}
|
||||
|
||||
// Remove "Bearer " prefix if present
|
||||
if strings.HasPrefix(apiKey, "Bearer ") {
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
}
|
||||
|
||||
// Validate API key and get its ID
|
||||
apiKeyID, _, err := s.secrets.ValidateRunnerAPIKey(apiKey)
|
||||
if err != nil {
|
||||
log.Printf("API key validation failed: %v", err)
|
||||
s.respondError(w, http.StatusUnauthorized, "invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
// Get runner ID from query string or find runner by API key
|
||||
runnerIDStr := r.URL.Query().Get("runner_id")
|
||||
if runnerIDStr == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "runner_id required in query string")
|
||||
return
|
||||
}
|
||||
|
||||
var runnerID int64
|
||||
_, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "invalid runner_id")
|
||||
return
|
||||
}
|
||||
|
||||
// Get runner secret
|
||||
runnerSecret, err := s.secrets.GetRunnerSecret(runnerID)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get runner secret for runner %d: %v", runnerID, err)
|
||||
s.respondError(w, http.StatusUnauthorized, "runner not found or not verified")
|
||||
return
|
||||
}
|
||||
if runnerIDStr != "" {
|
||||
// Runner ID provided - verify it belongs to this API key
|
||||
_, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "invalid runner_id")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify shared secret from header
|
||||
providedSecret := r.Header.Get("X-Runner-Secret")
|
||||
if providedSecret == "" {
|
||||
s.respondError(w, http.StatusUnauthorized, "missing secret")
|
||||
return
|
||||
}
|
||||
|
||||
if providedSecret != runnerSecret {
|
||||
s.respondError(w, http.StatusUnauthorized, "invalid secret")
|
||||
return
|
||||
// For fixed API keys, skip database verification
|
||||
if apiKeyID != -1 {
|
||||
// Verify runner exists and uses this API key
|
||||
var dbAPIKeyID sql.NullInt64
|
||||
err = s.db.QueryRow("SELECT api_key_id FROM runners WHERE id = ?", runnerID).Scan(&dbAPIKeyID)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "runner not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to query runner API key: %v", err))
|
||||
return
|
||||
}
|
||||
if !dbAPIKeyID.Valid || dbAPIKeyID.Int64 != apiKeyID {
|
||||
s.respondError(w, http.StatusForbidden, "runner does not belong to this API key")
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No runner ID provided - find the runner for this API key
|
||||
// For simplicity, assume each API key has one runner
|
||||
err = s.db.QueryRow("SELECT id FROM runners WHERE api_key_id = ?", apiKeyID).Scan(&runnerID)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "no runner found for this API key")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to query runner by API key: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Add runner ID to context
|
||||
@@ -72,66 +106,40 @@ func (s *Server) runnerAuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// handleRegisterRunner registers a new runner
|
||||
// Note: Token expiration only affects whether the token can be used for registration.
|
||||
// Once a runner is registered, it receives its own runner_secret and manager_secret
|
||||
// and operates independently. The token expiration does not affect registered runners.
|
||||
// handleRegisterRunner registers a new runner using an API key
|
||||
func (s *Server) handleRegisterRunner(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
types.RegisterRunnerRequest
|
||||
RegistrationToken string `json:"registration_token"`
|
||||
APIKey string `json:"api_key"`
|
||||
Fingerprint string `json:"fingerprint,omitempty"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Lock to prevent concurrent registrations that could create duplicate runners
|
||||
s.secrets.RegistrationMu.Lock()
|
||||
defer s.secrets.RegistrationMu.Unlock()
|
||||
|
||||
if req.Name == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "Runner name is required")
|
||||
return
|
||||
}
|
||||
|
||||
if req.RegistrationToken == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "Registration token is required")
|
||||
if req.APIKey == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "API key is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate registration token (expiration only affects token usability, not registered runners)
|
||||
result, err := s.secrets.ValidateRegistrationTokenDetailed(req.RegistrationToken)
|
||||
// Validate API key
|
||||
apiKeyID, apiKeyScope, err := s.secrets.ValidateRunnerAPIKey(req.APIKey)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to validate token: %v", err))
|
||||
return
|
||||
}
|
||||
if !result.Valid {
|
||||
var errorMsg string
|
||||
switch result.Reason {
|
||||
case "already_used":
|
||||
errorMsg = "Registration token has already been used"
|
||||
case "expired":
|
||||
errorMsg = "Registration token has expired"
|
||||
case "not_found":
|
||||
errorMsg = "Invalid registration token"
|
||||
default:
|
||||
errorMsg = "Invalid or expired registration token"
|
||||
}
|
||||
s.respondError(w, http.StatusUnauthorized, errorMsg)
|
||||
return
|
||||
}
|
||||
|
||||
// Get manager secret
|
||||
managerSecret, err := s.secrets.GetManagerSecret()
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, "Failed to get manager secret")
|
||||
return
|
||||
}
|
||||
|
||||
// Generate runner secret (runner will use this for all future authentication, independent of token)
|
||||
runnerSecret, err := s.secrets.GenerateRunnerSecret()
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, "Failed to generate runner secret")
|
||||
s.respondError(w, http.StatusUnauthorized, fmt.Sprintf("Invalid API key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// For fixed API keys (keyID = -1), skip fingerprint checking
|
||||
// Set default priority if not provided
|
||||
priority := 100
|
||||
if req.Priority != nil {
|
||||
@@ -140,28 +148,110 @@ func (s *Server) handleRegisterRunner(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Register runner
|
||||
var runnerID int64
|
||||
// For fixed API keys, don't store api_key_id in database
|
||||
var dbAPIKeyID interface{}
|
||||
if apiKeyID == -1 {
|
||||
dbAPIKeyID = nil // NULL for fixed API keys
|
||||
} else {
|
||||
dbAPIKeyID = apiKeyID
|
||||
}
|
||||
|
||||
// Determine fingerprint value
|
||||
fingerprint := req.Fingerprint
|
||||
if apiKeyID == -1 || fingerprint == "" {
|
||||
// For fixed API keys or when no fingerprint provided, generate a unique fingerprint
|
||||
// to avoid conflicts while still maintaining some uniqueness
|
||||
fingerprint = fmt.Sprintf("fixed-%s-%d", req.Name, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// Check fingerprint uniqueness only for non-fixed API keys
|
||||
if apiKeyID != -1 && req.Fingerprint != "" {
|
||||
var existingRunnerID int64
|
||||
var existingAPIKeyID sql.NullInt64
|
||||
err = s.db.QueryRow(
|
||||
"SELECT id, api_key_id FROM runners WHERE fingerprint = ?",
|
||||
req.Fingerprint,
|
||||
).Scan(&existingRunnerID, &existingAPIKeyID)
|
||||
|
||||
if err == nil {
|
||||
// Runner already exists with this fingerprint
|
||||
if existingAPIKeyID.Valid && existingAPIKeyID.Int64 == apiKeyID {
|
||||
// Same API key - update and return existing runner
|
||||
log.Printf("Runner with fingerprint %s already exists (ID: %d), updating info", req.Fingerprint, existingRunnerID)
|
||||
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE runners SET name = ?, hostname = ?, capabilities = ?, status = ?, last_heartbeat = ? WHERE id = ?`,
|
||||
req.Name, req.Hostname, req.Capabilities, types.RunnerStatusOnline, time.Now(), existingRunnerID,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to update existing runner info: %v", err)
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"id": existingRunnerID,
|
||||
"name": req.Name,
|
||||
"hostname": req.Hostname,
|
||||
"status": types.RunnerStatusOnline,
|
||||
"reused": true, // Indicates this was a re-registration
|
||||
})
|
||||
return
|
||||
} else {
|
||||
// Different API key - reject registration
|
||||
s.respondError(w, http.StatusConflict, "Runner with this fingerprint already registered with different API key")
|
||||
return
|
||||
}
|
||||
}
|
||||
// If err is not nil, it means no existing runner with this fingerprint - proceed with new registration
|
||||
}
|
||||
|
||||
// Insert runner
|
||||
err = s.db.QueryRow(
|
||||
`INSERT INTO runners (name, hostname, ip_address, status, last_heartbeat, capabilities,
|
||||
registration_token, runner_secret, manager_secret, verified, priority)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`INSERT INTO runners (name, hostname, ip_address, status, last_heartbeat, capabilities,
|
||||
api_key_id, api_key_scope, priority, fingerprint)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
RETURNING id`,
|
||||
req.Name, req.Hostname, "", types.RunnerStatusOnline, time.Now(), req.Capabilities,
|
||||
req.RegistrationToken, runnerSecret, managerSecret, true, priority,
|
||||
dbAPIKeyID, apiKeyScope, priority, fingerprint,
|
||||
).Scan(&runnerID)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to register runner: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Return runner info with secrets
|
||||
log.Printf("Registered new runner %s (ID: %d) with API key ID: %d", req.Name, runnerID, apiKeyID)
|
||||
|
||||
// Return runner info
|
||||
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"id": runnerID,
|
||||
"name": req.Name,
|
||||
"hostname": req.Hostname,
|
||||
"status": types.RunnerStatusOnline,
|
||||
"runner_secret": runnerSecret,
|
||||
"manager_secret": managerSecret,
|
||||
"verified": true,
|
||||
"id": runnerID,
|
||||
"name": req.Name,
|
||||
"hostname": req.Hostname,
|
||||
"status": types.RunnerStatusOnline,
|
||||
})
|
||||
}
|
||||
|
||||
// handleRunnerPing allows runners to validate their secrets and connection
|
||||
func (s *Server) handleRunnerPing(w http.ResponseWriter, r *http.Request) {
|
||||
// This endpoint uses runnerAuthMiddleware, so if we get here, secrets are valid
|
||||
// Get runner ID from context (set by runnerAuthMiddleware)
|
||||
runnerID, ok := r.Context().Value(runnerIDContextKey).(int64)
|
||||
if !ok {
|
||||
s.respondError(w, http.StatusUnauthorized, "runner_id not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
// Update last heartbeat
|
||||
_, err := s.db.Exec(
|
||||
`UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`,
|
||||
time.Now(), types.RunnerStatusOnline, runnerID,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to update runner heartbeat: %v", err)
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"status": "ok",
|
||||
"runner_id": runnerID,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -177,7 +267,7 @@ func (s *Server) handleUpdateTaskProgress(w http.ResponseWriter, r *http.Request
|
||||
Progress float64 `json:"progress"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -207,7 +297,7 @@ func (s *Server) handleUpdateTaskStep(w http.ResponseWriter, r *http.Request) {
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request body: expected valid JSON - %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -354,7 +444,7 @@ func (s *Server) handleUploadFileFromRunner(w http.ResponseWriter, r *http.Reque
|
||||
|
||||
err = r.ParseMultipartForm(50 << 30) // 50 GB (for large output files)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Failed to parse form")
|
||||
s.respondError(w, http.StatusBadRequest, fmt.Sprintf("Failed to parse multipart form: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -530,7 +620,7 @@ func (s *Server) handleGetJobMetadataForRunner(w http.ResponseWriter, r *http.Re
|
||||
|
||||
var metadata types.BlendMetadata
|
||||
if err := json.Unmarshal([]byte(blendMetadataJSON.String), &metadata); err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, "Failed to parse metadata")
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to parse metadata JSON: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -645,33 +735,66 @@ type WSTaskUpdate struct {
|
||||
|
||||
// handleRunnerWebSocket handles WebSocket connections from runners
|
||||
func (s *Server) handleRunnerWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
// Get runner ID and secret from query params
|
||||
// Get API key from query params or headers
|
||||
apiKey := r.URL.Query().Get("api_key")
|
||||
if apiKey == "" {
|
||||
apiKey = r.Header.Get("Authorization")
|
||||
if strings.HasPrefix(apiKey, "Bearer ") {
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
}
|
||||
}
|
||||
if apiKey == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "API key required")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate API key
|
||||
apiKeyID, _, err := s.secrets.ValidateRunnerAPIKey(apiKey)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, fmt.Sprintf("Invalid API key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Get runner ID from query params or find by API key
|
||||
runnerIDStr := r.URL.Query().Get("runner_id")
|
||||
providedSecret := r.URL.Query().Get("secret")
|
||||
|
||||
if runnerIDStr == "" || providedSecret == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "runner_id and secret required")
|
||||
return
|
||||
}
|
||||
|
||||
var runnerID int64
|
||||
_, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "invalid runner_id")
|
||||
return
|
||||
}
|
||||
|
||||
// Get runner secret
|
||||
runnerSecret, err := s.secrets.GetRunnerSecret(runnerID)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, "runner not found or not verified")
|
||||
return
|
||||
}
|
||||
if runnerIDStr != "" {
|
||||
// Runner ID provided - verify it belongs to this API key
|
||||
_, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "invalid runner_id")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify shared secret
|
||||
if providedSecret != runnerSecret {
|
||||
s.respondError(w, http.StatusUnauthorized, "invalid secret")
|
||||
return
|
||||
// For fixed API keys, skip database verification
|
||||
if apiKeyID != -1 {
|
||||
var dbAPIKeyID sql.NullInt64
|
||||
err = s.db.QueryRow("SELECT api_key_id FROM runners WHERE id = ?", runnerID).Scan(&dbAPIKeyID)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "runner not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to query runner API key: %v", err))
|
||||
return
|
||||
}
|
||||
if !dbAPIKeyID.Valid || dbAPIKeyID.Int64 != apiKeyID {
|
||||
s.respondError(w, http.StatusForbidden, "runner does not belong to this API key")
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No runner ID provided - find the runner for this API key
|
||||
err = s.db.QueryRow("SELECT id FROM runners WHERE api_key_id = ?", apiKeyID).Scan(&runnerID)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "no runner found for this API key")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, "database error")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Upgrade to WebSocket
|
||||
@@ -685,18 +808,25 @@ func (s *Server) handleRunnerWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
// Register connection (must be done before any distribution checks)
|
||||
// Close old connection outside lock to avoid blocking
|
||||
var oldConn *websocket.Conn
|
||||
var hadExistingConnection bool
|
||||
s.runnerConnsMu.Lock()
|
||||
if existingConn, exists := s.runnerConns[runnerID]; exists {
|
||||
oldConn = existingConn
|
||||
hadExistingConnection = true
|
||||
}
|
||||
s.runnerConns[runnerID] = conn
|
||||
s.runnerConnsMu.Unlock()
|
||||
|
||||
// Close old connection outside lock (if it existed)
|
||||
if oldConn != nil {
|
||||
log.Printf("Runner %d: closing existing WebSocket connection (reconnection)", runnerID)
|
||||
oldConn.Close()
|
||||
} else if hadExistingConnection {
|
||||
log.Printf("Runner %d: replacing existing WebSocket connection", runnerID)
|
||||
}
|
||||
|
||||
log.Printf("Runner %d: WebSocket connection established successfully", runnerID)
|
||||
|
||||
// Create a write mutex for this connection
|
||||
s.runnerConnsWriteMuMu.Lock()
|
||||
s.runnerConnsWriteMu[runnerID] = &sync.Mutex{}
|
||||
@@ -717,20 +847,31 @@ func (s *Server) handleRunnerWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Cleanup on disconnect
|
||||
defer func() {
|
||||
log.Printf("Runner %d: WebSocket connection cleanup started", runnerID)
|
||||
|
||||
// Update database status first
|
||||
_, err := s.db.Exec(
|
||||
`UPDATE runners SET status = ?, last_heartbeat = ? WHERE id = ?`,
|
||||
types.RunnerStatusOffline, time.Now(), runnerID,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to update runner %d status to offline: %v", runnerID, err)
|
||||
}
|
||||
|
||||
// Clean up connection maps
|
||||
s.runnerConnsMu.Lock()
|
||||
delete(s.runnerConns, runnerID)
|
||||
s.runnerConnsMu.Unlock()
|
||||
|
||||
s.runnerConnsWriteMuMu.Lock()
|
||||
delete(s.runnerConnsWriteMu, runnerID)
|
||||
s.runnerConnsWriteMuMu.Unlock()
|
||||
_, _ = s.db.Exec(
|
||||
`UPDATE runners SET status = ? WHERE id = ?`,
|
||||
types.RunnerStatusOffline, runnerID,
|
||||
)
|
||||
|
||||
// Immediately redistribute tasks that were assigned to this runner
|
||||
log.Printf("Runner %d disconnected, redistributing its tasks", runnerID)
|
||||
log.Printf("Runner %d: WebSocket disconnected, redistributing tasks", runnerID)
|
||||
s.redistributeRunnerTasks(runnerID)
|
||||
|
||||
log.Printf("Runner %d: WebSocket connection cleanup completed", runnerID)
|
||||
}()
|
||||
|
||||
// Set pong handler to update heartbeat when we receive pong responses from runner
|
||||
@@ -1341,7 +1482,7 @@ func (s *Server) distributeTasksToRunners() {
|
||||
|
||||
// Get all pending tasks
|
||||
rows, err := s.db.Query(
|
||||
`SELECT t.id, t.job_id, t.frame_start, t.frame_end, t.task_type, j.allow_parallel_runners, j.status as job_status, j.name as job_name
|
||||
`SELECT t.id, t.job_id, t.frame_start, t.frame_end, t.task_type, j.allow_parallel_runners, j.status as job_status, j.name as job_name, j.user_id
|
||||
FROM tasks t
|
||||
JOIN jobs j ON t.job_id = j.id
|
||||
WHERE t.status = ? AND j.status != ?
|
||||
@@ -1363,6 +1504,7 @@ func (s *Server) distributeTasksToRunners() {
|
||||
AllowParallelRunners bool
|
||||
JobName string
|
||||
JobStatus string
|
||||
JobUserID int64
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
@@ -1375,9 +1517,10 @@ func (s *Server) distributeTasksToRunners() {
|
||||
AllowParallelRunners bool
|
||||
JobName string
|
||||
JobStatus string
|
||||
JobUserID int64
|
||||
}
|
||||
var allowParallel sql.NullBool
|
||||
err := rows.Scan(&t.TaskID, &t.JobID, &t.FrameStart, &t.FrameEnd, &t.TaskType, &allowParallel, &t.JobStatus, &t.JobName)
|
||||
err := rows.Scan(&t.TaskID, &t.JobID, &t.FrameStart, &t.FrameEnd, &t.TaskType, &allowParallel, &t.JobStatus, &t.JobName, &t.JobUserID)
|
||||
if err != nil {
|
||||
log.Printf("Failed to scan pending task: %v", err)
|
||||
continue
|
||||
@@ -1411,19 +1554,22 @@ func (s *Server) distributeTasksToRunners() {
|
||||
}
|
||||
s.runnerConnsMu.RUnlock()
|
||||
|
||||
// Get runner priorities and capabilities for all connected runners
|
||||
// Get runner priorities, capabilities, and API key scopes for all connected runners
|
||||
runnerPriorities := make(map[int64]int)
|
||||
runnerCapabilities := make(map[int64]map[string]interface{})
|
||||
runnerScopes := make(map[int64]string)
|
||||
for _, runnerID := range connectedRunners {
|
||||
var priority int
|
||||
var capabilitiesJSON string
|
||||
err := s.db.QueryRow("SELECT priority, capabilities FROM runners WHERE id = ?", runnerID).Scan(&priority, &capabilitiesJSON)
|
||||
var scope string
|
||||
err := s.db.QueryRow("SELECT priority, capabilities, api_key_scope FROM runners WHERE id = ?", runnerID).Scan(&priority, &capabilitiesJSON, &scope)
|
||||
if err != nil {
|
||||
// Default to 100 if priority not found
|
||||
priority = 100
|
||||
capabilitiesJSON = "{}"
|
||||
}
|
||||
runnerPriorities[runnerID] = priority
|
||||
runnerScopes[runnerID] = scope
|
||||
|
||||
// Parse capabilities JSON (can contain both bools and numbers)
|
||||
var capabilities map[string]interface{}
|
||||
@@ -1512,6 +1658,30 @@ func (s *Server) distributeTasksToRunners() {
|
||||
|
||||
// Try to find the best runner for this task
|
||||
for _, runnerID := range connectedRunners {
|
||||
// Check if runner's API key scope allows working on this job
|
||||
runnerScope := runnerScopes[runnerID]
|
||||
if runnerScope == "user" && task.JobUserID != 0 {
|
||||
// User-scoped runner - check if they can work on jobs from this user
|
||||
// For now, user-scoped runners can only work on jobs from the same user who created their API key
|
||||
var apiKeyCreatedBy int64
|
||||
if runnerScope == "user" {
|
||||
// Get the user who created this runner's API key
|
||||
var apiKeyID sql.NullInt64
|
||||
err := s.db.QueryRow("SELECT api_key_id FROM runners WHERE id = ?", runnerID).Scan(&apiKeyID)
|
||||
if err == nil && apiKeyID.Valid {
|
||||
err = s.db.QueryRow("SELECT created_by FROM runner_api_keys WHERE id = ?", apiKeyID.Int64).Scan(&apiKeyCreatedBy)
|
||||
if err != nil {
|
||||
continue // Skip this runner if we can't determine API key ownership
|
||||
}
|
||||
// Only allow if the job owner matches the API key creator
|
||||
if apiKeyCreatedBy != task.JobUserID {
|
||||
continue // This user-scoped runner cannot work on this job
|
||||
}
|
||||
}
|
||||
}
|
||||
// Manager-scoped runners can work on any job
|
||||
}
|
||||
|
||||
// Check if runner has required capability
|
||||
capabilities := runnerCapabilities[runnerID]
|
||||
hasRequired := false
|
||||
@@ -1891,9 +2061,11 @@ func (s *Server) assignTaskToRunner(runnerID int64, taskID int64) error {
|
||||
|
||||
// redistributeRunnerTasks resets tasks assigned to a disconnected/dead runner and redistributes them
|
||||
func (s *Server) redistributeRunnerTasks(runnerID int64) {
|
||||
// Get tasks assigned to this runner
|
||||
log.Printf("Starting task redistribution for disconnected runner %d", runnerID)
|
||||
|
||||
// Get tasks assigned to this runner that are still running
|
||||
taskRows, err := s.db.Query(
|
||||
`SELECT id, retry_count, max_retries FROM tasks
|
||||
`SELECT id, retry_count, max_retries, job_id FROM tasks
|
||||
WHERE runner_id = ? AND status = ?`,
|
||||
runnerID, types.TaskStatusRunning,
|
||||
)
|
||||
@@ -1907,6 +2079,7 @@ func (s *Server) redistributeRunnerTasks(runnerID int64) {
|
||||
ID int64
|
||||
RetryCount int
|
||||
MaxRetries int
|
||||
JobID int64
|
||||
}
|
||||
|
||||
for taskRows.Next() {
|
||||
@@ -1914,51 +2087,78 @@ func (s *Server) redistributeRunnerTasks(runnerID int64) {
|
||||
ID int64
|
||||
RetryCount int
|
||||
MaxRetries int
|
||||
JobID int64
|
||||
}
|
||||
if err := taskRows.Scan(&t.ID, &t.RetryCount, &t.MaxRetries); err == nil {
|
||||
tasksToReset = append(tasksToReset, t)
|
||||
if err := taskRows.Scan(&t.ID, &t.RetryCount, &t.MaxRetries, &t.JobID); err != nil {
|
||||
log.Printf("Failed to scan task for runner %d: %v", runnerID, err)
|
||||
continue
|
||||
}
|
||||
tasksToReset = append(tasksToReset, t)
|
||||
}
|
||||
|
||||
if len(tasksToReset) == 0 {
|
||||
return // No tasks to redistribute
|
||||
log.Printf("No running tasks found for runner %d to redistribute", runnerID)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Redistributing %d tasks from runner %d", len(tasksToReset), runnerID)
|
||||
log.Printf("Redistributing %d running tasks from disconnected runner %d", len(tasksToReset), runnerID)
|
||||
|
||||
// Reset or fail tasks
|
||||
resetCount := 0
|
||||
failedCount := 0
|
||||
|
||||
for _, task := range tasksToReset {
|
||||
if task.RetryCount >= task.MaxRetries {
|
||||
// Mark as failed
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL
|
||||
WHERE id = ?`,
|
||||
types.TaskStatusFailed, "Runner died, max retries exceeded", task.ID,
|
||||
`UPDATE tasks SET status = ?, error_message = ?, runner_id = NULL, completed_at = ?
|
||||
WHERE id = ? AND runner_id = ?`,
|
||||
types.TaskStatusFailed, "Runner disconnected, max retries exceeded", time.Now(), task.ID, runnerID,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to mark task %d as failed: %v", task.ID, err)
|
||||
} else {
|
||||
failedCount++
|
||||
// Log task failure
|
||||
s.logTaskEvent(task.ID, &runnerID, types.LogLevelError, fmt.Sprintf("Task failed - runner %d disconnected, max retries (%d) exceeded", runnerID, task.MaxRetries), "")
|
||||
s.logTaskEvent(task.ID, &runnerID, types.LogLevelError,
|
||||
fmt.Sprintf("Task failed - runner %d disconnected, max retries (%d) exceeded", runnerID, task.MaxRetries), "")
|
||||
}
|
||||
} else {
|
||||
// Reset to pending so it can be redistributed
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, runner_id = NULL, current_step = NULL,
|
||||
retry_count = retry_count + 1 WHERE id = ?`,
|
||||
types.TaskStatusPending, task.ID,
|
||||
retry_count = retry_count + 1, started_at = NULL WHERE id = ? AND runner_id = ?`,
|
||||
types.TaskStatusPending, task.ID, runnerID,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to reset task %d: %v", task.ID, err)
|
||||
} else {
|
||||
resetCount++
|
||||
// Log task reset for redistribution
|
||||
s.logTaskEvent(task.ID, &runnerID, types.LogLevelWarn, fmt.Sprintf("Runner %d disconnected, task reset for redistribution (retry %d/%d)", runnerID, task.RetryCount+1, task.MaxRetries), "")
|
||||
s.logTaskEvent(task.ID, &runnerID, types.LogLevelWarn,
|
||||
fmt.Sprintf("Runner %d disconnected, task reset for redistribution (retry %d/%d)", runnerID, task.RetryCount+1, task.MaxRetries), "")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("Task redistribution complete for runner %d: %d tasks reset, %d tasks failed", runnerID, resetCount, failedCount)
|
||||
|
||||
// Update job statuses for affected jobs
|
||||
jobIDs := make(map[int64]bool)
|
||||
for _, task := range tasksToReset {
|
||||
jobIDs[task.JobID] = true
|
||||
}
|
||||
|
||||
for jobID := range jobIDs {
|
||||
// Update job status based on remaining tasks
|
||||
go s.updateJobStatusFromTasks(jobID)
|
||||
}
|
||||
|
||||
// Immediately redistribute the reset tasks
|
||||
s.triggerTaskDistribution()
|
||||
if resetCount > 0 {
|
||||
log.Printf("Triggering task distribution for %d reset tasks from runner %d", resetCount, runnerID)
|
||||
s.triggerTaskDistribution()
|
||||
}
|
||||
}
|
||||
|
||||
// logTaskEvent logs an event to a task's log (manager-side logging)
|
||||
@@ -1986,3 +2186,4 @@ func (s *Server) logTaskEvent(taskID int64, runnerID *int64, logLevel types.LogL
|
||||
StepName: stepName,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user