diff --git a/cmdext/builder.go b/cmdext/builder.go index c97d15f..84aa2c1 100644 --- a/cmdext/builder.go +++ b/cmdext/builder.go @@ -6,18 +6,20 @@ import ( ) type CommandRunner struct { - program string - args []string - timeout *time.Duration - env []string + program string + args []string + timeout *time.Duration + env []string + listener []CommandListener } func Runner(program string) *CommandRunner { return &CommandRunner{ - program: program, - args: make([]string, 0), - timeout: nil, - env: make([]string, 0), + program: program, + args: make([]string, 0), + timeout: nil, + env: make([]string, 0), + listener: make([]CommandListener, 0), } } @@ -51,6 +53,21 @@ func (r *CommandRunner) Envs(env []string) *CommandRunner { return r } +func (r *CommandRunner) Listen(lstr CommandListener) *CommandRunner { + r.listener = append(r.listener, lstr) + return r +} + +func (r *CommandRunner) ListenStdout(lstr func(string)) *CommandRunner { + r.listener = append(r.listener, genericCommandListener{_readStdoutLine: &lstr}) + return r +} + +func (r *CommandRunner) ListenStderr(lstr func(string)) *CommandRunner { + r.listener = append(r.listener, genericCommandListener{_readStderrLine: &lstr}) + return r +} + func (r *CommandRunner) Run() (CommandResult, error) { return run(*r) } diff --git a/cmdext/cmdrunner.go b/cmdext/cmdrunner.go index 22b3d06..9681a54 100644 --- a/cmdext/cmdrunner.go +++ b/cmdext/cmdrunner.go @@ -2,6 +2,7 @@ package cmdext import ( "bufio" + "io" "os/exec" "time" ) @@ -34,48 +35,114 @@ func run(opt CommandRunner) (CommandResult, error) { return CommandResult{}, err } - errch := make(chan error, 1) + errch := make(chan error, 3) go func() { errch <- cmd.Wait() }() + // [1] read raw stdout + + stdoutBufferReader, stdoutBufferWriter := io.Pipe() + stdout := "" + go func() { + buf := make([]byte, 128) + for true { + n, out := stdoutPipe.Read(buf) + + if n > 0 { + txt := string(buf[:n]) + stdout += txt + _, _ = stdoutBufferWriter.Write(buf[:n]) + for _, lstr := range opt.listener { + lstr.ReadRawStdout(buf[:n]) + } + } + if out == io.EOF { + break + } + if out != nil { + errch <- out + _ = cmd.Process.Kill() + break + } + } + _ = stdoutBufferWriter.Close() + }() + + // [2] read raw stderr + + stderrBufferReader, stderrBufferWriter := io.Pipe() + stderr := "" + go func() { + buf := make([]byte, 128) + for true { + n, err := stderrPipe.Read(buf) + + if n > 0 { + txt := string(buf[:n]) + stderr += txt + _, _ = stderrBufferWriter.Write(buf[:n]) + for _, lstr := range opt.listener { + lstr.ReadRawStderr(buf[:n]) + } + } + if err == io.EOF { + break + } + if err != nil { + errch <- err + _ = cmd.Process.Kill() + break + } + } + _ = stderrBufferWriter.Close() + }() + combch := make(chan string, 32) stopCombch := make(chan bool) - stdout := "" + // [3] collect stdout line-by-line + go func() { - scanner := bufio.NewScanner(stdoutPipe) + scanner := bufio.NewScanner(stdoutBufferReader) for scanner.Scan() { txt := scanner.Text() - stdout += txt + for _, lstr := range opt.listener { + lstr.ReadStdoutLine(txt) + } combch <- txt } }() - stderr := "" + // [4] collect stderr line-by-line + go func() { - scanner := bufio.NewScanner(stderrPipe) + scanner := bufio.NewScanner(stderrBufferReader) for scanner.Scan() { txt := scanner.Text() - stderr += txt + for _, lstr := range opt.listener { + lstr.ReadStderrLine(txt) + } combch <- txt } }() - defer func() { - stopCombch <- true - }() + defer func() { stopCombch <- true }() + + // [5] combine stdcombined stdcombined := "" go func() { for { select { case txt := <-combch: - stdcombined += txt + stdcombined += txt + "\n" // this comes from bufio.Scanner and has no newlines... case <-stopCombch: return } } }() + // [6] run + var timeoutChan <-chan time.Time = make(chan time.Time, 1) if opt.timeout != nil { timeoutChan = time.After(*opt.timeout) @@ -85,6 +152,9 @@ func run(opt CommandRunner) (CommandResult, error) { case <-timeoutChan: _ = cmd.Process.Kill() + for _, lstr := range opt.listener { + lstr.Timeout() + } return CommandResult{ StdOut: stdout, StdErr: stderr, @@ -95,16 +165,23 @@ func run(opt CommandRunner) (CommandResult, error) { case err := <-errch: if exiterr, ok := err.(*exec.ExitError); ok { + excode := exiterr.ExitCode() + for _, lstr := range opt.listener { + lstr.Finished(excode) + } return CommandResult{ StdOut: stdout, StdErr: stderr, StdCombined: stdcombined, - ExitCode: exiterr.ExitCode(), + ExitCode: excode, CommandTimedOut: false, }, nil } else if err != nil { return CommandResult{}, err } else { + for _, lstr := range opt.listener { + lstr.Finished(0) + } return CommandResult{ StdOut: stdout, StdErr: stderr, diff --git a/cmdext/cmdrunner_test.go b/cmdext/cmdrunner_test.go new file mode 100644 index 0000000..ca4548d --- /dev/null +++ b/cmdext/cmdrunner_test.go @@ -0,0 +1,59 @@ +package cmdext + +import "testing" + +func TestStdout(t *testing.T) { + + res1, err := Runner("printf").Arg("hello").Run() + if err != nil { + t.Errorf("%v", err) + } + if res1.StdErr != "" { + t.Errorf("res1.StdErr == '%v'", res1.StdErr) + } + if res1.StdOut != "hello" { + t.Errorf("res1.StdOut == '%v'", res1.StdOut) + } + if res1.StdCombined != "hello\n" { + t.Errorf("res1.StdCombined == '%v'", res1.StdCombined) + } + +} + +func TestStderr(t *testing.T) { + + res1, err := Runner("python").Arg("-c").Arg("import sys; print(\"error\", file=sys.stderr, end='')").Run() + if err != nil { + t.Errorf("%v", err) + } + if res1.StdErr != "error" { + t.Errorf("res1.StdErr == '%v'", res1.StdErr) + } + if res1.StdOut != "" { + t.Errorf("res1.StdOut == '%v'", res1.StdOut) + } + if res1.StdCombined != "error\n" { + t.Errorf("res1.StdCombined == '%v'", res1.StdCombined) + } + +} + +func TestStdcombined(t *testing.T) { + res1, err := Runner("python"). + Arg("-c"). + Arg("import sys; import time; print(\"1\", file=sys.stderr, flush=True); time.sleep(0.1); print(\"2\", file=sys.stdout, flush=True); time.sleep(0.1); print(\"3\", file=sys.stderr, flush=True)"). + Run() + if err != nil { + t.Errorf("%v", err) + } + if res1.StdErr != "1\n3\n" { + t.Errorf("res1.StdErr == '%v'", res1.StdErr) + } + if res1.StdOut != "2\n" { + t.Errorf("res1.StdOut == '%v'", res1.StdOut) + } + if res1.StdCombined != "1\n2\n3\n" { + t.Errorf("res1.StdCombined == '%v'", res1.StdCombined) + } + +} diff --git a/cmdext/listener.go b/cmdext/listener.go new file mode 100644 index 0000000..b17a520 --- /dev/null +++ b/cmdext/listener.go @@ -0,0 +1,57 @@ +package cmdext + +type CommandListener interface { + ReadRawStdout([]byte) + ReadRawStderr([]byte) + + ReadStdoutLine(string) + ReadStderrLine(string) + + Finished(int) + Timeout() +} + +type genericCommandListener struct { + _readRawStdout *func([]byte) + _readRawStderr *func([]byte) + _readStdoutLine *func(string) + _readStderrLine *func(string) + _finished *func(int) + _timeout *func() +} + +func (g genericCommandListener) ReadRawStdout(v []byte) { + if g._readRawStdout != nil { + (*g._readRawStdout)(v) + } +} + +func (g genericCommandListener) ReadRawStderr(v []byte) { + if g._readRawStderr != nil { + (*g._readRawStderr)(v) + } +} + +func (g genericCommandListener) ReadStdoutLine(v string) { + if g._readStdoutLine != nil { + (*g._readStdoutLine)(v) + } +} + +func (g genericCommandListener) ReadStderrLine(v string) { + if g._readStderrLine != nil { + (*g._readStderrLine)(v) + } +} + +func (g genericCommandListener) Finished(v int) { + if g._finished != nil { + (*g._finished)(v) + } +} + +func (g genericCommandListener) Timeout() { + if g._timeout != nil { + (*g._timeout)() + } +}