initial commit
This commit is contained in:
172
internal/api/admin.go
Normal file
172
internal/api/admin.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"fuego/internal/auth"
|
||||
"fuego/pkg/types"
|
||||
)
|
||||
|
||||
// handleGenerateRegistrationToken generates a new registration token
|
||||
func (s *Server) handleGenerateRegistrationToken(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := getUserID(r)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Default expiration: 24 hours
|
||||
expiresIn := 24 * time.Hour
|
||||
|
||||
var req struct {
|
||||
ExpiresInHours int `json:"expires_in_hours,omitempty"`
|
||||
}
|
||||
if r.Body != nil && r.ContentLength > 0 {
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err == nil && req.ExpiresInHours > 0 {
|
||||
expiresIn = time.Duration(req.ExpiresInHours) * time.Hour
|
||||
}
|
||||
}
|
||||
|
||||
token, err := s.secrets.GenerateRegistrationToken(userID, expiresIn)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to generate token: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"token": token,
|
||||
"expires_in": expiresIn.String(),
|
||||
"expires_at": time.Now().Add(expiresIn),
|
||||
})
|
||||
}
|
||||
|
||||
// handleListRegistrationTokens lists all registration tokens
|
||||
func (s *Server) handleListRegistrationTokens(w http.ResponseWriter, r *http.Request) {
|
||||
tokens, err := s.secrets.ListRegistrationTokens()
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to list tokens: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, tokens)
|
||||
}
|
||||
|
||||
// handleRevokeRegistrationToken revokes a registration token
|
||||
func (s *Server) handleRevokeRegistrationToken(w http.ResponseWriter, r *http.Request) {
|
||||
tokenID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.secrets.RevokeRegistrationToken(tokenID); err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to revoke token: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Token revoked"})
|
||||
}
|
||||
|
||||
// handleVerifyRunner manually verifies a runner
|
||||
func (s *Server) handleVerifyRunner(w http.ResponseWriter, r *http.Request) {
|
||||
runnerID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Check if runner exists
|
||||
var exists bool
|
||||
err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists)
|
||||
if err != nil || !exists {
|
||||
s.respondError(w, http.StatusNotFound, "Runner not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Mark runner as verified
|
||||
_, err = s.db.Exec("UPDATE runners SET verified = 1 WHERE id = ?", runnerID)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify runner: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Runner verified"})
|
||||
}
|
||||
|
||||
// handleDeleteRunner removes a runner
|
||||
func (s *Server) handleDeleteRunner(w http.ResponseWriter, r *http.Request) {
|
||||
runnerID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Check if runner exists
|
||||
var exists bool
|
||||
err = s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM runners WHERE id = ?)", runnerID).Scan(&exists)
|
||||
if err != nil || !exists {
|
||||
s.respondError(w, http.StatusNotFound, "Runner not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Delete runner
|
||||
_, err = s.db.Exec("DELETE FROM runners WHERE id = ?", runnerID)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to delete runner: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Runner deleted"})
|
||||
}
|
||||
|
||||
// handleListRunnersAdmin lists all runners with admin details
|
||||
func (s *Server) handleListRunnersAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, name, hostname, ip_address, status, last_heartbeat, capabilities,
|
||||
registration_token, verified, created_at
|
||||
FROM runners ORDER BY created_at DESC`,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query runners: %v", err))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
runners := []map[string]interface{}{}
|
||||
for rows.Next() {
|
||||
var runner types.Runner
|
||||
var registrationToken sql.NullString
|
||||
var verified bool
|
||||
|
||||
err := rows.Scan(
|
||||
&runner.ID, &runner.Name, &runner.Hostname, &runner.IPAddress,
|
||||
&runner.Status, &runner.LastHeartbeat, &runner.Capabilities,
|
||||
®istrationToken, &verified, &runner.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan runner: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
runners = append(runners, map[string]interface{}{
|
||||
"id": runner.ID,
|
||||
"name": runner.Name,
|
||||
"hostname": runner.Hostname,
|
||||
"ip_address": runner.IPAddress,
|
||||
"status": runner.Status,
|
||||
"last_heartbeat": runner.LastHeartbeat,
|
||||
"capabilities": runner.Capabilities,
|
||||
"registration_token": registrationToken.String,
|
||||
"verified": verified,
|
||||
"created_at": runner.CreatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, runners)
|
||||
}
|
||||
|
||||
498
internal/api/jobs.go
Normal file
498
internal/api/jobs.go
Normal file
@@ -0,0 +1,498 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"fuego/pkg/types"
|
||||
)
|
||||
|
||||
// handleCreateJob creates a new job
|
||||
func (s *Server) handleCreateJob(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := getUserID(r)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var req types.CreateJobRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "Job name is required")
|
||||
return
|
||||
}
|
||||
|
||||
if req.FrameStart < 0 || req.FrameEnd < req.FrameStart {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid frame range")
|
||||
return
|
||||
}
|
||||
|
||||
if req.OutputFormat == "" {
|
||||
req.OutputFormat = "PNG"
|
||||
}
|
||||
|
||||
result, err := s.db.Exec(
|
||||
`INSERT INTO jobs (user_id, name, status, progress, frame_start, frame_end, output_format)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
userID, req.Name, types.JobStatusPending, 0.0, req.FrameStart, req.FrameEnd, req.OutputFormat,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create job: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
jobID, _ := result.LastInsertId()
|
||||
|
||||
// Create tasks for the job (one task per frame for simplicity, could be batched)
|
||||
for frame := req.FrameStart; frame <= req.FrameEnd; frame++ {
|
||||
_, err = s.db.Exec(
|
||||
`INSERT INTO tasks (job_id, frame_start, frame_end, status) VALUES (?, ?, ?, ?)`,
|
||||
jobID, frame, frame, types.TaskStatusPending,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create tasks: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
job := types.Job{
|
||||
ID: jobID,
|
||||
UserID: userID,
|
||||
Name: req.Name,
|
||||
Status: types.JobStatusPending,
|
||||
Progress: 0.0,
|
||||
FrameStart: req.FrameStart,
|
||||
FrameEnd: req.FrameEnd,
|
||||
OutputFormat: req.OutputFormat,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusCreated, job)
|
||||
}
|
||||
|
||||
// handleListJobs lists jobs for the current user
|
||||
func (s *Server) handleListJobs(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := getUserID(r)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, user_id, name, status, progress, frame_start, frame_end, output_format,
|
||||
created_at, started_at, completed_at, error_message
|
||||
FROM jobs WHERE user_id = ? ORDER BY created_at DESC`,
|
||||
userID,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query jobs: %v", err))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
jobs := []types.Job{}
|
||||
for rows.Next() {
|
||||
var job types.Job
|
||||
var startedAt, completedAt sql.NullTime
|
||||
|
||||
err := rows.Scan(
|
||||
&job.ID, &job.UserID, &job.Name, &job.Status, &job.Progress,
|
||||
&job.FrameStart, &job.FrameEnd, &job.OutputFormat,
|
||||
&job.CreatedAt, &startedAt, &completedAt, &job.ErrorMessage,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan job: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if startedAt.Valid {
|
||||
job.StartedAt = &startedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
job.CompletedAt = &completedAt.Time
|
||||
}
|
||||
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, jobs)
|
||||
}
|
||||
|
||||
// handleGetJob gets a specific job
|
||||
func (s *Server) handleGetJob(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := getUserID(r)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jobID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var job types.Job
|
||||
var startedAt, completedAt sql.NullTime
|
||||
|
||||
err = s.db.QueryRow(
|
||||
`SELECT id, user_id, name, status, progress, frame_start, frame_end, output_format,
|
||||
created_at, started_at, completed_at, error_message
|
||||
FROM jobs WHERE id = ? AND user_id = ?`,
|
||||
jobID, userID,
|
||||
).Scan(
|
||||
&job.ID, &job.UserID, &job.Name, &job.Status, &job.Progress,
|
||||
&job.FrameStart, &job.FrameEnd, &job.OutputFormat,
|
||||
&job.CreatedAt, &startedAt, &completedAt, &job.ErrorMessage,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Job not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query job: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if startedAt.Valid {
|
||||
job.StartedAt = &startedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
job.CompletedAt = &completedAt.Time
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, job)
|
||||
}
|
||||
|
||||
// handleCancelJob cancels a job
|
||||
func (s *Server) handleCancelJob(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := getUserID(r)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jobID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
result, err := s.db.Exec(
|
||||
`UPDATE jobs SET status = ? WHERE id = ? AND user_id = ?`,
|
||||
types.JobStatusCancelled, jobID, userID,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to cancel job: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
rowsAffected, _ := result.RowsAffected()
|
||||
if rowsAffected == 0 {
|
||||
s.respondError(w, http.StatusNotFound, "Job not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Cancel pending tasks
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ? WHERE job_id = ? AND status = ?`,
|
||||
types.TaskStatusFailed, jobID, types.TaskStatusPending,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to cancel tasks: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Job cancelled"})
|
||||
}
|
||||
|
||||
// handleUploadJobFile handles file upload for a job
|
||||
func (s *Server) handleUploadJobFile(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := getUserID(r)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jobID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Verify job belongs to user
|
||||
var jobUserID int64
|
||||
err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Job not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to verify job: %v", err))
|
||||
return
|
||||
}
|
||||
if jobUserID != userID {
|
||||
s.respondError(w, http.StatusForbidden, "Access denied")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse multipart form
|
||||
err = r.ParseMultipartForm(100 << 20) // 100 MB
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Failed to parse form")
|
||||
return
|
||||
}
|
||||
|
||||
file, header, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "No file provided")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Save file
|
||||
filePath, err := s.storage.SaveUpload(jobID, header.Filename, file)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to save file: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Record in database
|
||||
result, err := s.db.Exec(
|
||||
`INSERT INTO job_files (job_id, file_type, file_path, file_name, file_size)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
jobID, types.JobFileTypeInput, filePath, header.Filename, header.Size,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to record file: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
fileID, _ := result.LastInsertId()
|
||||
|
||||
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"id": fileID,
|
||||
"file_name": header.Filename,
|
||||
"file_path": filePath,
|
||||
"file_size": header.Size,
|
||||
})
|
||||
}
|
||||
|
||||
// handleListJobFiles lists files for a job
|
||||
func (s *Server) handleListJobFiles(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := getUserID(r)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jobID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Verify job belongs to user
|
||||
var jobUserID int64
|
||||
err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Job not found")
|
||||
return
|
||||
}
|
||||
if jobUserID != userID {
|
||||
s.respondError(w, http.StatusForbidden, "Access denied")
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, job_id, file_type, file_path, file_name, file_size, created_at
|
||||
FROM job_files WHERE job_id = ? ORDER BY created_at DESC`,
|
||||
jobID,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query files: %v", err))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
files := []types.JobFile{}
|
||||
for rows.Next() {
|
||||
var file types.JobFile
|
||||
err := rows.Scan(
|
||||
&file.ID, &file.JobID, &file.FileType, &file.FilePath,
|
||||
&file.FileName, &file.FileSize, &file.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan file: %v", err))
|
||||
return
|
||||
}
|
||||
files = append(files, file)
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, files)
|
||||
}
|
||||
|
||||
// handleDownloadJobFile downloads a job file
|
||||
func (s *Server) handleDownloadJobFile(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := getUserID(r)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jobID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
fileID, err := parseID(r, "fileId")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Verify job belongs to user
|
||||
var jobUserID int64
|
||||
err = s.db.QueryRow("SELECT user_id FROM jobs WHERE id = ?", jobID).Scan(&jobUserID)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Job not found")
|
||||
return
|
||||
}
|
||||
if jobUserID != userID {
|
||||
s.respondError(w, http.StatusForbidden, "Access denied")
|
||||
return
|
||||
}
|
||||
|
||||
// Get file info
|
||||
var filePath, fileName string
|
||||
err = s.db.QueryRow(
|
||||
`SELECT file_path, file_name FROM job_files WHERE id = ? AND job_id = ?`,
|
||||
fileID, jobID,
|
||||
).Scan(&filePath, &fileName)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "File not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query file: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Open file
|
||||
file, err := s.storage.GetFile(filePath)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusNotFound, "File not found on disk")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Set headers
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", fileName))
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
|
||||
// Stream file
|
||||
io.Copy(w, file)
|
||||
}
|
||||
|
||||
// handleStreamVideo streams MP4 video file with range support
|
||||
func (s *Server) handleStreamVideo(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := getUserID(r)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jobID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Verify job belongs to user
|
||||
var jobUserID int64
|
||||
var outputFormat string
|
||||
err = s.db.QueryRow("SELECT user_id, output_format FROM jobs WHERE id = ?", jobID).Scan(&jobUserID, &outputFormat)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Job not found")
|
||||
return
|
||||
}
|
||||
if jobUserID != userID {
|
||||
s.respondError(w, http.StatusForbidden, "Access denied")
|
||||
return
|
||||
}
|
||||
|
||||
// Find MP4 file
|
||||
var filePath, fileName string
|
||||
err = s.db.QueryRow(
|
||||
`SELECT file_path, file_name FROM job_files
|
||||
WHERE job_id = ? AND file_type = ? AND file_name LIKE '%.mp4'
|
||||
ORDER BY created_at DESC LIMIT 1`,
|
||||
jobID, types.JobFileTypeOutput,
|
||||
).Scan(&filePath, &fileName)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Video file not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query file: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Open file
|
||||
file, err := s.storage.GetFile(filePath)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusNotFound, "File not found on disk")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Get file info
|
||||
fileInfo, err := file.Stat()
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, "Failed to get file info")
|
||||
return
|
||||
}
|
||||
|
||||
fileSize := fileInfo.Size()
|
||||
|
||||
// Handle range requests for video seeking
|
||||
rangeHeader := r.Header.Get("Range")
|
||||
if rangeHeader != "" {
|
||||
// Parse range header
|
||||
var start, end int64
|
||||
fmt.Sscanf(rangeHeader, "bytes=%d-%d", &start, &end)
|
||||
if end == 0 {
|
||||
end = fileSize - 1
|
||||
}
|
||||
|
||||
// Seek to start position
|
||||
file.Seek(start, 0)
|
||||
|
||||
// Set headers for partial content
|
||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, fileSize))
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", end-start+1))
|
||||
w.Header().Set("Content-Type", "video/mp4")
|
||||
w.WriteHeader(http.StatusPartialContent)
|
||||
|
||||
// Copy partial content
|
||||
io.CopyN(w, file, end-start+1)
|
||||
} else {
|
||||
// Full file
|
||||
w.Header().Set("Content-Type", "video/mp4")
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", fileSize))
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
io.Copy(w, file)
|
||||
}
|
||||
}
|
||||
|
||||
582
internal/api/runners.go
Normal file
582
internal/api/runners.go
Normal file
@@ -0,0 +1,582 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"fuego/internal/auth"
|
||||
"fuego/pkg/types"
|
||||
)
|
||||
|
||||
// handleListRunners lists all runners
|
||||
func (s *Server) handleListRunners(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := getUserID(r)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, name, hostname, ip_address, status, last_heartbeat, capabilities, created_at
|
||||
FROM runners ORDER BY created_at DESC`,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query runners: %v", err))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
runners := []types.Runner{}
|
||||
for rows.Next() {
|
||||
var runner types.Runner
|
||||
err := rows.Scan(
|
||||
&runner.ID, &runner.Name, &runner.Hostname, &runner.IPAddress,
|
||||
&runner.Status, &runner.LastHeartbeat, &runner.Capabilities, &runner.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan runner: %v", err))
|
||||
return
|
||||
}
|
||||
runners = append(runners, runner)
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, runners)
|
||||
}
|
||||
|
||||
// runnerAuthMiddleware verifies runner requests using HMAC signatures
|
||||
func (s *Server) runnerAuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get runner ID from query string
|
||||
runnerIDStr := r.URL.Query().Get("runner_id")
|
||||
if runnerIDStr == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "runner_id required in query string")
|
||||
return
|
||||
}
|
||||
|
||||
var runnerID int64
|
||||
_, err := fmt.Sscanf(runnerIDStr, "%d", &runnerID)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "invalid runner_id")
|
||||
return
|
||||
}
|
||||
|
||||
// Get runner secret
|
||||
runnerSecret, err := s.secrets.GetRunnerSecret(runnerID)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, "runner not found or not verified")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify request signature
|
||||
valid, err := auth.VerifyRequest(r, runnerSecret, 5*time.Minute)
|
||||
if err != nil || !valid {
|
||||
s.respondError(w, http.StatusUnauthorized, "invalid signature")
|
||||
return
|
||||
}
|
||||
|
||||
// Add runner ID to context
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, "runner_id", runnerID)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// handleRegisterRunner registers a new runner
|
||||
func (s *Server) handleRegisterRunner(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
types.RegisterRunnerRequest
|
||||
RegistrationToken string `json:"registration_token"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "Runner name is required")
|
||||
return
|
||||
}
|
||||
|
||||
if req.RegistrationToken == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "Registration token is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate registration token
|
||||
valid, err := s.secrets.ValidateRegistrationToken(req.RegistrationToken)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to validate token: %v", err))
|
||||
return
|
||||
}
|
||||
if !valid {
|
||||
s.respondError(w, http.StatusUnauthorized, "Invalid or expired registration token")
|
||||
return
|
||||
}
|
||||
|
||||
// Get manager secret
|
||||
managerSecret, err := s.secrets.GetManagerSecret()
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, "Failed to get manager secret")
|
||||
return
|
||||
}
|
||||
|
||||
// Generate runner secret
|
||||
runnerSecret, err := s.secrets.GenerateRunnerSecret()
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, "Failed to generate runner secret")
|
||||
return
|
||||
}
|
||||
|
||||
// Register runner
|
||||
result, err := s.db.Exec(
|
||||
`INSERT INTO runners (name, hostname, ip_address, status, last_heartbeat, capabilities,
|
||||
registration_token, runner_secret, manager_secret, verified)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
req.Name, req.Hostname, req.IPAddress, types.RunnerStatusOnline, time.Now(), req.Capabilities,
|
||||
req.RegistrationToken, runnerSecret, managerSecret, true,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to register runner: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
runnerID, _ := result.LastInsertId()
|
||||
|
||||
// Return runner info with secrets
|
||||
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"id": runnerID,
|
||||
"name": req.Name,
|
||||
"hostname": req.Hostname,
|
||||
"ip_address": req.IPAddress,
|
||||
"status": types.RunnerStatusOnline,
|
||||
"runner_secret": runnerSecret,
|
||||
"manager_secret": managerSecret,
|
||||
"verified": true,
|
||||
})
|
||||
}
|
||||
|
||||
// handleRunnerHeartbeat updates runner heartbeat
|
||||
func (s *Server) handleRunnerHeartbeat(w http.ResponseWriter, r *http.Request) {
|
||||
runnerID, ok := r.Context().Value("runner_id").(int64)
|
||||
if !ok {
|
||||
s.respondError(w, http.StatusBadRequest, "runner_id not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
_, err := s.db.Exec(
|
||||
`UPDATE runners SET last_heartbeat = ?, status = ? WHERE id = ?`,
|
||||
time.Now(), types.RunnerStatusOnline, runnerID,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update heartbeat: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Heartbeat updated"})
|
||||
}
|
||||
|
||||
// handleGetRunnerTasks gets pending tasks for a runner
|
||||
func (s *Server) handleGetRunnerTasks(w http.ResponseWriter, r *http.Request) {
|
||||
runnerID, ok := r.Context().Value("runner_id").(int64)
|
||||
if !ok {
|
||||
s.respondError(w, http.StatusBadRequest, "runner_id not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
// Get pending tasks
|
||||
rows, err := s.db.Query(
|
||||
`SELECT t.id, t.job_id, t.runner_id, t.frame_start, t.frame_end, t.status, t.output_path,
|
||||
t.created_at, t.started_at, t.completed_at, t.error_message,
|
||||
j.name as job_name, j.output_format
|
||||
FROM tasks t
|
||||
JOIN jobs j ON t.job_id = j.id
|
||||
WHERE t.status = ? AND j.status != ?
|
||||
ORDER BY t.created_at ASC
|
||||
LIMIT 10`,
|
||||
types.TaskStatusPending, types.JobStatusCancelled,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query tasks: %v", err))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
tasks := []map[string]interface{}{}
|
||||
for rows.Next() {
|
||||
var task types.Task
|
||||
var runnerID sql.NullInt64
|
||||
var startedAt, completedAt sql.NullTime
|
||||
var jobName, outputFormat string
|
||||
|
||||
err := rows.Scan(
|
||||
&task.ID, &task.JobID, &runnerID, &task.FrameStart, &task.FrameEnd,
|
||||
&task.Status, &task.OutputPath, &task.CreatedAt,
|
||||
&startedAt, &completedAt, &task.ErrorMessage,
|
||||
&jobName, &outputFormat,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan task: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if runnerID.Valid {
|
||||
task.RunnerID = &runnerID.Int64
|
||||
}
|
||||
if startedAt.Valid {
|
||||
task.StartedAt = &startedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
task.CompletedAt = &completedAt.Time
|
||||
}
|
||||
|
||||
// Get input files for the job
|
||||
var inputFiles []string
|
||||
fileRows, err := s.db.Query(
|
||||
`SELECT file_path FROM job_files WHERE job_id = ? AND file_type = ?`,
|
||||
task.JobID, types.JobFileTypeInput,
|
||||
)
|
||||
if err == nil {
|
||||
for fileRows.Next() {
|
||||
var filePath string
|
||||
if err := fileRows.Scan(&filePath); err == nil {
|
||||
inputFiles = append(inputFiles, filePath)
|
||||
}
|
||||
}
|
||||
fileRows.Close()
|
||||
}
|
||||
|
||||
tasks = append(tasks, map[string]interface{}{
|
||||
"task": task,
|
||||
"job_name": jobName,
|
||||
"output_format": outputFormat,
|
||||
"input_files": inputFiles,
|
||||
})
|
||||
|
||||
// Assign task to runner
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET runner_id = ?, status = ? WHERE id = ?`,
|
||||
runnerID, types.TaskStatusRunning, task.ID,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to assign task: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, tasks)
|
||||
}
|
||||
|
||||
// handleCompleteTask marks a task as completed
|
||||
func (s *Server) handleCompleteTask(w http.ResponseWriter, r *http.Request) {
|
||||
taskID, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
OutputPath string `json:"output_path"`
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
status := types.TaskStatusCompleted
|
||||
if !req.Success {
|
||||
status = types.TaskStatusFailed
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE tasks SET status = ?, output_path = ?, completed_at = ?, error_message = ? WHERE id = ?`,
|
||||
status, req.OutputPath, now, req.Error, taskID,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update task: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Update job progress
|
||||
var jobID int64
|
||||
var frameStart, frameEnd int
|
||||
err = s.db.QueryRow(
|
||||
`SELECT job_id, frame_start, frame_end FROM tasks WHERE id = ?`,
|
||||
taskID,
|
||||
).Scan(&jobID, &frameStart, &frameEnd)
|
||||
if err == nil {
|
||||
// Count completed tasks
|
||||
var totalTasks, completedTasks int
|
||||
s.db.QueryRow(
|
||||
`SELECT COUNT(*) FROM tasks WHERE job_id = ?`,
|
||||
jobID,
|
||||
).Scan(&totalTasks)
|
||||
s.db.QueryRow(
|
||||
`SELECT COUNT(*) FROM tasks WHERE job_id = ? AND status = ?`,
|
||||
jobID, types.TaskStatusCompleted,
|
||||
).Scan(&completedTasks)
|
||||
|
||||
progress := float64(completedTasks) / float64(totalTasks) * 100.0
|
||||
|
||||
// Update job status
|
||||
var jobStatus string
|
||||
var outputFormat string
|
||||
s.db.QueryRow(`SELECT output_format FROM jobs WHERE id = ?`, jobID).Scan(&outputFormat)
|
||||
|
||||
if completedTasks == totalTasks {
|
||||
jobStatus = string(types.JobStatusCompleted)
|
||||
now := time.Now()
|
||||
s.db.Exec(
|
||||
`UPDATE jobs SET status = ?, progress = ?, completed_at = ? WHERE id = ?`,
|
||||
jobStatus, progress, now, jobID,
|
||||
)
|
||||
|
||||
// For MP4 jobs, create a video generation task
|
||||
if outputFormat == "MP4" {
|
||||
go s.generateMP4Video(jobID)
|
||||
}
|
||||
} else {
|
||||
jobStatus = string(types.JobStatusRunning)
|
||||
var startedAt sql.NullTime
|
||||
s.db.QueryRow(`SELECT started_at FROM jobs WHERE id = ?`, jobID).Scan(&startedAt)
|
||||
if !startedAt.Valid {
|
||||
now := time.Now()
|
||||
s.db.Exec(`UPDATE jobs SET started_at = ? WHERE id = ?`, now, jobID)
|
||||
}
|
||||
s.db.Exec(
|
||||
`UPDATE jobs SET status = ?, progress = ? WHERE id = ?`,
|
||||
jobStatus, progress, jobID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Task completed"})
|
||||
}
|
||||
|
||||
// handleUpdateTaskProgress updates task progress
|
||||
func (s *Server) handleUpdateTaskProgress(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := parseID(r, "id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Progress float64 `json:"progress"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
// This is mainly for logging/debugging, actual progress is calculated from completed tasks
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Progress updated"})
|
||||
}
|
||||
|
||||
// handleDownloadFileForRunner allows runners to download job files
|
||||
func (s *Server) handleDownloadFileForRunner(w http.ResponseWriter, r *http.Request) {
|
||||
jobID, err := parseID(r, "jobId")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
fileName := chi.URLParam(r, "fileName")
|
||||
|
||||
// Find the file in the database
|
||||
var filePath string
|
||||
err = s.db.QueryRow(
|
||||
`SELECT file_path FROM job_files WHERE job_id = ? AND file_name = ?`,
|
||||
jobID, fileName,
|
||||
).Scan(&filePath)
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "File not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query file: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Open and serve file
|
||||
file, err := s.storage.GetFile(filePath)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusNotFound, "File not found on disk")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", fileName))
|
||||
io.Copy(w, file)
|
||||
}
|
||||
|
||||
// handleUploadFileFromRunner allows runners to upload output files
|
||||
func (s *Server) handleUploadFileFromRunner(w http.ResponseWriter, r *http.Request) {
|
||||
jobID, err := parseID(r, "jobId")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err = r.ParseMultipartForm(100 << 20) // 100 MB
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "Failed to parse form")
|
||||
return
|
||||
}
|
||||
|
||||
file, header, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, "No file provided")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Save file
|
||||
filePath, err := s.storage.SaveOutput(jobID, header.Filename, file)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to save file: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Record in database
|
||||
_, err = s.db.Exec(
|
||||
`INSERT INTO job_files (job_id, file_type, file_path, file_name, file_size)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
jobID, types.JobFileTypeOutput, filePath, header.Filename, header.Size,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to record file: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"file_path": filePath,
|
||||
"file_name": header.Filename,
|
||||
})
|
||||
}
|
||||
|
||||
// generateMP4Video generates MP4 video from PNG frames for a completed job
|
||||
func (s *Server) generateMP4Video(jobID int64) {
|
||||
// This would be called by a runner or external process
|
||||
// For now, we'll create a special task that runners can pick up
|
||||
// In a production system, you might want to use a job queue or have a dedicated video processor
|
||||
|
||||
// Get all PNG output files for this job
|
||||
rows, err := s.db.Query(
|
||||
`SELECT file_path, file_name FROM job_files
|
||||
WHERE job_id = ? AND file_type = ? AND file_name LIKE '%.png'
|
||||
ORDER BY file_name`,
|
||||
jobID, types.JobFileTypeOutput,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to query PNG files for job %d: %v", jobID, err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var pngFiles []string
|
||||
for rows.Next() {
|
||||
var filePath, fileName string
|
||||
if err := rows.Scan(&filePath, &fileName); err == nil {
|
||||
pngFiles = append(pngFiles, filePath)
|
||||
}
|
||||
}
|
||||
|
||||
if len(pngFiles) == 0 {
|
||||
log.Printf("No PNG files found for job %d", jobID)
|
||||
return
|
||||
}
|
||||
|
||||
// Note: Video generation will be handled by runners when they complete tasks
|
||||
// Runners can check job status and generate MP4 when all frames are complete
|
||||
log.Printf("Job %d completed with %d PNG frames - ready for MP4 generation", jobID, len(pngFiles))
|
||||
}
|
||||
|
||||
// handleGetJobStatusForRunner allows runners to check job status
|
||||
func (s *Server) handleGetJobStatusForRunner(w http.ResponseWriter, r *http.Request) {
|
||||
jobID, err := parseID(r, "jobId")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var job types.Job
|
||||
var startedAt, completedAt sql.NullTime
|
||||
|
||||
err = s.db.QueryRow(
|
||||
`SELECT id, user_id, name, status, progress, frame_start, frame_end, output_format,
|
||||
created_at, started_at, completed_at, error_message
|
||||
FROM jobs WHERE id = ?`,
|
||||
jobID,
|
||||
).Scan(
|
||||
&job.ID, &job.UserID, &job.Name, &job.Status, &job.Progress,
|
||||
&job.FrameStart, &job.FrameEnd, &job.OutputFormat,
|
||||
&job.CreatedAt, &startedAt, &completedAt, &job.ErrorMessage,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
s.respondError(w, http.StatusNotFound, "Job not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query job: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if startedAt.Valid {
|
||||
job.StartedAt = &startedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
job.CompletedAt = &completedAt.Time
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, job)
|
||||
}
|
||||
|
||||
// handleGetJobFilesForRunner allows runners to get job files
|
||||
func (s *Server) handleGetJobFilesForRunner(w http.ResponseWriter, r *http.Request) {
|
||||
jobID, err := parseID(r, "jobId")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, job_id, file_type, file_path, file_name, file_size, created_at
|
||||
FROM job_files WHERE job_id = ? ORDER BY file_name`,
|
||||
jobID,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to query files: %v", err))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
files := []types.JobFile{}
|
||||
for rows.Next() {
|
||||
var file types.JobFile
|
||||
err := rows.Scan(
|
||||
&file.ID, &file.JobID, &file.FileType, &file.FilePath,
|
||||
&file.FileName, &file.FileSize, &file.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to scan file: %v", err))
|
||||
return
|
||||
}
|
||||
files = append(files, file)
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, files)
|
||||
}
|
||||
|
||||
284
internal/api/server.go
Normal file
284
internal/api/server.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/go-chi/cors"
|
||||
"fuego/internal/auth"
|
||||
"fuego/internal/database"
|
||||
"fuego/internal/storage"
|
||||
)
|
||||
|
||||
// Server represents the API server
|
||||
type Server struct {
|
||||
db *database.DB
|
||||
auth *auth.Auth
|
||||
secrets *auth.Secrets
|
||||
storage *storage.Storage
|
||||
router *chi.Mux
|
||||
}
|
||||
|
||||
// NewServer creates a new API server
|
||||
func NewServer(db *database.DB, auth *auth.Auth, storage *storage.Storage) (*Server, error) {
|
||||
secrets, err := auth.NewSecrets(db.DB)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize secrets: %w", err)
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
db: db,
|
||||
auth: auth,
|
||||
secrets: secrets,
|
||||
storage: storage,
|
||||
router: chi.NewRouter(),
|
||||
}
|
||||
|
||||
s.setupMiddleware()
|
||||
s.setupRoutes()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// setupMiddleware configures middleware
|
||||
func (s *Server) setupMiddleware() {
|
||||
s.router.Use(middleware.Logger)
|
||||
s.router.Use(middleware.Recoverer)
|
||||
s.router.Use(middleware.Timeout(60 * time.Second))
|
||||
|
||||
s.router.Use(cors.Handler(cors.Options{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"},
|
||||
ExposedHeaders: []string{"Link"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 300,
|
||||
}))
|
||||
}
|
||||
|
||||
// setupRoutes configures routes
|
||||
func (s *Server) setupRoutes() {
|
||||
// Public routes
|
||||
s.router.Route("/api/auth", func(r chi.Router) {
|
||||
r.Get("/google/login", s.handleGoogleLogin)
|
||||
r.Get("/google/callback", s.handleGoogleCallback)
|
||||
r.Get("/discord/login", s.handleDiscordLogin)
|
||||
r.Get("/discord/callback", s.handleDiscordCallback)
|
||||
r.Post("/logout", s.handleLogout)
|
||||
r.Get("/me", s.handleGetMe)
|
||||
})
|
||||
|
||||
// Protected routes
|
||||
s.router.Route("/api/jobs", func(r chi.Router) {
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(s.auth.Middleware(next.ServeHTTP))
|
||||
})
|
||||
r.Post("/", s.handleCreateJob)
|
||||
r.Get("/", s.handleListJobs)
|
||||
r.Get("/{id}", s.handleGetJob)
|
||||
r.Delete("/{id}", s.handleCancelJob)
|
||||
r.Post("/{id}/upload", s.handleUploadJobFile)
|
||||
r.Get("/{id}/files", s.handleListJobFiles)
|
||||
r.Get("/{id}/files/{fileId}/download", s.handleDownloadJobFile)
|
||||
r.Get("/{id}/video", s.handleStreamVideo)
|
||||
})
|
||||
|
||||
s.router.Route("/api/runners", func(r chi.Router) {
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(s.auth.Middleware(next.ServeHTTP))
|
||||
})
|
||||
r.Get("/", s.handleListRunners)
|
||||
})
|
||||
|
||||
// Admin routes
|
||||
s.router.Route("/api/admin", func(r chi.Router) {
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(s.auth.AdminMiddleware(next.ServeHTTP))
|
||||
})
|
||||
r.Route("/runners", func(r chi.Router) {
|
||||
r.Route("/tokens", func(r chi.Router) {
|
||||
r.Post("/", s.handleGenerateRegistrationToken)
|
||||
r.Get("/", s.handleListRegistrationTokens)
|
||||
r.Delete("/{id}", s.handleRevokeRegistrationToken)
|
||||
})
|
||||
r.Get("/", s.handleListRunnersAdmin)
|
||||
r.Post("/{id}/verify", s.handleVerifyRunner)
|
||||
r.Delete("/{id}", s.handleDeleteRunner)
|
||||
})
|
||||
})
|
||||
|
||||
// Runner API
|
||||
s.router.Route("/api/runner", func(r chi.Router) {
|
||||
// Registration doesn't require auth (uses token)
|
||||
r.Post("/register", s.handleRegisterRunner)
|
||||
|
||||
// All other endpoints require runner authentication
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(s.runnerAuthMiddleware(next.ServeHTTP))
|
||||
})
|
||||
r.Post("/heartbeat", s.handleRunnerHeartbeat)
|
||||
r.Get("/tasks", s.handleGetRunnerTasks)
|
||||
r.Post("/tasks/{id}/complete", s.handleCompleteTask)
|
||||
r.Post("/tasks/{id}/progress", s.handleUpdateTaskProgress)
|
||||
r.Get("/files/{jobId}/{fileName}", s.handleDownloadFileForRunner)
|
||||
r.Post("/files/{jobId}/upload", s.handleUploadFileFromRunner)
|
||||
r.Get("/jobs/{jobId}/status", s.handleGetJobStatusForRunner)
|
||||
r.Get("/jobs/{jobId}/files", s.handleGetJobFilesForRunner)
|
||||
})
|
||||
})
|
||||
|
||||
// Serve static files (built React app)
|
||||
s.router.Handle("/*", http.FileServer(http.Dir("./web/dist")))
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.router.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// JSON response helpers
|
||||
func (s *Server) respondJSON(w http.ResponseWriter, status int, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
if err := json.NewEncoder(w).Encode(data); err != nil {
|
||||
log.Printf("Failed to encode JSON response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) respondError(w http.ResponseWriter, status int, message string) {
|
||||
s.respondJSON(w, status, map[string]string{"error": message})
|
||||
}
|
||||
|
||||
// Auth handlers
|
||||
func (s *Server) handleGoogleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
url, err := s.auth.GoogleLoginURL()
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, url, http.StatusFound)
|
||||
}
|
||||
|
||||
func (s *Server) handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "Missing code parameter")
|
||||
return
|
||||
}
|
||||
|
||||
session, err := s.auth.GoogleCallback(r.Context(), code)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionID := s.auth.CreateSession(session)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_id",
|
||||
Value: sessionID,
|
||||
Path: "/",
|
||||
MaxAge: 86400,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
}
|
||||
|
||||
func (s *Server) handleDiscordLogin(w http.ResponseWriter, r *http.Request) {
|
||||
url, err := s.auth.DiscordLoginURL()
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, url, http.StatusFound)
|
||||
}
|
||||
|
||||
func (s *Server) handleDiscordCallback(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
s.respondError(w, http.StatusBadRequest, "Missing code parameter")
|
||||
return
|
||||
}
|
||||
|
||||
session, err := s.auth.DiscordCallback(r.Context(), code)
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionID := s.auth.CreateSession(session)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_id",
|
||||
Value: sessionID,
|
||||
Path: "/",
|
||||
MaxAge: 86400,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
}
|
||||
|
||||
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("session_id")
|
||||
if err == nil {
|
||||
s.auth.DeleteSession(cookie.Value)
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_id",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
})
|
||||
s.respondJSON(w, http.StatusOK, map[string]string{"message": "Logged out"})
|
||||
}
|
||||
|
||||
func (s *Server) handleGetMe(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("session_id")
|
||||
if err != nil {
|
||||
s.respondError(w, http.StatusUnauthorized, "Not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
session, ok := s.auth.GetSession(cookie.Value)
|
||||
if !ok {
|
||||
s.respondError(w, http.StatusUnauthorized, "Invalid session")
|
||||
return
|
||||
}
|
||||
|
||||
s.respondJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"id": session.UserID,
|
||||
"email": session.Email,
|
||||
"name": session.Name,
|
||||
"is_admin": session.IsAdmin,
|
||||
})
|
||||
}
|
||||
|
||||
// Helper to get user ID from context
|
||||
func getUserID(r *http.Request) (int64, error) {
|
||||
userID, ok := auth.GetUserID(r.Context())
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("user ID not found in context")
|
||||
}
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
// Helper to parse ID from URL
|
||||
func parseID(r *http.Request, param string) (int64, error) {
|
||||
idStr := chi.URLParam(r, param)
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid ID: %s", idStr)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
302
internal/auth/auth.go
Normal file
302
internal/auth/auth.go
Normal file
@@ -0,0 +1,302 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
// Auth handles authentication
|
||||
type Auth struct {
|
||||
db *sql.DB
|
||||
googleConfig *oauth2.Config
|
||||
discordConfig *oauth2.Config
|
||||
sessionStore map[string]*Session
|
||||
}
|
||||
|
||||
// Session represents a user session
|
||||
type Session struct {
|
||||
UserID int64
|
||||
Email string
|
||||
Name string
|
||||
IsAdmin bool
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// NewAuth creates a new auth instance
|
||||
func NewAuth(db *sql.DB) (*Auth, error) {
|
||||
auth := &Auth{
|
||||
db: db,
|
||||
sessionStore: make(map[string]*Session),
|
||||
}
|
||||
|
||||
// Initialize Google OAuth
|
||||
googleClientID := os.Getenv("GOOGLE_CLIENT_ID")
|
||||
googleClientSecret := os.Getenv("GOOGLE_CLIENT_SECRET")
|
||||
if googleClientID != "" && googleClientSecret != "" {
|
||||
auth.googleConfig = &oauth2.Config{
|
||||
ClientID: googleClientID,
|
||||
ClientSecret: googleClientSecret,
|
||||
RedirectURL: os.Getenv("GOOGLE_REDIRECT_URL"),
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize Discord OAuth
|
||||
discordClientID := os.Getenv("DISCORD_CLIENT_ID")
|
||||
discordClientSecret := os.Getenv("DISCORD_CLIENT_SECRET")
|
||||
if discordClientID != "" && discordClientSecret != "" {
|
||||
auth.discordConfig = &oauth2.Config{
|
||||
ClientID: discordClientID,
|
||||
ClientSecret: discordClientSecret,
|
||||
RedirectURL: os.Getenv("DISCORD_REDIRECT_URL"),
|
||||
Scopes: []string{"identify", "email"},
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://discord.com/api/oauth2/authorize",
|
||||
TokenURL: "https://discord.com/api/oauth2/token",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// GoogleLoginURL returns the Google OAuth login URL
|
||||
func (a *Auth) GoogleLoginURL() (string, error) {
|
||||
if a.googleConfig == nil {
|
||||
return "", fmt.Errorf("Google OAuth not configured")
|
||||
}
|
||||
state := uuid.New().String()
|
||||
return a.googleConfig.AuthCodeURL(state), nil
|
||||
}
|
||||
|
||||
// DiscordLoginURL returns the Discord OAuth login URL
|
||||
func (a *Auth) DiscordLoginURL() (string, error) {
|
||||
if a.discordConfig == nil {
|
||||
return "", fmt.Errorf("Discord OAuth not configured")
|
||||
}
|
||||
state := uuid.New().String()
|
||||
return a.discordConfig.AuthCodeURL(state), nil
|
||||
}
|
||||
|
||||
// GoogleCallback handles Google OAuth callback
|
||||
func (a *Auth) GoogleCallback(ctx context.Context, code string) (*Session, error) {
|
||||
if a.googleConfig == nil {
|
||||
return nil, fmt.Errorf("Google OAuth not configured")
|
||||
}
|
||||
|
||||
token, err := a.googleConfig.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange token: %w", err)
|
||||
}
|
||||
|
||||
client := a.googleConfig.Client(ctx, token)
|
||||
resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user info: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var userInfo struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode user info: %w", err)
|
||||
}
|
||||
|
||||
return a.getOrCreateUser("google", userInfo.ID, userInfo.Email, userInfo.Name)
|
||||
}
|
||||
|
||||
// DiscordCallback handles Discord OAuth callback
|
||||
func (a *Auth) DiscordCallback(ctx context.Context, code string) (*Session, error) {
|
||||
if a.discordConfig == nil {
|
||||
return nil, fmt.Errorf("Discord OAuth not configured")
|
||||
}
|
||||
|
||||
token, err := a.discordConfig.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange token: %w", err)
|
||||
}
|
||||
|
||||
client := a.discordConfig.Client(ctx, token)
|
||||
resp, err := client.Get("https://discord.com/api/users/@me")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user info: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var userInfo struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode user info: %w", err)
|
||||
}
|
||||
|
||||
return a.getOrCreateUser("discord", userInfo.ID, userInfo.Email, userInfo.Username)
|
||||
}
|
||||
|
||||
// getOrCreateUser gets or creates a user in the database
|
||||
func (a *Auth) getOrCreateUser(provider, oauthID, email, name string) (*Session, error) {
|
||||
var userID int64
|
||||
var dbEmail, dbName string
|
||||
var isAdmin bool
|
||||
|
||||
err := a.db.QueryRow(
|
||||
"SELECT id, email, name, is_admin FROM users WHERE oauth_provider = ? AND oauth_id = ?",
|
||||
provider, oauthID,
|
||||
).Scan(&userID, &dbEmail, &dbName, &isAdmin)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
// Check if this is the first user
|
||||
var userCount int
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
||||
isAdmin = userCount == 0
|
||||
|
||||
// Create new user
|
||||
result, err := a.db.Exec(
|
||||
"INSERT INTO users (email, name, oauth_provider, oauth_id, is_admin) VALUES (?, ?, ?, ?, ?)",
|
||||
email, name, provider, oauthID, isAdmin,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
userID, _ = result.LastInsertId()
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("failed to query user: %w", err)
|
||||
} else {
|
||||
// Update user info if changed
|
||||
if dbEmail != email || dbName != name {
|
||||
_, err = a.db.Exec(
|
||||
"UPDATE users SET email = ?, name = ? WHERE id = ?",
|
||||
email, name, userID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
session := &Session{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
Name: name,
|
||||
IsAdmin: isAdmin,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// CreateSession creates a new session and returns a session ID
|
||||
func (a *Auth) CreateSession(session *Session) string {
|
||||
sessionID := uuid.New().String()
|
||||
a.sessionStore[sessionID] = session
|
||||
return sessionID
|
||||
}
|
||||
|
||||
// GetSession retrieves a session by ID
|
||||
func (a *Auth) GetSession(sessionID string) (*Session, bool) {
|
||||
session, ok := a.sessionStore[sessionID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
delete(a.sessionStore, sessionID)
|
||||
return nil, false
|
||||
}
|
||||
// Refresh admin status from database
|
||||
var isAdmin bool
|
||||
err := a.db.QueryRow("SELECT is_admin FROM users WHERE id = ?", session.UserID).Scan(&isAdmin)
|
||||
if err == nil {
|
||||
session.IsAdmin = isAdmin
|
||||
}
|
||||
return session, true
|
||||
}
|
||||
|
||||
// DeleteSession deletes a session
|
||||
func (a *Auth) DeleteSession(sessionID string) {
|
||||
delete(a.sessionStore, sessionID)
|
||||
}
|
||||
|
||||
// Middleware creates an authentication middleware
|
||||
func (a *Auth) Middleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("session_id")
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
session, ok := a.GetSession(cookie.Value)
|
||||
if !ok {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Add user info to request context
|
||||
ctx := context.WithValue(r.Context(), "user_id", session.UserID)
|
||||
ctx = context.WithValue(ctx, "user_email", session.Email)
|
||||
ctx = context.WithValue(ctx, "user_name", session.Name)
|
||||
ctx = context.WithValue(ctx, "is_admin", session.IsAdmin)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserID gets the user ID from context
|
||||
func GetUserID(ctx context.Context) (int64, bool) {
|
||||
userID, ok := ctx.Value("user_id").(int64)
|
||||
return userID, ok
|
||||
}
|
||||
|
||||
// IsAdmin checks if the user in context is an admin
|
||||
func IsAdmin(ctx context.Context) bool {
|
||||
isAdmin, ok := ctx.Value("is_admin").(bool)
|
||||
return ok && isAdmin
|
||||
}
|
||||
|
||||
// AdminMiddleware creates an admin-only middleware
|
||||
func (a *Auth) AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// First check authentication
|
||||
cookie, err := r.Cookie("session_id")
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
session, ok := a.GetSession(cookie.Value)
|
||||
if !ok {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Then check admin status
|
||||
if !session.IsAdmin {
|
||||
http.Error(w, "Forbidden: Admin access required", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Add user info to request context
|
||||
ctx := context.WithValue(r.Context(), "user_id", session.UserID)
|
||||
ctx = context.WithValue(ctx, "user_email", session.Email)
|
||||
ctx = context.WithValue(ctx, "user_name", session.Name)
|
||||
ctx = context.WithValue(ctx, "is_admin", session.IsAdmin)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
244
internal/auth/secrets.go
Normal file
244
internal/auth/secrets.go
Normal file
@@ -0,0 +1,244 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Secrets handles secret and token management
|
||||
type Secrets struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewSecrets creates a new secrets manager
|
||||
func NewSecrets(db *sql.DB) (*Secrets, error) {
|
||||
s := &Secrets{db: db}
|
||||
|
||||
// Ensure manager secret exists
|
||||
if err := s.ensureManagerSecret(); err != nil {
|
||||
return nil, fmt.Errorf("failed to ensure manager secret: %w", err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// ensureManagerSecret ensures a manager secret exists in the database
|
||||
func (s *Secrets) ensureManagerSecret() error {
|
||||
var count int
|
||||
err := s.db.QueryRow("SELECT COUNT(*) FROM manager_secrets").Scan(&count)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check manager secrets: %w", err)
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
// Generate new manager secret
|
||||
secret, err := generateSecret(32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate manager secret: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.db.Exec("INSERT INTO manager_secrets (secret) VALUES (?)", secret)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store manager secret: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetManagerSecret retrieves the current manager secret
|
||||
func (s *Secrets) GetManagerSecret() (string, error) {
|
||||
var secret string
|
||||
err := s.db.QueryRow("SELECT secret FROM manager_secrets ORDER BY created_at DESC LIMIT 1").Scan(&secret)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get manager secret: %w", err)
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
// GenerateRegistrationToken generates a new registration token
|
||||
func (s *Secrets) GenerateRegistrationToken(createdBy int64, expiresIn time.Duration) (string, error) {
|
||||
token, err := generateSecret(32)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate token: %w", err)
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(expiresIn)
|
||||
|
||||
_, err = s.db.Exec(
|
||||
"INSERT INTO registration_tokens (token, expires_at, created_by) VALUES (?, ?, ?)",
|
||||
token, expiresAt, createdBy,
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to store registration token: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// ValidateRegistrationToken validates a registration token
|
||||
func (s *Secrets) ValidateRegistrationToken(token string) (bool, error) {
|
||||
var used bool
|
||||
var expiresAt time.Time
|
||||
var id int64
|
||||
|
||||
err := s.db.QueryRow(
|
||||
"SELECT id, expires_at, used FROM registration_tokens WHERE token = ?",
|
||||
token,
|
||||
).Scan(&id, &expiresAt, &used)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to query token: %w", err)
|
||||
}
|
||||
|
||||
if used {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if time.Now().After(expiresAt) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Mark token as used
|
||||
_, err = s.db.Exec("UPDATE registration_tokens SET used = 1 WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to mark token as used: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ListRegistrationTokens lists all registration tokens
|
||||
func (s *Secrets) ListRegistrationTokens() ([]map[string]interface{}, error) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, token, expires_at, used, created_at, created_by
|
||||
FROM registration_tokens
|
||||
ORDER BY created_at DESC`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query tokens: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tokens []map[string]interface{}
|
||||
for rows.Next() {
|
||||
var id, createdBy sql.NullInt64
|
||||
var token string
|
||||
var expiresAt, createdAt time.Time
|
||||
var used bool
|
||||
|
||||
err := rows.Scan(&id, &token, &expiresAt, &used, &createdAt, &createdBy)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
tokens = append(tokens, map[string]interface{}{
|
||||
"id": id.Int64,
|
||||
"token": token,
|
||||
"expires_at": expiresAt,
|
||||
"used": used,
|
||||
"created_at": createdAt,
|
||||
"created_by": createdBy.Int64,
|
||||
})
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// RevokeRegistrationToken revokes a registration token
|
||||
func (s *Secrets) RevokeRegistrationToken(tokenID int64) error {
|
||||
_, err := s.db.Exec("UPDATE registration_tokens SET used = 1 WHERE id = ?", tokenID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GenerateRunnerSecret generates a unique secret for a runner
|
||||
func (s *Secrets) GenerateRunnerSecret() (string, error) {
|
||||
return generateSecret(32)
|
||||
}
|
||||
|
||||
// SignRequest signs a request with the given secret
|
||||
func SignRequest(method, path, body, secret string, timestamp time.Time) string {
|
||||
message := fmt.Sprintf("%s\n%s\n%s\n%d", method, path, body, timestamp.Unix())
|
||||
h := hmac.New(sha256.New, []byte(secret))
|
||||
h.Write([]byte(message))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// VerifyRequest verifies a signed request
|
||||
func VerifyRequest(r *http.Request, secret string, maxAge time.Duration) (bool, error) {
|
||||
signature := r.Header.Get("X-Runner-Signature")
|
||||
if signature == "" {
|
||||
return false, fmt.Errorf("missing signature")
|
||||
}
|
||||
|
||||
timestampStr := r.Header.Get("X-Runner-Timestamp")
|
||||
if timestampStr == "" {
|
||||
return false, fmt.Errorf("missing timestamp")
|
||||
}
|
||||
|
||||
var timestamp time.Time
|
||||
_, err := fmt.Sscanf(timestampStr, "%d", ×tamp)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("invalid timestamp: %w", err)
|
||||
}
|
||||
|
||||
// Check timestamp is not too old
|
||||
if time.Since(timestamp) > maxAge {
|
||||
return false, fmt.Errorf("request too old")
|
||||
}
|
||||
|
||||
// Check timestamp is not in the future (allow 1 minute clock skew)
|
||||
if timestamp.After(time.Now().Add(1 * time.Minute)) {
|
||||
return false, fmt.Errorf("timestamp in future")
|
||||
}
|
||||
|
||||
// Read body
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to read body: %w", err)
|
||||
}
|
||||
// Restore body for handler
|
||||
r.Body = io.NopCloser(strings.NewReader(string(bodyBytes)))
|
||||
|
||||
// Verify signature
|
||||
expectedSig := SignRequest(r.Method, r.URL.Path, string(bodyBytes), secret, timestamp)
|
||||
|
||||
return hmac.Equal([]byte(signature), []byte(expectedSig)), nil
|
||||
}
|
||||
|
||||
// GetRunnerSecret retrieves the runner secret for a runner ID
|
||||
func (s *Secrets) GetRunnerSecret(runnerID int64) (string, error) {
|
||||
var secret string
|
||||
err := s.db.QueryRow("SELECT runner_secret FROM runners WHERE id = ?", runnerID).Scan(&secret)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", fmt.Errorf("runner not found")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get runner secret: %w", err)
|
||||
}
|
||||
if secret == "" {
|
||||
return "", fmt.Errorf("runner not verified")
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
// generateSecret generates a random secret of the given length
|
||||
func generateSecret(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
159
internal/database/schema.go
Normal file
159
internal/database/schema.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// DB wraps the database connection
|
||||
type DB struct {
|
||||
*sql.DB
|
||||
}
|
||||
|
||||
// NewDB creates a new database connection
|
||||
func NewDB(dbPath string) (*DB, error) {
|
||||
db, err := sql.Open("sqlite3", dbPath+"?_foreign_keys=1")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
database := &DB{DB: db}
|
||||
if err := database.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("failed to migrate database: %w", err)
|
||||
}
|
||||
|
||||
return database, nil
|
||||
}
|
||||
|
||||
// migrate runs database migrations
|
||||
func (db *DB) migrate() error {
|
||||
schema := `
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
oauth_provider TEXT NOT NULL,
|
||||
oauth_id TEXT NOT NULL,
|
||||
is_admin BOOLEAN NOT NULL DEFAULT 0,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(oauth_provider, oauth_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
progress REAL NOT NULL DEFAULT 0.0,
|
||||
frame_start INTEGER NOT NULL,
|
||||
frame_end INTEGER NOT NULL,
|
||||
output_format TEXT NOT NULL DEFAULT 'PNG',
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at DATETIME,
|
||||
completed_at DATETIME,
|
||||
error_message TEXT,
|
||||
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS runners (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
hostname TEXT NOT NULL,
|
||||
ip_address TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'offline',
|
||||
last_heartbeat DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
capabilities TEXT,
|
||||
registration_token TEXT,
|
||||
runner_secret TEXT,
|
||||
manager_secret TEXT,
|
||||
verified BOOLEAN NOT NULL DEFAULT 0,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
job_id INTEGER NOT NULL,
|
||||
runner_id INTEGER,
|
||||
frame_start INTEGER NOT NULL,
|
||||
frame_end INTEGER NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
output_path TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at DATETIME,
|
||||
completed_at DATETIME,
|
||||
error_message TEXT,
|
||||
FOREIGN KEY (job_id) REFERENCES jobs(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (runner_id) REFERENCES runners(id) ON DELETE SET NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS job_files (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
job_id INTEGER NOT NULL,
|
||||
file_type TEXT NOT NULL,
|
||||
file_path TEXT NOT NULL,
|
||||
file_name TEXT NOT NULL,
|
||||
file_size INTEGER NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (job_id) REFERENCES jobs(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS manager_secrets (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
secret TEXT UNIQUE NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS registration_tokens (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
token TEXT UNIQUE NOT NULL,
|
||||
expires_at DATETIME NOT NULL,
|
||||
used BOOLEAN NOT NULL DEFAULT 0,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
created_by INTEGER,
|
||||
FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE SET NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_user_id ON jobs(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_job_id ON tasks(job_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_runner_id ON tasks(runner_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_job_files_job_id ON job_files(job_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_registration_tokens_token ON registration_tokens(token);
|
||||
CREATE INDEX IF NOT EXISTS idx_registration_tokens_expires_at ON registration_tokens(expires_at);
|
||||
`
|
||||
|
||||
if _, err := db.Exec(schema); err != nil {
|
||||
return fmt.Errorf("failed to create schema: %w", err)
|
||||
}
|
||||
|
||||
// Migrate existing tables to add new columns
|
||||
migrations := []string{
|
||||
// Add is_admin to users if it doesn't exist
|
||||
`ALTER TABLE users ADD COLUMN is_admin BOOLEAN NOT NULL DEFAULT 0`,
|
||||
// Add new columns to runners if they don't exist
|
||||
`ALTER TABLE runners ADD COLUMN registration_token TEXT`,
|
||||
`ALTER TABLE runners ADD COLUMN runner_secret TEXT`,
|
||||
`ALTER TABLE runners ADD COLUMN manager_secret TEXT`,
|
||||
`ALTER TABLE runners ADD COLUMN verified BOOLEAN NOT NULL DEFAULT 0`,
|
||||
}
|
||||
|
||||
for _, migration := range migrations {
|
||||
// SQLite doesn't support IF NOT EXISTS for ALTER TABLE, so we ignore errors
|
||||
db.Exec(migration)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the database connection
|
||||
func (db *DB) Close() error {
|
||||
return db.DB.Close()
|
||||
}
|
||||
|
||||
627
internal/runner/client.go
Normal file
627
internal/runner/client.go
Normal file
@@ -0,0 +1,627 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client represents a runner client
|
||||
type Client struct {
|
||||
managerURL string
|
||||
name string
|
||||
hostname string
|
||||
ipAddress string
|
||||
httpClient *http.Client
|
||||
runnerID int64
|
||||
runnerSecret string
|
||||
managerSecret string
|
||||
}
|
||||
|
||||
// NewClient creates a new runner client
|
||||
func NewClient(managerURL, name, hostname, ipAddress string) *Client {
|
||||
return &Client{
|
||||
managerURL: managerURL,
|
||||
name: name,
|
||||
hostname: hostname,
|
||||
ipAddress: ipAddress,
|
||||
httpClient: &http.Client{Timeout: 30 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// SetSecrets sets the runner and manager secrets
|
||||
func (c *Client) SetSecrets(runnerID int64, runnerSecret, managerSecret string) {
|
||||
c.runnerID = runnerID
|
||||
c.runnerSecret = runnerSecret
|
||||
c.managerSecret = managerSecret
|
||||
}
|
||||
|
||||
// Register registers the runner with the manager using a registration token
|
||||
func (c *Client) Register(registrationToken string) (int64, string, string, error) {
|
||||
req := map[string]interface{}{
|
||||
"name": c.name,
|
||||
"hostname": c.hostname,
|
||||
"ip_address": c.ipAddress,
|
||||
"capabilities": "blender,ffmpeg",
|
||||
"registration_token": registrationToken,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(req)
|
||||
resp, err := c.httpClient.Post(
|
||||
fmt.Sprintf("%s/api/runner/register", c.managerURL),
|
||||
"application/json",
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
if err != nil {
|
||||
return 0, "", "", fmt.Errorf("failed to register: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return 0, "", "", fmt.Errorf("registration failed: %s", string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
ID int64 `json:"id"`
|
||||
RunnerSecret string `json:"runner_secret"`
|
||||
ManagerSecret string `json:"manager_secret"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return 0, "", "", fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
c.runnerID = result.ID
|
||||
c.runnerSecret = result.RunnerSecret
|
||||
c.managerSecret = result.ManagerSecret
|
||||
|
||||
return result.ID, result.RunnerSecret, result.ManagerSecret, nil
|
||||
}
|
||||
|
||||
// signRequest signs a request with the runner secret
|
||||
func (c *Client) signRequest(method, path string, body []byte) (string, time.Time) {
|
||||
timestamp := time.Now()
|
||||
message := fmt.Sprintf("%s\n%s\n%s\n%d", method, path, string(body), timestamp.Unix())
|
||||
h := hmac.New(sha256.New, []byte(c.runnerSecret))
|
||||
h.Write([]byte(message))
|
||||
signature := hex.EncodeToString(h.Sum(nil))
|
||||
return signature, timestamp
|
||||
}
|
||||
|
||||
// doSignedRequest performs a signed HTTP request
|
||||
func (c *Client) doSignedRequest(method, path string, body []byte) (*http.Response, error) {
|
||||
if c.runnerSecret == "" {
|
||||
return nil, fmt.Errorf("runner not authenticated")
|
||||
}
|
||||
|
||||
signature, timestamp := c.signRequest(method, path, body)
|
||||
|
||||
req, err := http.NewRequest(method, fmt.Sprintf("%s%s", c.managerURL, path), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Runner-Signature", signature)
|
||||
req.Header.Set("X-Runner-Timestamp", fmt.Sprintf("%d", timestamp.Unix()))
|
||||
|
||||
return c.httpClient.Do(req)
|
||||
}
|
||||
|
||||
// HeartbeatLoop sends periodic heartbeats to the manager
|
||||
func (c *Client) HeartbeatLoop() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
req := map[string]interface{}{}
|
||||
body, _ := json.Marshal(req)
|
||||
|
||||
resp, err := c.doSignedRequest("POST", "/api/runner/heartbeat?runner_id="+fmt.Sprintf("%d", c.runnerID), body)
|
||||
if err != nil {
|
||||
log.Printf("Heartbeat failed: %v", err)
|
||||
continue
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessTasks polls for tasks and processes them
|
||||
func (c *Client) ProcessTasks() {
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
tasks, err := c.getTasks()
|
||||
if err != nil {
|
||||
log.Printf("Failed to get tasks: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, taskData := range tasks {
|
||||
taskMap, ok := taskData["task"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
jobName, _ := taskData["job_name"].(string)
|
||||
outputFormat, _ := taskData["output_format"].(string)
|
||||
inputFilesRaw, _ := taskData["input_files"].([]interface{})
|
||||
|
||||
if len(inputFilesRaw) == 0 {
|
||||
log.Printf("No input files for task %v", taskMap["id"])
|
||||
continue
|
||||
}
|
||||
|
||||
// Process the task
|
||||
if err := c.processTask(taskMap, jobName, outputFormat, inputFilesRaw); err != nil {
|
||||
taskID, _ := taskMap["id"].(float64)
|
||||
log.Printf("Failed to process task %v: %v", taskID, err)
|
||||
c.completeTask(int64(taskID), "", false, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getTasks fetches tasks from the manager
|
||||
func (c *Client) getTasks() ([]map[string]interface{}, error) {
|
||||
path := fmt.Sprintf("/api/runner/tasks?runner_id=%d", c.runnerID)
|
||||
resp, err := c.doSignedRequest("GET", path, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("failed to get tasks: %s", string(body))
|
||||
}
|
||||
|
||||
var tasks []map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tasks); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
// processTask processes a single task
|
||||
func (c *Client) processTask(task map[string]interface{}, jobName, outputFormat string, inputFiles []interface{}) error {
|
||||
taskID := int64(task["id"].(float64))
|
||||
jobID := int64(task["job_id"].(float64))
|
||||
frameStart := int(task["frame_start"].(float64))
|
||||
frameEnd := int(task["frame_end"].(float64))
|
||||
|
||||
log.Printf("Processing task %d: job %d, frames %d-%d, format: %s", taskID, jobID, frameStart, frameEnd, outputFormat)
|
||||
|
||||
// Create work directory
|
||||
workDir := filepath.Join(os.TempDir(), fmt.Sprintf("fuego-task-%d", taskID))
|
||||
if err := os.MkdirAll(workDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create work directory: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(workDir)
|
||||
|
||||
// Download input files
|
||||
blendFile := ""
|
||||
for _, filePath := range inputFiles {
|
||||
filePathStr := filePath.(string)
|
||||
if err := c.downloadFile(filePathStr, workDir); err != nil {
|
||||
return fmt.Errorf("failed to download file %s: %w", filePathStr, err)
|
||||
}
|
||||
if filepath.Ext(filePathStr) == ".blend" {
|
||||
blendFile = filepath.Join(workDir, filepath.Base(filePathStr))
|
||||
}
|
||||
}
|
||||
|
||||
if blendFile == "" {
|
||||
return fmt.Errorf("no .blend file found in input files")
|
||||
}
|
||||
|
||||
// Render frames
|
||||
outputDir := filepath.Join(workDir, "output")
|
||||
if err := os.MkdirAll(outputDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create output directory: %w", err)
|
||||
}
|
||||
|
||||
// For MP4, render as PNG first, then combine into video
|
||||
renderFormat := outputFormat
|
||||
if outputFormat == "MP4" {
|
||||
renderFormat = "PNG"
|
||||
}
|
||||
|
||||
outputPattern := filepath.Join(outputDir, fmt.Sprintf("frame_%%04d.%s", strings.ToLower(renderFormat)))
|
||||
|
||||
// Execute Blender
|
||||
cmd := exec.Command("blender", "-b", blendFile, "-o", outputPattern, "-f", fmt.Sprintf("%d", frameStart))
|
||||
cmd.Dir = workDir
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("blender failed: %w\nOutput: %s", err, string(output))
|
||||
}
|
||||
|
||||
// Find rendered output file
|
||||
outputFile := filepath.Join(outputDir, fmt.Sprintf("frame_%04d.%s", frameStart, strings.ToLower(renderFormat)))
|
||||
if _, err := os.Stat(outputFile); os.IsNotExist(err) {
|
||||
return fmt.Errorf("output file not found: %s", outputFile)
|
||||
}
|
||||
|
||||
// Upload frame file
|
||||
outputPath, err := c.uploadFile(jobID, outputFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upload output: %w", err)
|
||||
}
|
||||
|
||||
// Mark task as complete
|
||||
if err := c.completeTask(taskID, outputPath, true, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// For MP4 format, check if all frames are done and generate video
|
||||
if outputFormat == "MP4" {
|
||||
if err := c.checkAndGenerateMP4(jobID); err != nil {
|
||||
log.Printf("Failed to generate MP4 for job %d: %v", jobID, err)
|
||||
// Don't fail the task if video generation fails - frames are already uploaded
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkAndGenerateMP4 checks if all frames are complete and generates MP4 if so
|
||||
func (c *Client) checkAndGenerateMP4(jobID int64) error {
|
||||
// Check job status
|
||||
job, err := c.getJobStatus(jobID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get job status: %w", err)
|
||||
}
|
||||
|
||||
if job["status"] != "completed" {
|
||||
log.Printf("Job %d not yet complete (%v), skipping MP4 generation", jobID, job["status"])
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get all output files for this job
|
||||
files, err := c.getJobFiles(jobID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get job files: %w", err)
|
||||
}
|
||||
|
||||
// Find all PNG frame files
|
||||
var pngFiles []map[string]interface{}
|
||||
for _, file := range files {
|
||||
fileType, _ := file["file_type"].(string)
|
||||
fileName, _ := file["file_name"].(string)
|
||||
if fileType == "output" && strings.HasSuffix(fileName, ".png") {
|
||||
pngFiles = append(pngFiles, file)
|
||||
}
|
||||
}
|
||||
|
||||
if len(pngFiles) == 0 {
|
||||
return fmt.Errorf("no PNG frame files found for MP4 generation")
|
||||
}
|
||||
|
||||
log.Printf("Generating MP4 for job %d from %d PNG frames", jobID, len(pngFiles))
|
||||
|
||||
// Create work directory for video generation
|
||||
workDir := filepath.Join(os.TempDir(), fmt.Sprintf("fuego-video-%d", jobID))
|
||||
if err := os.MkdirAll(workDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create work directory: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(workDir)
|
||||
|
||||
// Download all PNG frames
|
||||
var frameFiles []string
|
||||
for _, file := range pngFiles {
|
||||
fileName, _ := file["file_name"].(string)
|
||||
framePath := filepath.Join(workDir, fileName)
|
||||
if err := c.downloadFrameFile(jobID, fileName, framePath); err != nil {
|
||||
log.Printf("Failed to download frame %s: %v", fileName, err)
|
||||
continue
|
||||
}
|
||||
frameFiles = append(frameFiles, framePath)
|
||||
}
|
||||
|
||||
if len(frameFiles) == 0 {
|
||||
return fmt.Errorf("failed to download any frame files")
|
||||
}
|
||||
|
||||
// Sort frame files by name to ensure correct order
|
||||
sort.Strings(frameFiles)
|
||||
|
||||
// Generate MP4 using ffmpeg
|
||||
outputMP4 := filepath.Join(workDir, fmt.Sprintf("output_%d.mp4", jobID))
|
||||
|
||||
// Use ffmpeg to combine frames into MP4
|
||||
// Method 1: Using image sequence input (more reliable)
|
||||
firstFrame := frameFiles[0]
|
||||
// Extract frame number pattern (e.g., frame_0001.png -> frame_%04d.png)
|
||||
baseName := filepath.Base(firstFrame)
|
||||
pattern := strings.Replace(baseName, fmt.Sprintf("%04d", extractFrameNumber(baseName)), "%04d", 1)
|
||||
patternPath := filepath.Join(workDir, pattern)
|
||||
|
||||
// Run ffmpeg to combine frames into MP4 at 24 fps
|
||||
cmd := exec.Command("ffmpeg", "-y", "-framerate", "24", "-i", patternPath,
|
||||
"-c:v", "libx264", "-pix_fmt", "yuv420p", "-r", "24", outputMP4)
|
||||
cmd.Dir = workDir
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
// Try alternative method with concat demuxer
|
||||
log.Printf("First ffmpeg attempt failed, trying concat method: %s", string(output))
|
||||
return c.generateMP4WithConcat(frameFiles, outputMP4, workDir)
|
||||
}
|
||||
|
||||
// Check if MP4 was created
|
||||
if _, err := os.Stat(outputMP4); os.IsNotExist(err) {
|
||||
return fmt.Errorf("MP4 file not created: %s", outputMP4)
|
||||
}
|
||||
|
||||
// Upload MP4 file
|
||||
mp4Path, err := c.uploadFile(jobID, outputMP4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upload MP4: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Successfully generated and uploaded MP4 for job %d: %s", jobID, mp4Path)
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateMP4WithConcat uses ffmpeg concat demuxer as fallback
|
||||
func (c *Client) generateMP4WithConcat(frameFiles []string, outputMP4, workDir string) error {
|
||||
// Create file list for ffmpeg concat demuxer
|
||||
listFile := filepath.Join(workDir, "frames.txt")
|
||||
listFileHandle, err := os.Create(listFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create list file: %w", err)
|
||||
}
|
||||
|
||||
for _, frameFile := range frameFiles {
|
||||
absPath, _ := filepath.Abs(frameFile)
|
||||
fmt.Fprintf(listFileHandle, "file '%s'\n", absPath)
|
||||
}
|
||||
listFileHandle.Close()
|
||||
|
||||
// Run ffmpeg with concat demuxer
|
||||
cmd := exec.Command("ffmpeg", "-f", "concat", "-safe", "0", "-i", listFile,
|
||||
"-c:v", "libx264", "-pix_fmt", "yuv420p", "-r", "24", "-y", outputMP4)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("ffmpeg concat failed: %w\nOutput: %s", err, string(output))
|
||||
}
|
||||
|
||||
if _, err := os.Stat(outputMP4); os.IsNotExist(err) {
|
||||
return fmt.Errorf("MP4 file not created: %s", outputMP4)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractFrameNumber extracts frame number from filename like "frame_0001.png"
|
||||
func extractFrameNumber(filename string) int {
|
||||
parts := strings.Split(filepath.Base(filename), "_")
|
||||
if len(parts) < 2 {
|
||||
return 0
|
||||
}
|
||||
framePart := strings.Split(parts[1], ".")[0]
|
||||
var frameNum int
|
||||
fmt.Sscanf(framePart, "%d", &frameNum)
|
||||
return frameNum
|
||||
}
|
||||
|
||||
// getJobStatus gets job status from manager
|
||||
func (c *Client) getJobStatus(jobID int64) (map[string]interface{}, error) {
|
||||
path := fmt.Sprintf("/api/runner/jobs/%d/status?runner_id=%d", jobID, c.runnerID)
|
||||
resp, err := c.doSignedRequest("GET", path, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("failed to get job status: %s", string(body))
|
||||
}
|
||||
|
||||
var job map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&job); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return job, nil
|
||||
}
|
||||
|
||||
// getJobFiles gets job files from manager
|
||||
func (c *Client) getJobFiles(jobID int64) ([]map[string]interface{}, error) {
|
||||
path := fmt.Sprintf("/api/runner/jobs/%d/files?runner_id=%d", jobID, c.runnerID)
|
||||
resp, err := c.doSignedRequest("GET", path, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("failed to get job files: %s", string(body))
|
||||
}
|
||||
|
||||
var files []map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&files); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
// downloadFrameFile downloads a frame file for MP4 generation
|
||||
func (c *Client) downloadFrameFile(jobID int64, fileName, destPath string) error {
|
||||
path := fmt.Sprintf("/api/runner/files/%d/%s?runner_id=%d", jobID, fileName, c.runnerID)
|
||||
resp, err := c.doSignedRequest("GET", path, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("download failed: %s", string(body))
|
||||
}
|
||||
|
||||
file, err := os.Create(destPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
_, err = io.Copy(file, resp.Body)
|
||||
return err
|
||||
}
|
||||
|
||||
// downloadFile downloads a file from the manager
|
||||
func (c *Client) downloadFile(filePath, destDir string) error {
|
||||
// Extract job ID and filename from path
|
||||
// Path format: storage/jobs/{jobID}/{filename}
|
||||
parts := filepath.SplitList(filePath)
|
||||
if len(parts) < 3 {
|
||||
return fmt.Errorf("invalid file path format: %s", filePath)
|
||||
}
|
||||
|
||||
// Find job ID in path (look for "jobs" directory)
|
||||
jobID := ""
|
||||
fileName := filepath.Base(filePath)
|
||||
for i, part := range parts {
|
||||
if part == "jobs" && i+1 < len(parts) {
|
||||
jobID = parts[i+1]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if jobID == "" {
|
||||
return fmt.Errorf("could not extract job ID from path: %s", filePath)
|
||||
}
|
||||
|
||||
// Download via HTTP
|
||||
path := fmt.Sprintf("/api/runner/files/%s/%s?runner_id=%d", jobID, fileName, c.runnerID)
|
||||
resp, err := c.doSignedRequest("GET", path, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download file: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("download failed: %s", string(body))
|
||||
}
|
||||
|
||||
destPath := filepath.Join(destDir, fileName)
|
||||
file, err := os.Create(destPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create destination file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
_, err = io.Copy(file, resp.Body)
|
||||
return err
|
||||
}
|
||||
|
||||
// uploadFile uploads a file to the manager
|
||||
func (c *Client) uploadFile(jobID int64, filePath string) (string, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Create multipart form
|
||||
var buf bytes.Buffer
|
||||
formWriter := multipart.NewWriter(&buf)
|
||||
|
||||
part, err := formWriter.CreateFormFile("file", filepath.Base(filePath))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create form file: %w", err)
|
||||
}
|
||||
|
||||
_, err = io.Copy(part, file)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to copy file data: %w", err)
|
||||
}
|
||||
|
||||
formWriter.Close()
|
||||
|
||||
// Upload file with signature
|
||||
path := fmt.Sprintf("/api/runner/files/%d/upload?runner_id=%d", jobID, c.runnerID)
|
||||
timestamp := time.Now()
|
||||
message := fmt.Sprintf("POST\n%s\n%s\n%d", path, buf.String(), timestamp.Unix())
|
||||
h := hmac.New(sha256.New, []byte(c.runnerSecret))
|
||||
h.Write([]byte(message))
|
||||
signature := hex.EncodeToString(h.Sum(nil))
|
||||
|
||||
url := fmt.Sprintf("%s%s", c.managerURL, path)
|
||||
req, err := http.NewRequest("POST", url, &buf)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", formWriter.FormDataContentType())
|
||||
req.Header.Set("X-Runner-Signature", signature)
|
||||
req.Header.Set("X-Runner-Timestamp", fmt.Sprintf("%d", timestamp.Unix()))
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to upload file: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("upload failed: %s", string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
FilePath string `json:"file_path"`
|
||||
FileName string `json:"file_name"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
return result.FilePath, nil
|
||||
}
|
||||
|
||||
// completeTask marks a task as complete
|
||||
func (c *Client) completeTask(taskID int64, outputPath string, success bool, errorMsg string) error {
|
||||
req := map[string]interface{}{
|
||||
"output_path": outputPath,
|
||||
"success": success,
|
||||
}
|
||||
if !success {
|
||||
req["error"] = errorMsg
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(req)
|
||||
path := fmt.Sprintf("/api/runner/tasks/%d/complete?runner_id=%d", taskID, c.runnerID)
|
||||
resp, err := c.doSignedRequest("POST", path, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("failed to complete task: %s", string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
137
internal/storage/storage.go
Normal file
137
internal/storage/storage.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// Storage handles file storage operations
|
||||
type Storage struct {
|
||||
basePath string
|
||||
}
|
||||
|
||||
// NewStorage creates a new storage instance
|
||||
func NewStorage(basePath string) (*Storage, error) {
|
||||
s := &Storage{basePath: basePath}
|
||||
if err := s.init(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// init creates necessary directories
|
||||
func (s *Storage) init() error {
|
||||
dirs := []string{
|
||||
s.basePath,
|
||||
s.uploadsPath(),
|
||||
s.outputsPath(),
|
||||
}
|
||||
|
||||
for _, dir := range dirs {
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create directory %s: %w", dir, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// uploadsPath returns the path for uploads
|
||||
func (s *Storage) uploadsPath() string {
|
||||
return filepath.Join(s.basePath, "uploads")
|
||||
}
|
||||
|
||||
// outputsPath returns the path for outputs
|
||||
func (s *Storage) outputsPath() string {
|
||||
return filepath.Join(s.basePath, "outputs")
|
||||
}
|
||||
|
||||
// JobPath returns the path for a specific job's files
|
||||
func (s *Storage) JobPath(jobID int64) string {
|
||||
return filepath.Join(s.basePath, "jobs", fmt.Sprintf("%d", jobID))
|
||||
}
|
||||
|
||||
// SaveUpload saves an uploaded file
|
||||
func (s *Storage) SaveUpload(jobID int64, filename string, reader io.Reader) (string, error) {
|
||||
jobPath := s.JobPath(jobID)
|
||||
if err := os.MkdirAll(jobPath, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create job directory: %w", err)
|
||||
}
|
||||
|
||||
filePath := filepath.Join(jobPath, filename)
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
if _, err := io.Copy(file, reader); err != nil {
|
||||
return "", fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
// SaveOutput saves an output file
|
||||
func (s *Storage) SaveOutput(jobID int64, filename string, reader io.Reader) (string, error) {
|
||||
outputPath := filepath.Join(s.outputsPath(), fmt.Sprintf("%d", jobID))
|
||||
if err := os.MkdirAll(outputPath, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create output directory: %w", err)
|
||||
}
|
||||
|
||||
filePath := filepath.Join(outputPath, filename)
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
if _, err := io.Copy(file, reader); err != nil {
|
||||
return "", fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
// GetFile returns a file reader for the given path
|
||||
func (s *Storage) GetFile(filePath string) (*os.File, error) {
|
||||
return os.Open(filePath)
|
||||
}
|
||||
|
||||
// DeleteFile deletes a file
|
||||
func (s *Storage) DeleteFile(filePath string) error {
|
||||
return os.Remove(filePath)
|
||||
}
|
||||
|
||||
// DeleteJobFiles deletes all files for a job
|
||||
func (s *Storage) DeleteJobFiles(jobID int64) error {
|
||||
jobPath := s.JobPath(jobID)
|
||||
if err := os.RemoveAll(jobPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to delete job files: %w", err)
|
||||
}
|
||||
|
||||
outputPath := filepath.Join(s.outputsPath(), fmt.Sprintf("%d", jobID))
|
||||
if err := os.RemoveAll(outputPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to delete output files: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FileExists checks if a file exists
|
||||
func (s *Storage) FileExists(filePath string) bool {
|
||||
_, err := os.Stat(filePath)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// GetFileSize returns the size of a file
|
||||
func (s *Storage) GetFileSize(filePath string) (int64, error) {
|
||||
info, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return info.Size(), nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user