Files
jiggablend/pkg/executils/exec.go
2025-11-27 00:46:48 -06:00

367 lines
8.8 KiB
Go

package executils
import (
"bufio"
"errors"
"fmt"
"os"
"os/exec"
"sync"
"time"
"jiggablend/pkg/types"
)
// DefaultTracker is the global default process tracker
// Use this for processes that should be tracked globally and killed on shutdown
var DefaultTracker = NewProcessTracker()
// ProcessTracker tracks running processes for cleanup
type ProcessTracker struct {
processes sync.Map // map[int64]*exec.Cmd - tracks running processes by task ID
}
// NewProcessTracker creates a new process tracker
func NewProcessTracker() *ProcessTracker {
return &ProcessTracker{}
}
// Track registers a process for tracking
func (pt *ProcessTracker) Track(taskID int64, cmd *exec.Cmd) {
pt.processes.Store(taskID, cmd)
}
// Untrack removes a process from tracking
func (pt *ProcessTracker) Untrack(taskID int64) {
pt.processes.Delete(taskID)
}
// Get returns the command for a task ID if it exists
func (pt *ProcessTracker) Get(taskID int64) (*exec.Cmd, bool) {
if val, ok := pt.processes.Load(taskID); ok {
return val.(*exec.Cmd), true
}
return nil, false
}
// Kill kills a specific process by task ID
// Returns true if the process was found and killed
func (pt *ProcessTracker) Kill(taskID int64) bool {
cmd, ok := pt.Get(taskID)
if !ok || cmd.Process == nil {
return false
}
// Try graceful kill first (SIGINT)
if err := cmd.Process.Signal(os.Interrupt); err != nil {
// If SIGINT fails, try SIGKILL
cmd.Process.Kill()
} else {
// Give it a moment to clean up gracefully
time.Sleep(100 * time.Millisecond)
// Force kill if still running
cmd.Process.Kill()
}
pt.Untrack(taskID)
return true
}
// KillAll kills all tracked processes
// Returns the number of processes killed
func (pt *ProcessTracker) KillAll() int {
var killedCount int
pt.processes.Range(func(key, value interface{}) bool {
taskID := key.(int64)
cmd := value.(*exec.Cmd)
if cmd.Process != nil {
// Try graceful kill first (SIGINT)
if err := cmd.Process.Signal(os.Interrupt); err == nil {
// Give it a moment to clean up
time.Sleep(100 * time.Millisecond)
}
// Force kill
cmd.Process.Kill()
killedCount++
}
pt.processes.Delete(taskID)
return true
})
return killedCount
}
// Count returns the number of tracked processes
func (pt *ProcessTracker) Count() int {
count := 0
pt.processes.Range(func(key, value interface{}) bool {
count++
return true
})
return count
}
// CommandResult holds the output from a command execution
type CommandResult struct {
Stdout string
Stderr string
ExitCode int
}
// RunCommand executes a command and returns the output
// If tracker is provided, the process will be registered for tracking
// This is useful for commands where you need to capture output (like metadata extraction)
func RunCommand(
cmdPath string,
args []string,
dir string,
env []string,
taskID int64,
tracker *ProcessTracker,
) (*CommandResult, error) {
cmd := exec.Command(cmdPath, args...)
cmd.Dir = dir
if env != nil {
cmd.Env = env
}
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
}
stderrPipe, err := cmd.StderrPipe()
if err != nil {
return nil, fmt.Errorf("failed to create stderr pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start command: %w", err)
}
// Track the process if tracker is provided
if tracker != nil {
tracker.Track(taskID, cmd)
defer tracker.Untrack(taskID)
}
// Collect stdout
var stdoutBuf, stderrBuf []byte
var stdoutErr, stderrErr error
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
stdoutBuf, stdoutErr = readAll(stdoutPipe)
}()
go func() {
defer wg.Done()
stderrBuf, stderrErr = readAll(stderrPipe)
}()
waitErr := cmd.Wait()
wg.Wait()
// Check for read errors
if stdoutErr != nil {
return nil, fmt.Errorf("failed to read stdout: %w", stdoutErr)
}
if stderrErr != nil {
return nil, fmt.Errorf("failed to read stderr: %w", stderrErr)
}
result := &CommandResult{
Stdout: string(stdoutBuf),
Stderr: string(stderrBuf),
}
if waitErr != nil {
if exitErr, ok := waitErr.(*exec.ExitError); ok {
result.ExitCode = exitErr.ExitCode()
} else {
result.ExitCode = -1
}
return result, waitErr
}
result.ExitCode = 0
return result, nil
}
// readAll reads all data from a reader
func readAll(r interface{ Read([]byte) (int, error) }) ([]byte, error) {
var buf []byte
tmp := make([]byte, 4096)
for {
n, err := r.Read(tmp)
if n > 0 {
buf = append(buf, tmp[:n]...)
}
if err != nil {
if err.Error() == "EOF" {
break
}
return buf, err
}
}
return buf, nil
}
// LogSender is a function type for sending logs
type LogSender func(taskID int, level types.LogLevel, message string, stepName string)
// LineFilter is a function that processes a line and returns whether to filter it out and the log level
type LineFilter func(line string) (shouldFilter bool, level types.LogLevel)
// RunCommandWithStreaming executes a command with streaming output and OOM detection
// If tracker is provided, the process will be registered for tracking
func RunCommandWithStreaming(
cmdPath string,
args []string,
dir string,
env []string,
taskID int,
stepName string,
logSender LogSender,
stdoutFilter LineFilter,
stderrFilter LineFilter,
oomMessage string,
tracker *ProcessTracker,
) error {
cmd := exec.Command(cmdPath, args...)
cmd.Dir = dir
cmd.Env = env
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
errMsg := fmt.Sprintf("failed to create stdout pipe: %v", err)
logSender(taskID, types.LogLevelError, errMsg, stepName)
return errors.New(errMsg)
}
stderrPipe, err := cmd.StderrPipe()
if err != nil {
errMsg := fmt.Sprintf("failed to create stderr pipe: %v", err)
logSender(taskID, types.LogLevelError, errMsg, stepName)
return errors.New(errMsg)
}
if err := cmd.Start(); err != nil {
errMsg := fmt.Sprintf("failed to start command: %v", err)
logSender(taskID, types.LogLevelError, errMsg, stepName)
return errors.New(errMsg)
}
// Track the process if tracker is provided
if tracker != nil {
tracker.Track(int64(taskID), cmd)
defer tracker.Untrack(int64(taskID))
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
scanner := bufio.NewScanner(stdoutPipe)
for scanner.Scan() {
line := scanner.Text()
if line != "" {
shouldFilter, level := stdoutFilter(line)
if !shouldFilter {
logSender(taskID, level, line, stepName)
}
}
}
}()
go func() {
defer wg.Done()
scanner := bufio.NewScanner(stderrPipe)
for scanner.Scan() {
line := scanner.Text()
if line != "" {
shouldFilter, level := stderrFilter(line)
if !shouldFilter {
logSender(taskID, level, line, stepName)
}
}
}
}()
err = cmd.Wait()
wg.Wait()
if err != nil {
var errMsg string
if exitErr, ok := err.(*exec.ExitError); ok {
if exitErr.ExitCode() == 137 {
errMsg = oomMessage
} else {
errMsg = fmt.Sprintf("command failed: %v", err)
}
} else {
errMsg = fmt.Sprintf("command failed: %v", err)
}
logSender(taskID, types.LogLevelError, errMsg, stepName)
return errors.New(errMsg)
}
return nil
}
// ============================================================================
// Helper functions using DefaultTracker
// ============================================================================
// Run executes a command using the default tracker and returns the output
// This is a convenience wrapper around RunCommand that uses DefaultTracker
func Run(cmdPath string, args []string, dir string, env []string, taskID int64) (*CommandResult, error) {
return RunCommand(cmdPath, args, dir, env, taskID, DefaultTracker)
}
// RunStreaming executes a command with streaming output using the default tracker
// This is a convenience wrapper around RunCommandWithStreaming that uses DefaultTracker
func RunStreaming(
cmdPath string,
args []string,
dir string,
env []string,
taskID int,
stepName string,
logSender LogSender,
stdoutFilter LineFilter,
stderrFilter LineFilter,
oomMessage string,
) error {
return RunCommandWithStreaming(cmdPath, args, dir, env, taskID, stepName, logSender, stdoutFilter, stderrFilter, oomMessage, DefaultTracker)
}
// KillAll kills all processes tracked by the default tracker
// Returns the number of processes killed
func KillAll() int {
return DefaultTracker.KillAll()
}
// Kill kills a specific process by task ID using the default tracker
// Returns true if the process was found and killed
func Kill(taskID int64) bool {
return DefaultTracker.Kill(taskID)
}
// Track registers a process with the default tracker
func Track(taskID int64, cmd *exec.Cmd) {
DefaultTracker.Track(taskID, cmd)
}
// Untrack removes a process from the default tracker
func Untrack(taskID int64) {
DefaultTracker.Untrack(taskID)
}
// GetTrackedCount returns the number of processes tracked by the default tracker
func GetTrackedCount() int {
return DefaultTracker.Count()
}