package pacelimiter import ( "fmt" "runtime" "sync" "testing" "time" ) func TestPacelimiter(t *testing.T) { t.Run("nonblocking", func(t *testing.T) { pl := New() ch := make(chan struct{}, 10) for i := 0; i < cap(ch); i++ { go func() { for j := 0; j < 10; j++ { pl.WaitIfNeeded() runtime.Gosched() } ch <- struct{}{} }() } // Check that all the goroutines are finished. timeoutCh := time.After(5 * time.Second) for i := 0; i < cap(ch); i++ { select { case <-ch: case <-timeoutCh: t.Fatalf("timeout") } } if n := pl.DelaysTotal(); n > 0 { t.Fatalf("unexpected non-zero number of delays: %d", n) } }) t.Run("blocking", func(t *testing.T) { pl := New() pl.Inc() ch := make(chan struct{}, 10) var wg sync.WaitGroup for i := 0; i < cap(ch); i++ { wg.Add(1) go func() { wg.Done() for j := 0; j < 10; j++ { pl.WaitIfNeeded() } ch <- struct{}{} }() } // Check that all the goroutines created above are started and blocked in WaitIfNeeded wg.Wait() select { case <-ch: t.Fatalf("the pl must be blocked") default: } // Unblock goroutines and check that they are unblocked. pl.Dec() timeoutCh := time.After(5 * time.Second) for i := 0; i < cap(ch); i++ { select { case <-ch: case <-timeoutCh: t.Fatalf("timeout") } } if n := pl.DelaysTotal(); n == 0 { t.Fatalf("expecting non-zero number of delays") } // Verify that the pl is unblocked now. pl.WaitIfNeeded() // Verify that negative count doesn't block pl. pl.Dec() pl.WaitIfNeeded() if n := pl.DelaysTotal(); n == 0 { t.Fatalf("expecting non-zero number of delays after subsequent pl.Dec()") } }) t.Run("negative_count", func(t *testing.T) { n := 10 pl := New() for i := 0; i < n; i++ { pl.Dec() } doneCh := make(chan error) go func() { defer close(doneCh) for i := 0; i < n; i++ { pl.Inc() pl.WaitIfNeeded() if n := pl.DelaysTotal(); n != 0 { doneCh <- fmt.Errorf("expecting zero number of delays") return } } doneCh <- nil }() select { case err := <-doneCh: if err != nil { t.Fatalf("unexpected error: %s", err) } case <-time.After(5 * time.Second): t.Fatalf("timeout") } }) t.Run("concurrent_inc_dec", func(t *testing.T) { pl := New() ch := make(chan struct{}, 10) for i := 0; i < cap(ch); i++ { go func() { for j := 0; j < 10; j++ { pl.Inc() runtime.Gosched() pl.Dec() } ch <- struct{}{} }() } // Verify that all the goroutines are finished timeoutCh := time.After(5 * time.Second) for i := 0; i < cap(ch); i++ { select { case <-ch: case <-timeoutCh: t.Fatalf("timeout") } } // Verify that the pl is unblocked. pl.WaitIfNeeded() if n := pl.DelaysTotal(); n > 0 { t.Fatalf("expecting zer number of delays; got %d", n) } }) }