package cmdext

import (
	"errors"
	"gogs.mikescher.com/BlackForestBytes/goext/langext"
	"gogs.mikescher.com/BlackForestBytes/goext/mathext"
	"gogs.mikescher.com/BlackForestBytes/goext/syncext"
	"os/exec"
	"time"
)

var ErrExitCode = errors.New("process exited with an unexpected exitcode")
var ErrTimeout = errors.New("process did not exit after the specified timeout")
var ErrStderrPrint = errors.New("process did print to stderr stream")

type CommandResult struct {
	StdOut          string
	StdErr          string
	StdCombined     string
	ExitCode        int
	CommandTimedOut bool
}

func run(opt CommandRunner) (CommandResult, error) {
	cmd := exec.Command(opt.program, opt.args...)

	cmd.Env = append(cmd.Env, opt.env...)

	stdoutPipe, err := cmd.StdoutPipe()
	if err != nil {
		return CommandResult{}, err
	}

	stderrPipe, err := cmd.StderrPipe()
	if err != nil {
		return CommandResult{}, err
	}

	preader := pipeReader{
		lineBufferSize: langext.Ptr(128 * 1024 * 1024), // 128MB max size of a single line, is hopefully enough....
		stdout:         stdoutPipe,
		stderr:         stderrPipe,
	}

	err = cmd.Start()
	if err != nil {
		return CommandResult{}, err
	}

	type resultObj struct {
		stdout      string
		stderr      string
		stdcombined string
		err         error
	}

	stderrFailChan := make(chan bool)

	outputChan := make(chan resultObj)
	go func() {
		// we need to first fully read the pipes and then call Wait
		// see https://pkg.go.dev/os/exec#Cmd.StdoutPipe

		listener := make([]CommandListener, 0)
		listener = append(listener, opt.listener...)

		if opt.enforceNoStderr {
			listener = append(listener, genericCommandListener{
				_readRawStderr: langext.Ptr(func(v []byte) {
					if len(v) > 0 {
						stderrFailChan <- true
					}
				}),
			})
		}

		stdout, stderr, stdcombined, err := preader.Read(listener)
		if err != nil {
			outputChan <- resultObj{stdout, stderr, stdcombined, err}
			_ = cmd.Process.Kill()
			return
		}

		err = cmd.Wait()
		if err != nil {
			outputChan <- resultObj{stdout, stderr, stdcombined, err}
		} else {
			outputChan <- resultObj{stdout, stderr, stdcombined, nil}
		}

	}()

	var timeoutChan <-chan time.Time = make(chan time.Time, 1)
	if opt.timeout != nil {
		timeoutChan = time.After(*opt.timeout)
	}

	select {

	case <-timeoutChan:
		_ = cmd.Process.Kill()
		for _, lstr := range opt.listener {
			lstr.Timeout()
		}

		if fallback, ok := syncext.ReadChannelWithTimeout(outputChan, mathext.Min(32*time.Millisecond, *opt.timeout)); ok {
			// most of the time the cmd.Process.Kill() should also ahve finished the pipereader
			// and we can at least return the already collected stdout, stderr, etc
			res := CommandResult{
				StdOut:          fallback.stdout,
				StdErr:          fallback.stderr,
				StdCombined:     fallback.stdcombined,
				ExitCode:        -1,
				CommandTimedOut: true,
			}
			if opt.enforceNoTimeout {
				return res, ErrTimeout
			}
			return res, nil
		} else {
			res := CommandResult{
				StdOut:          "",
				StdErr:          "",
				StdCombined:     "",
				ExitCode:        -1,
				CommandTimedOut: true,
			}
			if opt.enforceNoTimeout {
				return res, ErrTimeout
			}
			return res, nil
		}

	case <-stderrFailChan:
		_ = cmd.Process.Kill()

		if fallback, ok := syncext.ReadChannelWithTimeout(outputChan, 32*time.Millisecond); ok {
			// most of the time the cmd.Process.Kill() should also have finished the pipereader
			// and we can at least return the already collected stdout, stderr, etc
			res := CommandResult{
				StdOut:          fallback.stdout,
				StdErr:          fallback.stderr,
				StdCombined:     fallback.stdcombined,
				ExitCode:        -1,
				CommandTimedOut: false,
			}
			return res, ErrStderrPrint
		} else {
			res := CommandResult{
				StdOut:          "",
				StdErr:          "",
				StdCombined:     "",
				ExitCode:        -1,
				CommandTimedOut: false,
			}
			return res, ErrStderrPrint
		}

	case outobj := <-outputChan:
		var exiterr *exec.ExitError
		if errors.As(outobj.err, &exiterr) {
			excode := exiterr.ExitCode()
			for _, lstr := range opt.listener {
				lstr.Finished(excode)
			}
			res := CommandResult{
				StdOut:          outobj.stdout,
				StdErr:          outobj.stderr,
				StdCombined:     outobj.stdcombined,
				ExitCode:        excode,
				CommandTimedOut: false,
			}
			if opt.enforceExitCodes != nil && !langext.InArray(excode, *opt.enforceExitCodes) {
				return res, ErrExitCode
			}
			return res, nil
		} else if err != nil {
			return CommandResult{}, err
		} else {
			for _, lstr := range opt.listener {
				lstr.Finished(0)
			}
			res := CommandResult{
				StdOut:          outobj.stdout,
				StdErr:          outobj.stderr,
				StdCombined:     outobj.stdcombined,
				ExitCode:        0,
				CommandTimedOut: false,
			}
			if opt.enforceExitCodes != nil && !langext.InArray(0, *opt.enforceExitCodes) {
				return res, ErrExitCode
			}
			return res, nil
		}
	}
}