diff --git a/cmdext/builder.go b/cmdext/builder.go index b707643..f95a811 100644 --- a/cmdext/builder.go +++ b/cmdext/builder.go @@ -14,6 +14,7 @@ type CommandRunner struct { listener []CommandListener enforceExitCodes *[]int enforceNoTimeout bool + enforceNoStderr bool } func Runner(program string) *CommandRunner { @@ -25,6 +26,7 @@ func Runner(program string) *CommandRunner { listener: make([]CommandListener, 0), enforceExitCodes: nil, enforceNoTimeout: false, + enforceNoStderr: false, } } @@ -73,6 +75,11 @@ func (r *CommandRunner) FailOnTimeout() *CommandRunner { return r } +func (r *CommandRunner) FailOnStderr() *CommandRunner { + r.enforceNoStderr = true + return r +} + func (r *CommandRunner) Listen(lstr CommandListener) *CommandRunner { r.listener = append(r.listener, lstr) return r diff --git a/cmdext/cmdrunner.go b/cmdext/cmdrunner.go index e471a09..7007b4b 100644 --- a/cmdext/cmdrunner.go +++ b/cmdext/cmdrunner.go @@ -11,6 +11,7 @@ import ( var ErrExitCode = errors.New("process exited with an unexpected exitcode") var ErrTimeout = errors.New("process did not exit after the specified timeout") +var ErrStderrPrint = errors.New("process did print to stderr stream") type CommandResult struct { StdOut string @@ -53,12 +54,27 @@ func run(opt CommandRunner) (CommandResult, error) { err error } + stderrFailChan := make(chan bool) + outputChan := make(chan resultObj) go func() { // we need to first fully read the pipes and then call Wait // see https://pkg.go.dev/os/exec#Cmd.StdoutPipe - stdout, stderr, stdcombined, err := preader.Read(opt.listener) + listener := make([]CommandListener, 0) + listener = append(listener, opt.listener...) + + if opt.enforceNoStderr { + listener = append(listener, genericCommandListener{ + _readRawStderr: langext.Ptr(func(v []byte) { + if len(v) > 0 { + stderrFailChan <- true + } + }), + }) + } + + stdout, stderr, stdcombined, err := preader.Read(listener) if err != nil { outputChan <- resultObj{stdout, stderr, stdcombined, err} _ = cmd.Process.Kill() @@ -115,6 +131,34 @@ func run(opt CommandRunner) (CommandResult, error) { return res, nil } + case <-stderrFailChan: + _ = cmd.Process.Kill() + for _, lstr := range opt.listener { + lstr.Timeout() + } + + if fallback, ok := syncext.ReadChannelWithTimeout(outputChan, 32*time.Millisecond); ok { + // most of the time the cmd.Process.Kill() should also have finished the pipereader + // and we can at least return the already collected stdout, stderr, etc + res := CommandResult{ + StdOut: fallback.stdout, + StdErr: fallback.stderr, + StdCombined: fallback.stdcombined, + ExitCode: -1, + CommandTimedOut: false, + } + return res, ErrStderrPrint + } else { + res := CommandResult{ + StdOut: "", + StdErr: "", + StdCombined: "", + ExitCode: -1, + CommandTimedOut: false, + } + return res, ErrStderrPrint + } + case outobj := <-outputChan: if exiterr, ok := outobj.err.(*exec.ExitError); ok { excode := exiterr.ExitCode() diff --git a/cmdext/cmdrunner_test.go b/cmdext/cmdrunner_test.go index 8651cf4..d83351a 100644 --- a/cmdext/cmdrunner_test.go +++ b/cmdext/cmdrunner_test.go @@ -1,6 +1,7 @@ package cmdext import ( + "errors" "fmt" "testing" "time" @@ -289,16 +290,40 @@ func TestLongStdout(t *testing.T) { func TestFailOnTimeout(t *testing.T) { _, err := Runner("sleep").Arg("2").Timeout(200 * time.Millisecond).FailOnTimeout().Run() - if err != ErrTimeout { + if !errors.Is(err, ErrTimeout) { t.Errorf("wrong err := %v", err) } } +func TestFailOnStderr(t *testing.T) { + + res1, err := Runner("python").Arg("-c").Arg("import sys; print(\"error\", file=sys.stderr, end='')").FailOnStderr().Run() + if err == nil { + t.Errorf("no err") + } + if res1.CommandTimedOut { + t.Errorf("Timeout") + } + if res1.ExitCode != -1 { + t.Errorf("res1.ExitCode == %v", res1.ExitCode) + } + if res1.StdErr != "error" { + t.Errorf("res1.StdErr == '%v'", res1.StdErr) + } + if res1.StdOut != "" { + t.Errorf("res1.StdOut == '%v'", res1.StdOut) + } + if res1.StdCombined != "error\n" { + t.Errorf("res1.StdCombined == '%v'", res1.StdCombined) + } + +} + func TestFailOnExitcode(t *testing.T) { _, err := Runner("false").Timeout(200 * time.Millisecond).FailOnExitCode().Run() - if err != ErrExitCode { + if !errors.Is(err, ErrExitCode) { t.Errorf("wrong err := %v", err) } diff --git a/goextVersion.go b/goextVersion.go index a2c1988..4f0d359 100644 --- a/goextVersion.go +++ b/goextVersion.go @@ -1,5 +1,5 @@ package goext -const GoextVersion = "0.0.235" +const GoextVersion = "0.0.236" -const GoextVersionTimestamp = "2023-08-09T14:40:16+0200" +const GoextVersionTimestamp = "2023-08-09T17:48:06+0200"