367 lines
8.8 KiB
Go
367 lines
8.8 KiB
Go
package executils
|
|
|
|
import (
|
|
"bufio"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"os/exec"
|
|
"sync"
|
|
"time"
|
|
|
|
"jiggablend/pkg/types"
|
|
)
|
|
|
|
// DefaultTracker is the global default process tracker
|
|
// Use this for processes that should be tracked globally and killed on shutdown
|
|
var DefaultTracker = NewProcessTracker()
|
|
|
|
// ProcessTracker tracks running processes for cleanup
|
|
type ProcessTracker struct {
|
|
processes sync.Map // map[int64]*exec.Cmd - tracks running processes by task ID
|
|
}
|
|
|
|
// NewProcessTracker creates a new process tracker
|
|
func NewProcessTracker() *ProcessTracker {
|
|
return &ProcessTracker{}
|
|
}
|
|
|
|
// Track registers a process for tracking
|
|
func (pt *ProcessTracker) Track(taskID int64, cmd *exec.Cmd) {
|
|
pt.processes.Store(taskID, cmd)
|
|
}
|
|
|
|
// Untrack removes a process from tracking
|
|
func (pt *ProcessTracker) Untrack(taskID int64) {
|
|
pt.processes.Delete(taskID)
|
|
}
|
|
|
|
// Get returns the command for a task ID if it exists
|
|
func (pt *ProcessTracker) Get(taskID int64) (*exec.Cmd, bool) {
|
|
if val, ok := pt.processes.Load(taskID); ok {
|
|
return val.(*exec.Cmd), true
|
|
}
|
|
return nil, false
|
|
}
|
|
|
|
// Kill kills a specific process by task ID
|
|
// Returns true if the process was found and killed
|
|
func (pt *ProcessTracker) Kill(taskID int64) bool {
|
|
cmd, ok := pt.Get(taskID)
|
|
if !ok || cmd.Process == nil {
|
|
return false
|
|
}
|
|
|
|
// Try graceful kill first (SIGINT)
|
|
if err := cmd.Process.Signal(os.Interrupt); err != nil {
|
|
// If SIGINT fails, try SIGKILL
|
|
cmd.Process.Kill()
|
|
} else {
|
|
// Give it a moment to clean up gracefully
|
|
time.Sleep(100 * time.Millisecond)
|
|
// Force kill if still running
|
|
cmd.Process.Kill()
|
|
}
|
|
|
|
pt.Untrack(taskID)
|
|
return true
|
|
}
|
|
|
|
// KillAll kills all tracked processes
|
|
// Returns the number of processes killed
|
|
func (pt *ProcessTracker) KillAll() int {
|
|
var killedCount int
|
|
pt.processes.Range(func(key, value interface{}) bool {
|
|
taskID := key.(int64)
|
|
cmd := value.(*exec.Cmd)
|
|
if cmd.Process != nil {
|
|
// Try graceful kill first (SIGINT)
|
|
if err := cmd.Process.Signal(os.Interrupt); err == nil {
|
|
// Give it a moment to clean up
|
|
time.Sleep(100 * time.Millisecond)
|
|
}
|
|
// Force kill
|
|
cmd.Process.Kill()
|
|
killedCount++
|
|
}
|
|
pt.processes.Delete(taskID)
|
|
return true
|
|
})
|
|
return killedCount
|
|
}
|
|
|
|
// Count returns the number of tracked processes
|
|
func (pt *ProcessTracker) Count() int {
|
|
count := 0
|
|
pt.processes.Range(func(key, value interface{}) bool {
|
|
count++
|
|
return true
|
|
})
|
|
return count
|
|
}
|
|
|
|
// CommandResult holds the output from a command execution
|
|
type CommandResult struct {
|
|
Stdout string
|
|
Stderr string
|
|
ExitCode int
|
|
}
|
|
|
|
// RunCommand executes a command and returns the output
|
|
// If tracker is provided, the process will be registered for tracking
|
|
// This is useful for commands where you need to capture output (like metadata extraction)
|
|
func RunCommand(
|
|
cmdPath string,
|
|
args []string,
|
|
dir string,
|
|
env []string,
|
|
taskID int64,
|
|
tracker *ProcessTracker,
|
|
) (*CommandResult, error) {
|
|
cmd := exec.Command(cmdPath, args...)
|
|
cmd.Dir = dir
|
|
if env != nil {
|
|
cmd.Env = env
|
|
}
|
|
|
|
stdoutPipe, err := cmd.StdoutPipe()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
|
|
}
|
|
|
|
stderrPipe, err := cmd.StderrPipe()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create stderr pipe: %w", err)
|
|
}
|
|
|
|
if err := cmd.Start(); err != nil {
|
|
return nil, fmt.Errorf("failed to start command: %w", err)
|
|
}
|
|
|
|
// Track the process if tracker is provided
|
|
if tracker != nil {
|
|
tracker.Track(taskID, cmd)
|
|
defer tracker.Untrack(taskID)
|
|
}
|
|
|
|
// Collect stdout
|
|
var stdoutBuf, stderrBuf []byte
|
|
var stdoutErr, stderrErr error
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
stdoutBuf, stdoutErr = readAll(stdoutPipe)
|
|
}()
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
stderrBuf, stderrErr = readAll(stderrPipe)
|
|
}()
|
|
|
|
waitErr := cmd.Wait()
|
|
wg.Wait()
|
|
|
|
// Check for read errors
|
|
if stdoutErr != nil {
|
|
return nil, fmt.Errorf("failed to read stdout: %w", stdoutErr)
|
|
}
|
|
if stderrErr != nil {
|
|
return nil, fmt.Errorf("failed to read stderr: %w", stderrErr)
|
|
}
|
|
|
|
result := &CommandResult{
|
|
Stdout: string(stdoutBuf),
|
|
Stderr: string(stderrBuf),
|
|
}
|
|
|
|
if waitErr != nil {
|
|
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
|
result.ExitCode = exitErr.ExitCode()
|
|
} else {
|
|
result.ExitCode = -1
|
|
}
|
|
return result, waitErr
|
|
}
|
|
|
|
result.ExitCode = 0
|
|
return result, nil
|
|
}
|
|
|
|
// readAll reads all data from a reader
|
|
func readAll(r interface{ Read([]byte) (int, error) }) ([]byte, error) {
|
|
var buf []byte
|
|
tmp := make([]byte, 4096)
|
|
for {
|
|
n, err := r.Read(tmp)
|
|
if n > 0 {
|
|
buf = append(buf, tmp[:n]...)
|
|
}
|
|
if err != nil {
|
|
if err.Error() == "EOF" {
|
|
break
|
|
}
|
|
return buf, err
|
|
}
|
|
}
|
|
return buf, nil
|
|
}
|
|
|
|
// LogSender is a function type for sending logs
|
|
type LogSender func(taskID int, level types.LogLevel, message string, stepName string)
|
|
|
|
// LineFilter is a function that processes a line and returns whether to filter it out and the log level
|
|
type LineFilter func(line string) (shouldFilter bool, level types.LogLevel)
|
|
|
|
// RunCommandWithStreaming executes a command with streaming output and OOM detection
|
|
// If tracker is provided, the process will be registered for tracking
|
|
func RunCommandWithStreaming(
|
|
cmdPath string,
|
|
args []string,
|
|
dir string,
|
|
env []string,
|
|
taskID int,
|
|
stepName string,
|
|
logSender LogSender,
|
|
stdoutFilter LineFilter,
|
|
stderrFilter LineFilter,
|
|
oomMessage string,
|
|
tracker *ProcessTracker,
|
|
) error {
|
|
cmd := exec.Command(cmdPath, args...)
|
|
cmd.Dir = dir
|
|
cmd.Env = env
|
|
|
|
stdoutPipe, err := cmd.StdoutPipe()
|
|
if err != nil {
|
|
errMsg := fmt.Sprintf("failed to create stdout pipe: %v", err)
|
|
logSender(taskID, types.LogLevelError, errMsg, stepName)
|
|
return errors.New(errMsg)
|
|
}
|
|
|
|
stderrPipe, err := cmd.StderrPipe()
|
|
if err != nil {
|
|
errMsg := fmt.Sprintf("failed to create stderr pipe: %v", err)
|
|
logSender(taskID, types.LogLevelError, errMsg, stepName)
|
|
return errors.New(errMsg)
|
|
}
|
|
|
|
if err := cmd.Start(); err != nil {
|
|
errMsg := fmt.Sprintf("failed to start command: %v", err)
|
|
logSender(taskID, types.LogLevelError, errMsg, stepName)
|
|
return errors.New(errMsg)
|
|
}
|
|
|
|
// Track the process if tracker is provided
|
|
if tracker != nil {
|
|
tracker.Track(int64(taskID), cmd)
|
|
defer tracker.Untrack(int64(taskID))
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
scanner := bufio.NewScanner(stdoutPipe)
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
if line != "" {
|
|
shouldFilter, level := stdoutFilter(line)
|
|
if !shouldFilter {
|
|
logSender(taskID, level, line, stepName)
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
scanner := bufio.NewScanner(stderrPipe)
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
if line != "" {
|
|
shouldFilter, level := stderrFilter(line)
|
|
if !shouldFilter {
|
|
logSender(taskID, level, line, stepName)
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
err = cmd.Wait()
|
|
wg.Wait()
|
|
|
|
if err != nil {
|
|
var errMsg string
|
|
if exitErr, ok := err.(*exec.ExitError); ok {
|
|
if exitErr.ExitCode() == 137 {
|
|
errMsg = oomMessage
|
|
} else {
|
|
errMsg = fmt.Sprintf("command failed: %v", err)
|
|
}
|
|
} else {
|
|
errMsg = fmt.Sprintf("command failed: %v", err)
|
|
}
|
|
logSender(taskID, types.LogLevelError, errMsg, stepName)
|
|
return errors.New(errMsg)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ============================================================================
|
|
// Helper functions using DefaultTracker
|
|
// ============================================================================
|
|
|
|
// Run executes a command using the default tracker and returns the output
|
|
// This is a convenience wrapper around RunCommand that uses DefaultTracker
|
|
func Run(cmdPath string, args []string, dir string, env []string, taskID int64) (*CommandResult, error) {
|
|
return RunCommand(cmdPath, args, dir, env, taskID, DefaultTracker)
|
|
}
|
|
|
|
// RunStreaming executes a command with streaming output using the default tracker
|
|
// This is a convenience wrapper around RunCommandWithStreaming that uses DefaultTracker
|
|
func RunStreaming(
|
|
cmdPath string,
|
|
args []string,
|
|
dir string,
|
|
env []string,
|
|
taskID int,
|
|
stepName string,
|
|
logSender LogSender,
|
|
stdoutFilter LineFilter,
|
|
stderrFilter LineFilter,
|
|
oomMessage string,
|
|
) error {
|
|
return RunCommandWithStreaming(cmdPath, args, dir, env, taskID, stepName, logSender, stdoutFilter, stderrFilter, oomMessage, DefaultTracker)
|
|
}
|
|
|
|
// KillAll kills all processes tracked by the default tracker
|
|
// Returns the number of processes killed
|
|
func KillAll() int {
|
|
return DefaultTracker.KillAll()
|
|
}
|
|
|
|
// Kill kills a specific process by task ID using the default tracker
|
|
// Returns true if the process was found and killed
|
|
func Kill(taskID int64) bool {
|
|
return DefaultTracker.Kill(taskID)
|
|
}
|
|
|
|
// Track registers a process with the default tracker
|
|
func Track(taskID int64, cmd *exec.Cmd) {
|
|
DefaultTracker.Track(taskID, cmd)
|
|
}
|
|
|
|
// Untrack removes a process from the default tracker
|
|
func Untrack(taskID int64) {
|
|
DefaultTracker.Untrack(taskID)
|
|
}
|
|
|
|
// GetTrackedCount returns the number of processes tracked by the default tracker
|
|
func GetTrackedCount() int {
|
|
return DefaultTracker.Count()
|
|
}
|