package syncext

import (
	"context"
	"gogs.mikescher.com/BlackForestBytes/goext/langext"
	"sync"
	"time"
)

type AtomicBool struct {
	v        bool
	listener map[string]chan bool
	lock     sync.Mutex
}

func NewAtomicBool(value bool) *AtomicBool {
	return &AtomicBool{
		v:        value,
		listener: make(map[string]chan bool),
		lock:     sync.Mutex{},
	}
}

func (a *AtomicBool) Get() bool {
	a.lock.Lock()
	defer a.lock.Unlock()
	return a.v
}

func (a *AtomicBool) Set(value bool) bool {
	a.lock.Lock()
	defer a.lock.Unlock()

	oldValue := a.v

	a.v = value

	for k, v := range a.listener {
		select {
		case v <- value:
			// message sent
		default:
			// no receiver on channel
			delete(a.listener, k)
		}
	}

	return oldValue
}

func (a *AtomicBool) Wait(waitFor bool) {
	_ = a.WaitWithContext(context.Background(), waitFor)
}

func (a *AtomicBool) WaitWithTimeout(timeout time.Duration, waitFor bool) error {
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
	defer cancel()
	return a.WaitWithContext(ctx, waitFor)
}

func (a *AtomicBool) WaitWithContext(ctx context.Context, waitFor bool) error {
	if err := ctx.Err(); err != nil {
		return err
	}

	if a.Get() == waitFor {
		return nil
	}

	uuid, _ := langext.NewHexUUID()

	waitchan := make(chan bool)

	a.lock.Lock()
	a.listener[uuid] = waitchan
	a.lock.Unlock()
	defer func() {
		a.lock.Lock()
		delete(a.listener, uuid)
		a.lock.Unlock()
	}()

	for {
		if err := ctx.Err(); err != nil {
			return err
		}

		timeOut := 1024 * time.Millisecond

		if dl, ok := ctx.Deadline(); ok {
			timeOutMax := dl.Sub(time.Now())
			if timeOutMax <= 0 {
				timeOut = 0
			} else if 0 < timeOutMax && timeOutMax < timeOut {
				timeOut = timeOutMax
			}
		}

		if v, ok := ReadChannelWithTimeout(waitchan, timeOut); ok {
			if v == waitFor {
				return nil
			}
		} else {
			if err := ctx.Err(); err != nil {
				return err
			}

			if a.Get() == waitFor {
				return nil
			}
		}
	}
}