VictoriaMetrics/lib/backup/fslocal/bandwidth_limiter.go

134 lines
2.4 KiB
Go

package fslocal
import (
"io"
"sync"
"time"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/logger"
)
type bandwidthLimiter struct {
perSecondLimit int
c *sync.Cond
// quota for the current second
quota int
stopCh chan struct{}
wg sync.WaitGroup
}
func newBandwidthLimiter(perSecondLimit int) *bandwidthLimiter {
if perSecondLimit <= 0 {
logger.Panicf("BUG: perSecondLimit must be positive; got %d", perSecondLimit)
}
var bl bandwidthLimiter
bl.perSecondLimit = perSecondLimit
var mu sync.Mutex
bl.c = sync.NewCond(&mu)
bl.stopCh = make(chan struct{})
bl.wg.Add(1)
go func() {
defer bl.wg.Done()
bl.perSecondUpdater()
}()
return &bl
}
func (bl *bandwidthLimiter) MustStop() {
close(bl.stopCh)
bl.wg.Wait()
}
func (bl *bandwidthLimiter) NewReadCloser(rc io.ReadCloser) *bandwidthLimitedReader {
return &bandwidthLimitedReader{
rc: rc,
bl: bl,
}
}
func (bl *bandwidthLimiter) NewWriteCloser(wc io.WriteCloser) *bandwidthLimitedWriter {
return &bandwidthLimitedWriter{
wc: wc,
bl: bl,
}
}
type bandwidthLimitedReader struct {
rc io.ReadCloser
bl *bandwidthLimiter
}
func (blr *bandwidthLimitedReader) Read(p []byte) (int, error) {
quota := blr.bl.GetQuota(len(p))
return blr.rc.Read(p[:quota])
}
func (blr *bandwidthLimitedReader) Close() error {
return blr.rc.Close()
}
type bandwidthLimitedWriter struct {
wc io.WriteCloser
bl *bandwidthLimiter
}
func (blw *bandwidthLimitedWriter) Write(p []byte) (int, error) {
nn := 0
for len(p) > 0 {
quota := blw.bl.GetQuota(len(p))
n, err := blw.wc.Write(p[:quota])
nn += n
if err != nil {
return nn, err
}
p = p[quota:]
}
return nn, nil
}
func (blw *bandwidthLimitedWriter) Close() error {
return blw.wc.Close()
}
func (bl *bandwidthLimiter) perSecondUpdater() {
tc := time.NewTicker(time.Second)
c := bl.c
for {
select {
case <-tc.C:
case <-bl.stopCh:
return
}
c.L.Lock()
bl.quota = bl.perSecondLimit
c.Signal()
c.L.Unlock()
}
}
// GetQuota returns the number in the range [1..n] - the allowed quota for now.
//
// The function blocks until at least 1 can be returned from it.
func (bl *bandwidthLimiter) GetQuota(n int) int {
if n <= 0 {
logger.Panicf("BUG: n must be positive; got %d", n)
}
c := bl.c
c.L.Lock()
for bl.quota <= 0 {
c.Wait()
}
quota := bl.quota
if quota > n {
quota = n
}
bl.quota -= quota
if bl.quota > 0 {
c.Signal()
}
c.L.Unlock()
return quota
}