diff --git a/syncext/atomic.go b/syncext/atomic.go index 7e23522..a03d1a8 100644 --- a/syncext/atomic.go +++ b/syncext/atomic.go @@ -1,16 +1,21 @@ package dataext -import "sync/atomic" +import ( + "context" + "sync/atomic" + "time" +) type AtomicBool struct { - v int32 + v int32 + waiter chan bool // unbuffered } func NewAtomicBool(value bool) *AtomicBool { if value { - return &AtomicBool{v: 0} + return &AtomicBool{v: 0, waiter: make(chan bool)} } else { - return &AtomicBool{v: 1} + return &AtomicBool{v: 1, waiter: make(chan bool)} } } @@ -24,4 +29,70 @@ func (a *AtomicBool) Set(value bool) { } else { atomic.StoreInt32(&a.v, 0) } + + select { + case a.waiter <- value: + // message sent + default: + // no receiver on channel + } +} + +func (a *AtomicBool) Wait(waitFor bool) { + if a.Get() == waitFor { + return + } + + for { + if v, ok := ReadChannelWithTimeout(a.waiter, 128*time.Millisecond); ok { + if v == waitFor { + return + } + } else { + if a.Get() == waitFor { + return + } + } + } +} + +func (a *AtomicBool) WaitWithContext(ctx context.Context, waitFor bool) error { + if err := ctx.Err(); err != nil { + return err + } + + if a.Get() == waitFor { + return nil + } + + for { + if err := ctx.Err(); err != nil { + return err + } + + timeOut := 128 * time.Millisecond + + if dl, ok := ctx.Deadline(); ok { + timeOutMax := dl.Sub(time.Now()) + if timeOutMax <= 0 { + timeOut = 0 + } else if 0 < timeOutMax && timeOutMax < timeOut { + timeOut = timeOutMax + } + } + + if v, ok := ReadChannelWithTimeout(a.waiter, timeOut); ok { + if v == waitFor { + return nil + } + } else { + if err := ctx.Err(); err != nil { + return err + } + + if a.Get() == waitFor { + return nil + } + } + } } diff --git a/syncext/channel.go b/syncext/channel.go new file mode 100644 index 0000000..4c4c03b --- /dev/null +++ b/syncext/channel.go @@ -0,0 +1,14 @@ +package dataext + +import "time" + +func ReadChannelWithTimeout[T any](c chan T, timeout time.Duration) (T, bool) { + afterCh := time.After(timeout) + select { + case rv := <-c: + return rv, true + case <-afterCh: + return *new(T), false + } + +} diff --git a/syncext/channel_test.go b/syncext/channel_test.go new file mode 100644 index 0000000..fa3d843 --- /dev/null +++ b/syncext/channel_test.go @@ -0,0 +1,121 @@ +package dataext + +import ( + "testing" + "time" +) + +func TestTimeoutReadBuffered(t *testing.T) { + c := make(chan int, 1) + + go func() { + time.Sleep(200 * time.Millisecond) + c <- 112 + }() + + _, ok := ReadChannelWithTimeout(c, 100*time.Millisecond) + + if ok { + t.Error("Read success, but should timeout") + } +} + +func TestTimeoutReadBigBuffered(t *testing.T) { + c := make(chan int, 128) + + go func() { + time.Sleep(200 * time.Millisecond) + c <- 112 + }() + + _, ok := ReadChannelWithTimeout(c, 100*time.Millisecond) + + if ok { + t.Error("Read success, but should timeout") + } +} + +func TestTimeoutReadUnbuffered(t *testing.T) { + c := make(chan int) + + go func() { + time.Sleep(200 * time.Millisecond) + c <- 112 + }() + + _, ok := ReadChannelWithTimeout(c, 100*time.Millisecond) + + if ok { + t.Error("Read success, but should timeout") + } +} + +func TestNoTimeoutAfterStartReadBuffered(t *testing.T) { + c := make(chan int, 1) + + go func() { + time.Sleep(10 * time.Millisecond) + c <- 112 + }() + + _, ok := ReadChannelWithTimeout(c, 100*time.Millisecond) + + if !ok { + t.Error("Read timeout, but should have succeeded") + } +} + +func TestNoTimeoutAfterStartReadBigBuffered(t *testing.T) { + c := make(chan int, 128) + + go func() { + time.Sleep(10 * time.Millisecond) + c <- 112 + }() + + _, ok := ReadChannelWithTimeout(c, 100*time.Millisecond) + + if !ok { + t.Error("Read timeout, but should have succeeded") + } +} + +func TestNoTimeoutAfterStartReadUnbuffered(t *testing.T) { + c := make(chan int) + + go func() { + time.Sleep(10 * time.Millisecond) + c <- 112 + }() + + _, ok := ReadChannelWithTimeout(c, 100*time.Millisecond) + + if !ok { + t.Error("Read timeout, but should have succeeded") + } + +} + +func TestNoTimeoutBeforeStartReadBuffered(t *testing.T) { + c := make(chan int, 1) + + c <- 112 + + _, ok := ReadChannelWithTimeout(c, 10*time.Millisecond) + + if !ok { + t.Error("Read timeout, but should have succeeded") + } +} + +func TestNoTimeoutBeforeStartReadBigBuffered(t *testing.T) { + c := make(chan int, 128) + + c <- 112 + + _, ok := ReadChannelWithTimeout(c, 10*time.Millisecond) + + if !ok { + t.Error("Read timeout, but should have succeeded") + } +}