something
This commit is contained in:
366
pkg/executils/exec.go
Normal file
366
pkg/executils/exec.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package executils
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"jiggablend/pkg/types"
|
||||
)
|
||||
|
||||
// DefaultTracker is the global default process tracker
|
||||
// Use this for processes that should be tracked globally and killed on shutdown
|
||||
var DefaultTracker = NewProcessTracker()
|
||||
|
||||
// ProcessTracker tracks running processes for cleanup
|
||||
type ProcessTracker struct {
|
||||
processes sync.Map // map[int64]*exec.Cmd - tracks running processes by task ID
|
||||
}
|
||||
|
||||
// NewProcessTracker creates a new process tracker
|
||||
func NewProcessTracker() *ProcessTracker {
|
||||
return &ProcessTracker{}
|
||||
}
|
||||
|
||||
// Track registers a process for tracking
|
||||
func (pt *ProcessTracker) Track(taskID int64, cmd *exec.Cmd) {
|
||||
pt.processes.Store(taskID, cmd)
|
||||
}
|
||||
|
||||
// Untrack removes a process from tracking
|
||||
func (pt *ProcessTracker) Untrack(taskID int64) {
|
||||
pt.processes.Delete(taskID)
|
||||
}
|
||||
|
||||
// Get returns the command for a task ID if it exists
|
||||
func (pt *ProcessTracker) Get(taskID int64) (*exec.Cmd, bool) {
|
||||
if val, ok := pt.processes.Load(taskID); ok {
|
||||
return val.(*exec.Cmd), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Kill kills a specific process by task ID
|
||||
// Returns true if the process was found and killed
|
||||
func (pt *ProcessTracker) Kill(taskID int64) bool {
|
||||
cmd, ok := pt.Get(taskID)
|
||||
if !ok || cmd.Process == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Try graceful kill first (SIGINT)
|
||||
if err := cmd.Process.Signal(os.Interrupt); err != nil {
|
||||
// If SIGINT fails, try SIGKILL
|
||||
cmd.Process.Kill()
|
||||
} else {
|
||||
// Give it a moment to clean up gracefully
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// Force kill if still running
|
||||
cmd.Process.Kill()
|
||||
}
|
||||
|
||||
pt.Untrack(taskID)
|
||||
return true
|
||||
}
|
||||
|
||||
// KillAll kills all tracked processes
|
||||
// Returns the number of processes killed
|
||||
func (pt *ProcessTracker) KillAll() int {
|
||||
var killedCount int
|
||||
pt.processes.Range(func(key, value interface{}) bool {
|
||||
taskID := key.(int64)
|
||||
cmd := value.(*exec.Cmd)
|
||||
if cmd.Process != nil {
|
||||
// Try graceful kill first (SIGINT)
|
||||
if err := cmd.Process.Signal(os.Interrupt); err == nil {
|
||||
// Give it a moment to clean up
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
// Force kill
|
||||
cmd.Process.Kill()
|
||||
killedCount++
|
||||
}
|
||||
pt.processes.Delete(taskID)
|
||||
return true
|
||||
})
|
||||
return killedCount
|
||||
}
|
||||
|
||||
// Count returns the number of tracked processes
|
||||
func (pt *ProcessTracker) Count() int {
|
||||
count := 0
|
||||
pt.processes.Range(func(key, value interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
return count
|
||||
}
|
||||
|
||||
// CommandResult holds the output from a command execution
|
||||
type CommandResult struct {
|
||||
Stdout string
|
||||
Stderr string
|
||||
ExitCode int
|
||||
}
|
||||
|
||||
// RunCommand executes a command and returns the output
|
||||
// If tracker is provided, the process will be registered for tracking
|
||||
// This is useful for commands where you need to capture output (like metadata extraction)
|
||||
func RunCommand(
|
||||
cmdPath string,
|
||||
args []string,
|
||||
dir string,
|
||||
env []string,
|
||||
taskID int64,
|
||||
tracker *ProcessTracker,
|
||||
) (*CommandResult, error) {
|
||||
cmd := exec.Command(cmdPath, args...)
|
||||
cmd.Dir = dir
|
||||
if env != nil {
|
||||
cmd.Env = env
|
||||
}
|
||||
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start command: %w", err)
|
||||
}
|
||||
|
||||
// Track the process if tracker is provided
|
||||
if tracker != nil {
|
||||
tracker.Track(taskID, cmd)
|
||||
defer tracker.Untrack(taskID)
|
||||
}
|
||||
|
||||
// Collect stdout
|
||||
var stdoutBuf, stderrBuf []byte
|
||||
var stdoutErr, stderrErr error
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
stdoutBuf, stdoutErr = readAll(stdoutPipe)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
stderrBuf, stderrErr = readAll(stderrPipe)
|
||||
}()
|
||||
|
||||
waitErr := cmd.Wait()
|
||||
wg.Wait()
|
||||
|
||||
// Check for read errors
|
||||
if stdoutErr != nil {
|
||||
return nil, fmt.Errorf("failed to read stdout: %w", stdoutErr)
|
||||
}
|
||||
if stderrErr != nil {
|
||||
return nil, fmt.Errorf("failed to read stderr: %w", stderrErr)
|
||||
}
|
||||
|
||||
result := &CommandResult{
|
||||
Stdout: string(stdoutBuf),
|
||||
Stderr: string(stderrBuf),
|
||||
}
|
||||
|
||||
if waitErr != nil {
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
result.ExitCode = exitErr.ExitCode()
|
||||
} else {
|
||||
result.ExitCode = -1
|
||||
}
|
||||
return result, waitErr
|
||||
}
|
||||
|
||||
result.ExitCode = 0
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// readAll reads all data from a reader
|
||||
func readAll(r interface{ Read([]byte) (int, error) }) ([]byte, error) {
|
||||
var buf []byte
|
||||
tmp := make([]byte, 4096)
|
||||
for {
|
||||
n, err := r.Read(tmp)
|
||||
if n > 0 {
|
||||
buf = append(buf, tmp[:n]...)
|
||||
}
|
||||
if err != nil {
|
||||
if err.Error() == "EOF" {
|
||||
break
|
||||
}
|
||||
return buf, err
|
||||
}
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// LogSender is a function type for sending logs
|
||||
type LogSender func(taskID int, level types.LogLevel, message string, stepName string)
|
||||
|
||||
// LineFilter is a function that processes a line and returns whether to filter it out and the log level
|
||||
type LineFilter func(line string) (shouldFilter bool, level types.LogLevel)
|
||||
|
||||
// RunCommandWithStreaming executes a command with streaming output and OOM detection
|
||||
// If tracker is provided, the process will be registered for tracking
|
||||
func RunCommandWithStreaming(
|
||||
cmdPath string,
|
||||
args []string,
|
||||
dir string,
|
||||
env []string,
|
||||
taskID int,
|
||||
stepName string,
|
||||
logSender LogSender,
|
||||
stdoutFilter LineFilter,
|
||||
stderrFilter LineFilter,
|
||||
oomMessage string,
|
||||
tracker *ProcessTracker,
|
||||
) error {
|
||||
cmd := exec.Command(cmdPath, args...)
|
||||
cmd.Dir = dir
|
||||
cmd.Env = env
|
||||
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("failed to create stdout pipe: %v", err)
|
||||
logSender(taskID, types.LogLevelError, errMsg, stepName)
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("failed to create stderr pipe: %v", err)
|
||||
logSender(taskID, types.LogLevelError, errMsg, stepName)
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
errMsg := fmt.Sprintf("failed to start command: %v", err)
|
||||
logSender(taskID, types.LogLevelError, errMsg, stepName)
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
// Track the process if tracker is provided
|
||||
if tracker != nil {
|
||||
tracker.Track(int64(taskID), cmd)
|
||||
defer tracker.Untrack(int64(taskID))
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
scanner := bufio.NewScanner(stdoutPipe)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line != "" {
|
||||
shouldFilter, level := stdoutFilter(line)
|
||||
if !shouldFilter {
|
||||
logSender(taskID, level, line, stepName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
scanner := bufio.NewScanner(stderrPipe)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line != "" {
|
||||
shouldFilter, level := stderrFilter(line)
|
||||
if !shouldFilter {
|
||||
logSender(taskID, level, line, stepName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err = cmd.Wait()
|
||||
wg.Wait()
|
||||
|
||||
if err != nil {
|
||||
var errMsg string
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
if exitErr.ExitCode() == 137 {
|
||||
errMsg = oomMessage
|
||||
} else {
|
||||
errMsg = fmt.Sprintf("command failed: %v", err)
|
||||
}
|
||||
} else {
|
||||
errMsg = fmt.Sprintf("command failed: %v", err)
|
||||
}
|
||||
logSender(taskID, types.LogLevelError, errMsg, stepName)
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper functions using DefaultTracker
|
||||
// ============================================================================
|
||||
|
||||
// Run executes a command using the default tracker and returns the output
|
||||
// This is a convenience wrapper around RunCommand that uses DefaultTracker
|
||||
func Run(cmdPath string, args []string, dir string, env []string, taskID int64) (*CommandResult, error) {
|
||||
return RunCommand(cmdPath, args, dir, env, taskID, DefaultTracker)
|
||||
}
|
||||
|
||||
// RunStreaming executes a command with streaming output using the default tracker
|
||||
// This is a convenience wrapper around RunCommandWithStreaming that uses DefaultTracker
|
||||
func RunStreaming(
|
||||
cmdPath string,
|
||||
args []string,
|
||||
dir string,
|
||||
env []string,
|
||||
taskID int,
|
||||
stepName string,
|
||||
logSender LogSender,
|
||||
stdoutFilter LineFilter,
|
||||
stderrFilter LineFilter,
|
||||
oomMessage string,
|
||||
) error {
|
||||
return RunCommandWithStreaming(cmdPath, args, dir, env, taskID, stepName, logSender, stdoutFilter, stderrFilter, oomMessage, DefaultTracker)
|
||||
}
|
||||
|
||||
// KillAll kills all processes tracked by the default tracker
|
||||
// Returns the number of processes killed
|
||||
func KillAll() int {
|
||||
return DefaultTracker.KillAll()
|
||||
}
|
||||
|
||||
// Kill kills a specific process by task ID using the default tracker
|
||||
// Returns true if the process was found and killed
|
||||
func Kill(taskID int64) bool {
|
||||
return DefaultTracker.Kill(taskID)
|
||||
}
|
||||
|
||||
// Track registers a process with the default tracker
|
||||
func Track(taskID int64, cmd *exec.Cmd) {
|
||||
DefaultTracker.Track(taskID, cmd)
|
||||
}
|
||||
|
||||
// Untrack removes a process from the default tracker
|
||||
func Untrack(taskID int64) {
|
||||
DefaultTracker.Untrack(taskID)
|
||||
}
|
||||
|
||||
// GetTrackedCount returns the number of processes tracked by the default tracker
|
||||
func GetTrackedCount() int {
|
||||
return DefaultTracker.Count()
|
||||
}
|
||||
Reference in New Issue
Block a user