diff --git a/cmdext/builder.go b/cmdext/builder.go index 84aa2c1..b707643 100644 --- a/cmdext/builder.go +++ b/cmdext/builder.go @@ -2,24 +2,29 @@ package cmdext import ( "fmt" + "gogs.mikescher.com/BlackForestBytes/goext/langext" "time" ) type CommandRunner struct { - program string - args []string - timeout *time.Duration - env []string - listener []CommandListener + program string + args []string + timeout *time.Duration + env []string + listener []CommandListener + enforceExitCodes *[]int + enforceNoTimeout bool } func Runner(program string) *CommandRunner { return &CommandRunner{ - program: program, - args: make([]string, 0), - timeout: nil, - env: make([]string, 0), - listener: make([]CommandListener, 0), + program: program, + args: make([]string, 0), + timeout: nil, + env: make([]string, 0), + listener: make([]CommandListener, 0), + enforceExitCodes: nil, + enforceNoTimeout: false, } } @@ -53,6 +58,21 @@ func (r *CommandRunner) Envs(env []string) *CommandRunner { return r } +func (r *CommandRunner) EnsureExitcode(arg ...int) *CommandRunner { + r.enforceExitCodes = langext.Ptr(langext.ForceArray(arg)) + return r +} + +func (r *CommandRunner) FailOnExitCode() *CommandRunner { + r.enforceExitCodes = langext.Ptr([]int{0}) + return r +} + +func (r *CommandRunner) FailOnTimeout() *CommandRunner { + r.enforceNoTimeout = true + return r +} + func (r *CommandRunner) Listen(lstr CommandListener) *CommandRunner { r.listener = append(r.listener, lstr) return r diff --git a/cmdext/cmdrunner.go b/cmdext/cmdrunner.go index 1a3550f..ab61f8f 100644 --- a/cmdext/cmdrunner.go +++ b/cmdext/cmdrunner.go @@ -1,12 +1,17 @@ package cmdext import ( + "errors" + "gogs.mikescher.com/BlackForestBytes/goext/langext" "gogs.mikescher.com/BlackForestBytes/goext/mathext" "gogs.mikescher.com/BlackForestBytes/goext/syncext" "os/exec" "time" ) +var ErrExitCode = errors.New("process exited with an unexpected exitcode") +var ErrTimeout = errors.New("process did not exit after the specified timeout") + type CommandResult struct { StdOut string StdErr string @@ -81,21 +86,29 @@ func run(opt CommandRunner) (CommandResult, error) { if fallback, ok := syncext.ReadChannelWithTimeout(outputChan, mathext.Min(32*time.Millisecond, *opt.timeout)); ok { // most of the time the cmd.Process.Kill() should also ahve finished the pipereader // and we can at least return the already collected stdout, stderr, etc - return CommandResult{ + res := CommandResult{ StdOut: fallback.stdout, StdErr: fallback.stderr, StdCombined: fallback.stdcombined, ExitCode: -1, CommandTimedOut: true, - }, nil + } + if opt.enforceNoTimeout { + return res, ErrTimeout + } + return res, nil } else { - return CommandResult{ + res := CommandResult{ StdOut: "", StdErr: "", StdCombined: "", ExitCode: -1, CommandTimedOut: true, - }, nil + } + if opt.enforceNoTimeout { + return res, ErrTimeout + } + return res, nil } case outobj := <-outputChan: @@ -104,26 +117,34 @@ func run(opt CommandRunner) (CommandResult, error) { for _, lstr := range opt.listener { lstr.Finished(excode) } - return CommandResult{ + res := CommandResult{ StdOut: outobj.stdout, StdErr: outobj.stderr, StdCombined: outobj.stdcombined, ExitCode: excode, CommandTimedOut: false, - }, nil + } + if opt.enforceExitCodes != nil && !langext.InArray(excode, *opt.enforceExitCodes) { + return res, ErrExitCode + } + return res, nil } else if err != nil { return CommandResult{}, err } else { for _, lstr := range opt.listener { lstr.Finished(0) } - return CommandResult{ + res := CommandResult{ StdOut: outobj.stdout, StdErr: outobj.stderr, StdCombined: outobj.stdcombined, ExitCode: 0, CommandTimedOut: false, - }, nil + } + if opt.enforceExitCodes != nil && !langext.InArray(0, *opt.enforceExitCodes) { + return res, ErrExitCode + } + return res, nil } } } diff --git a/cmdext/cmdrunner_test.go b/cmdext/cmdrunner_test.go index ac65670..fa50bc2 100644 --- a/cmdext/cmdrunner_test.go +++ b/cmdext/cmdrunner_test.go @@ -12,6 +12,12 @@ func TestStdout(t *testing.T) { if err != nil { t.Errorf("%v", err) } + if res1.CommandTimedOut { + t.Errorf("Timeout") + } + if res1.ExitCode != 0 { + t.Errorf("res1.ExitCode == %v", res1.ExitCode) + } if res1.StdErr != "" { t.Errorf("res1.StdErr == '%v'", res1.StdErr) } @@ -30,6 +36,12 @@ func TestStderr(t *testing.T) { if err != nil { t.Errorf("%v", err) } + if res1.CommandTimedOut { + t.Errorf("Timeout") + } + if res1.ExitCode != 0 { + t.Errorf("res1.ExitCode == %v", res1.ExitCode) + } if res1.StdErr != "error" { t.Errorf("res1.StdErr == '%v'", res1.StdErr) } @@ -50,6 +62,12 @@ func TestStdcombined(t *testing.T) { if err != nil { t.Errorf("%v", err) } + if res1.CommandTimedOut { + t.Errorf("Timeout") + } + if res1.ExitCode != 0 { + t.Errorf("res1.ExitCode == %v", res1.ExitCode) + } if res1.StdErr != "1\n3\n" { t.Errorf("res1.StdErr == '%v'", res1.StdErr) } @@ -116,6 +134,12 @@ func TestReadUnflushedStdout(t *testing.T) { if err != nil { t.Errorf("%v", err) } + if res1.CommandTimedOut { + t.Errorf("Timeout") + } + if res1.ExitCode != 0 { + t.Errorf("res1.ExitCode == %v", res1.ExitCode) + } if res1.StdErr != "" { t.Errorf("res1.StdErr == '%v'", res1.StdErr) } @@ -134,6 +158,12 @@ func TestReadUnflushedStderr(t *testing.T) { if err != nil { t.Errorf("%v", err) } + if res1.CommandTimedOut { + t.Errorf("Timeout") + } + if res1.ExitCode != 0 { + t.Errorf("res1.ExitCode == %v", res1.ExitCode) + } if res1.StdErr != "message101" { t.Errorf("res1.StdErr == '%v'", res1.StdErr) } @@ -200,7 +230,7 @@ func TestPartialReadUnflushedStderr(t *testing.T) { func TestListener(t *testing.T) { - _, err := Runner("python"). + res1, err := Runner("python"). Arg("-c"). Arg("import sys;" + "import time;" + @@ -223,23 +253,71 @@ func TestListener(t *testing.T) { if err != nil { panic(err) } + if res1.CommandTimedOut { + t.Errorf("Timeout") + } + if res1.ExitCode != 0 { + t.Errorf("res1.ExitCode == %v", res1.ExitCode) + } } func TestLongStdout(t *testing.T) { res1, err := Runner("python"). Arg("-c"). - Arg("import sys; import time; print(\"1234\" * 10000 + \"\\n\"); print(\"1234\" * 10000 + \"\\n\"); print(\"1234\" * 10000 + \"\\n\");"). - Timeout(100 * time.Millisecond). + Arg("import sys; import time; print(\"X\" * 125001 + \"\\n\"); print(\"Y\" * 125001 + \"\\n\"); print(\"Z\" * 125001 + \"\\n\");"). + Timeout(5000 * time.Millisecond). Run() if err != nil { t.Errorf("%v", err) } + if res1.CommandTimedOut { + t.Errorf("Timeout") + } + if res1.ExitCode != 0 { + t.Errorf("res1.ExitCode == %v", res1.ExitCode) + } if res1.StdErr != "" { t.Errorf("res1.StdErr == '%v'", res1.StdErr) } - if len(res1.StdOut) != 120006 { + if len(res1.StdOut) != 375006 { t.Errorf("len(res1.StdOut) == '%v'", len(res1.StdOut)) } } + +func TestFailOnTimeout(t *testing.T) { + + _, err := Runner("sleep").Arg("2").Timeout(200 * time.Millisecond).FailOnTimeout().Run() + if err != ErrTimeout { + t.Errorf("wrong err := %v", err) + } + +} + +func TestFailOnExitcode(t *testing.T) { + + _, err := Runner("false").Timeout(200 * time.Millisecond).FailOnExitCode().Run() + if err != ErrExitCode { + t.Errorf("wrong err := %v", err) + } + +} + +func TestEnsureExitcode1(t *testing.T) { + + _, err := Runner("false").Timeout(200 * time.Millisecond).EnsureExitcode(1).Run() + if err != nil { + t.Errorf("wrong err := %v", err) + } + +} + +func TestEnsureExitcode2(t *testing.T) { + + _, err := Runner("false").Timeout(200*time.Millisecond).EnsureExitcode(0, 2, 3).Run() + if err != ErrExitCode { + t.Errorf("wrong err := %v", err) + } + +}