package executils import ( "bufio" "errors" "fmt" "os" "os/exec" "sync" "time" "jiggablend/pkg/types" ) // DefaultTracker is the global default process tracker // Use this for processes that should be tracked globally and killed on shutdown var DefaultTracker = NewProcessTracker() // ProcessTracker tracks running processes for cleanup type ProcessTracker struct { processes sync.Map // map[int64]*exec.Cmd - tracks running processes by task ID } // NewProcessTracker creates a new process tracker func NewProcessTracker() *ProcessTracker { return &ProcessTracker{} } // Track registers a process for tracking func (pt *ProcessTracker) Track(taskID int64, cmd *exec.Cmd) { pt.processes.Store(taskID, cmd) } // Untrack removes a process from tracking func (pt *ProcessTracker) Untrack(taskID int64) { pt.processes.Delete(taskID) } // Get returns the command for a task ID if it exists func (pt *ProcessTracker) Get(taskID int64) (*exec.Cmd, bool) { if val, ok := pt.processes.Load(taskID); ok { return val.(*exec.Cmd), true } return nil, false } // Kill kills a specific process by task ID // Returns true if the process was found and killed func (pt *ProcessTracker) Kill(taskID int64) bool { cmd, ok := pt.Get(taskID) if !ok || cmd.Process == nil { return false } // Try graceful kill first (SIGINT) if err := cmd.Process.Signal(os.Interrupt); err != nil { // If SIGINT fails, try SIGKILL cmd.Process.Kill() } else { // Give it a moment to clean up gracefully time.Sleep(100 * time.Millisecond) // Force kill if still running cmd.Process.Kill() } pt.Untrack(taskID) return true } // KillAll kills all tracked processes // Returns the number of processes killed func (pt *ProcessTracker) KillAll() int { var killedCount int pt.processes.Range(func(key, value interface{}) bool { taskID := key.(int64) cmd := value.(*exec.Cmd) if cmd.Process != nil { // Try graceful kill first (SIGINT) if err := cmd.Process.Signal(os.Interrupt); err == nil { // Give it a moment to clean up time.Sleep(100 * time.Millisecond) } // Force kill cmd.Process.Kill() killedCount++ } pt.processes.Delete(taskID) return true }) return killedCount } // Count returns the number of tracked processes func (pt *ProcessTracker) Count() int { count := 0 pt.processes.Range(func(key, value interface{}) bool { count++ return true }) return count } // CommandResult holds the output from a command execution type CommandResult struct { Stdout string Stderr string ExitCode int } // RunCommand executes a command and returns the output // If tracker is provided, the process will be registered for tracking // This is useful for commands where you need to capture output (like metadata extraction) func RunCommand( cmdPath string, args []string, dir string, env []string, taskID int64, tracker *ProcessTracker, ) (*CommandResult, error) { cmd := exec.Command(cmdPath, args...) cmd.Dir = dir if env != nil { cmd.Env = env } stdoutPipe, err := cmd.StdoutPipe() if err != nil { return nil, fmt.Errorf("failed to create stdout pipe: %w", err) } stderrPipe, err := cmd.StderrPipe() if err != nil { return nil, fmt.Errorf("failed to create stderr pipe: %w", err) } if err := cmd.Start(); err != nil { return nil, fmt.Errorf("failed to start command: %w", err) } // Track the process if tracker is provided if tracker != nil { tracker.Track(taskID, cmd) defer tracker.Untrack(taskID) } // Collect stdout var stdoutBuf, stderrBuf []byte var stdoutErr, stderrErr error var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() stdoutBuf, stdoutErr = readAll(stdoutPipe) }() go func() { defer wg.Done() stderrBuf, stderrErr = readAll(stderrPipe) }() waitErr := cmd.Wait() wg.Wait() // Check for read errors if stdoutErr != nil { return nil, fmt.Errorf("failed to read stdout: %w", stdoutErr) } if stderrErr != nil { return nil, fmt.Errorf("failed to read stderr: %w", stderrErr) } result := &CommandResult{ Stdout: string(stdoutBuf), Stderr: string(stderrBuf), } if waitErr != nil { if exitErr, ok := waitErr.(*exec.ExitError); ok { result.ExitCode = exitErr.ExitCode() } else { result.ExitCode = -1 } return result, waitErr } result.ExitCode = 0 return result, nil } // readAll reads all data from a reader func readAll(r interface{ Read([]byte) (int, error) }) ([]byte, error) { var buf []byte tmp := make([]byte, 4096) for { n, err := r.Read(tmp) if n > 0 { buf = append(buf, tmp[:n]...) } if err != nil { if err.Error() == "EOF" { break } return buf, err } } return buf, nil } // LogSender is a function type for sending logs type LogSender func(taskID int, level types.LogLevel, message string, stepName string) // LineFilter is a function that processes a line and returns whether to filter it out and the log level type LineFilter func(line string) (shouldFilter bool, level types.LogLevel) // RunCommandWithStreaming executes a command with streaming output and OOM detection // If tracker is provided, the process will be registered for tracking func RunCommandWithStreaming( cmdPath string, args []string, dir string, env []string, taskID int, stepName string, logSender LogSender, stdoutFilter LineFilter, stderrFilter LineFilter, oomMessage string, tracker *ProcessTracker, ) error { cmd := exec.Command(cmdPath, args...) cmd.Dir = dir cmd.Env = env stdoutPipe, err := cmd.StdoutPipe() if err != nil { errMsg := fmt.Sprintf("failed to create stdout pipe: %v", err) logSender(taskID, types.LogLevelError, errMsg, stepName) return errors.New(errMsg) } stderrPipe, err := cmd.StderrPipe() if err != nil { errMsg := fmt.Sprintf("failed to create stderr pipe: %v", err) logSender(taskID, types.LogLevelError, errMsg, stepName) return errors.New(errMsg) } if err := cmd.Start(); err != nil { errMsg := fmt.Sprintf("failed to start command: %v", err) logSender(taskID, types.LogLevelError, errMsg, stepName) return errors.New(errMsg) } // Track the process if tracker is provided if tracker != nil { tracker.Track(int64(taskID), cmd) defer tracker.Untrack(int64(taskID)) } var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() scanner := bufio.NewScanner(stdoutPipe) for scanner.Scan() { line := scanner.Text() if line != "" { shouldFilter, level := stdoutFilter(line) if !shouldFilter { logSender(taskID, level, line, stepName) } } } }() go func() { defer wg.Done() scanner := bufio.NewScanner(stderrPipe) for scanner.Scan() { line := scanner.Text() if line != "" { shouldFilter, level := stderrFilter(line) if !shouldFilter { logSender(taskID, level, line, stepName) } } } }() err = cmd.Wait() wg.Wait() if err != nil { var errMsg string if exitErr, ok := err.(*exec.ExitError); ok { if exitErr.ExitCode() == 137 { errMsg = oomMessage } else { errMsg = fmt.Sprintf("command failed: %v", err) } } else { errMsg = fmt.Sprintf("command failed: %v", err) } logSender(taskID, types.LogLevelError, errMsg, stepName) return errors.New(errMsg) } return nil } // ============================================================================ // Helper functions using DefaultTracker // ============================================================================ // Run executes a command using the default tracker and returns the output // This is a convenience wrapper around RunCommand that uses DefaultTracker func Run(cmdPath string, args []string, dir string, env []string, taskID int64) (*CommandResult, error) { return RunCommand(cmdPath, args, dir, env, taskID, DefaultTracker) } // RunStreaming executes a command with streaming output using the default tracker // This is a convenience wrapper around RunCommandWithStreaming that uses DefaultTracker func RunStreaming( cmdPath string, args []string, dir string, env []string, taskID int, stepName string, logSender LogSender, stdoutFilter LineFilter, stderrFilter LineFilter, oomMessage string, ) error { return RunCommandWithStreaming(cmdPath, args, dir, env, taskID, stepName, logSender, stdoutFilter, stderrFilter, oomMessage, DefaultTracker) } // KillAll kills all processes tracked by the default tracker // Returns the number of processes killed func KillAll() int { return DefaultTracker.KillAll() } // Kill kills a specific process by task ID using the default tracker // Returns true if the process was found and killed func Kill(taskID int64) bool { return DefaultTracker.Kill(taskID) } // Track registers a process with the default tracker func Track(taskID int64, cmd *exec.Cmd) { DefaultTracker.Track(taskID, cmd) } // Untrack removes a process from the default tracker func Untrack(taskID int64) { DefaultTracker.Untrack(taskID) } // GetTrackedCount returns the number of processes tracked by the default tracker func GetTrackedCount() int { return DefaultTracker.Count() }