package uint64set import ( "math/bits" "sort" "sync" "sync/atomic" "unsafe" ) // Set is a fast set for uint64. // // It should work faster than map[uint64]struct{} for semi-sparse uint64 values // such as MetricIDs generated by lib/storage. // // It is unsafe calling Set methods from concurrent goroutines. type Set struct { itemsCount int buckets bucket32Sorter // Most likely the buckets contains only a single item, so put it here for performance reasons // in order to improve memory locality. scratchBuckets [1]bucket32 } type bucket32Sorter []bucket32 func (s *bucket32Sorter) Len() int { return len(*s) } func (s *bucket32Sorter) Less(i, j int) bool { a := *s return a[i].hi < a[j].hi } func (s *bucket32Sorter) Swap(i, j int) { a := *s a[i], a[j] = a[j], a[i] } // Clone returns an independent copy of s. func (s *Set) Clone() *Set { if s == nil || s.itemsCount == 0 { // Return an empty set, so data could be added into it later. return &Set{} } var dst Set dst.itemsCount = s.itemsCount if len(s.buckets) == 1 { dst.buckets = dst.scratchBuckets[:] } else { dst.buckets = make([]bucket32, len(s.buckets)) } for i := range s.buckets { s.buckets[i].copyTo(&dst.buckets[i]) } return &dst } func (s *Set) fixItemsCount() { n := 0 for i := range s.buckets { n += s.buckets[i].getLen() } s.itemsCount = n } func (s *Set) cloneShallow() *Set { var dst Set dst.itemsCount = s.itemsCount if len(s.buckets) == 1 { dst.buckets = dst.scratchBuckets[:] } dst.buckets = append(dst.buckets[:0], s.buckets...) return &dst } // SizeBytes returns an estimate size of s in RAM. func (s *Set) SizeBytes() uint64 { if s == nil { return 0 } n := uint64(unsafe.Sizeof(*s)) for i := range s.buckets { n += s.buckets[i].sizeBytes() } return n } // Len returns the number of distinct uint64 values in s. func (s *Set) Len() int { if s == nil { return 0 } return s.itemsCount } // Add adds x to s. func (s *Set) Add(x uint64) { hi32 := uint32(x >> 32) lo32 := uint32(x) bs := s.buckets if len(bs) > 0 && bs[0].hi == hi32 { // Manually inline bucket32.add for performance reasons. hi16 := uint16(lo32 >> 16) lo16 := uint16(lo32) b32 := &bs[0] his := b32.b16his if n := b32.getHint(); n < uint32(len(his)) && his[n] == hi16 { bs := b32.buckets if n < uint32(len(bs)) && bs[n].add(lo16) { s.itemsCount++ } return } if b32.addSlow(hi16, lo16) { s.itemsCount++ } return } for i := range bs { b32 := &bs[i] if b32.hi == hi32 { if b32.add(lo32) { s.itemsCount++ } return } } b32 := s.addBucket32() b32.hi = hi32 _ = b32.add(lo32) s.itemsCount++ } // AddMulti adds all the items from a to s. // // It is usually faster than calling s.Add() for each item in a. // // The caller is responsible for splitting a into items with clustered values. func (s *Set) AddMulti(a []uint64) { if len(a) == 0 { return } hiPrev := uint32(a[0] >> 32) i := 0 for j, x := range a { hi := uint32(x >> 32) if hi == hiPrev { continue } b32 := s.getOrCreateBucket32(hiPrev) s.itemsCount += b32.addMulti(a[i:j]) hiPrev = hi i = j } b32 := s.getOrCreateBucket32(hiPrev) s.itemsCount += b32.addMulti(a[i:]) } func (s *Set) getOrCreateBucket32(hi uint32) *bucket32 { bs := s.buckets for i := range bs { if bs[i].hi == hi { return &bs[i] } } b32 := s.addBucket32() b32.hi = hi return b32 } func (s *Set) addBucket32() *bucket32 { if len(s.buckets) == 0 { s.buckets = s.scratchBuckets[:] } else { s.buckets = append(s.buckets, bucket32{}) } return &s.buckets[len(s.buckets)-1] } // Has verifies whether x exists in s. func (s *Set) Has(x uint64) bool { if s == nil { return false } hi32 := uint32(x >> 32) lo32 := uint32(x) bs := s.buckets for i := range bs { b32 := &bs[i] if b32.hi == hi32 { return b32.has(lo32) } } return false } // Del deletes x from s. func (s *Set) Del(x uint64) { hi := uint32(x >> 32) lo := uint32(x) bs := s.buckets if len(bs) > 0 && bs[0].hi == hi { if bs[0].del(lo) { s.itemsCount-- } return } for i := range bs { b32 := &bs[i] if b32.hi == hi { if b32.del(lo) { s.itemsCount-- } return } } } // AppendTo appends all the items from the set to dst and returns the result. // // The returned items are sorted. // // AppendTo can mutate s. func (s *Set) AppendTo(dst []uint64) []uint64 { if s == nil { return dst } // pre-allocate memory for dst dstLen := len(dst) sLen := s.Len() if n := dstLen + sLen - cap(dst); n > 0 { dst = append(dst[:cap(dst)], make([]uint64, n)...) dst = dst[:dstLen] } s.sort() for i := range s.buckets { dst = s.buckets[i].appendTo(dst) } return dst } func (s *Set) sort() { // sort s.buckets if it isn't sorted yet if !sort.IsSorted(&s.buckets) { sort.Sort(&s.buckets) } } // Union adds all the items from a to s. func (s *Set) Union(a *Set) { s.union(a, false) } // UnionMayOwn adds all the items from a to s. // // It may own a if s is empty. This means that `a` cannot be used // after the call to UnionMayOwn. func (s *Set) UnionMayOwn(a *Set) { s.union(a, true) } func (s *Set) union(a *Set, mayOwn bool) { if a.Len() == 0 { // Fast path - nothing to union. return } if s.Len() == 0 { // Fast path - copy `a` to `s`. if !mayOwn { a = a.Clone() } *s = *a return } // Make shallow copy of `a`, since it can be modified by a.sort(). if !mayOwn { a = a.cloneShallow() } a.sort() s.sort() i := 0 j := 0 sBucketsLen := len(s.buckets) for { for i < sBucketsLen && j < len(a.buckets) && s.buckets[i].hi < a.buckets[j].hi { i++ } if i >= sBucketsLen { for j < len(a.buckets) { b32 := s.addBucket32() a.buckets[j].copyTo(b32) j++ } break } for j < len(a.buckets) && a.buckets[j].hi < s.buckets[i].hi { b32 := s.addBucket32() a.buckets[j].copyTo(b32) j++ } if j >= len(a.buckets) { break } if s.buckets[i].hi == a.buckets[j].hi { s.buckets[i].union(&a.buckets[j], mayOwn) i++ j++ } } s.fixItemsCount() } // Intersect removes all the items missing in a from s. func (s *Set) Intersect(a *Set) { if s.Len() == 0 || a.Len() == 0 { // Fast path - the result is empty. *s = Set{} return } // Make shallow copy of `a`, since it can be modified by a.sort(). a = a.cloneShallow() a.sort() s.sort() i := 0 j := 0 for { for i < len(s.buckets) && j < len(a.buckets) && s.buckets[i].hi < a.buckets[j].hi { s.buckets[i] = bucket32{} i++ } if i >= len(s.buckets) { break } for j < len(a.buckets) && a.buckets[j].hi < s.buckets[i].hi { j++ } if j >= len(a.buckets) { for i < len(s.buckets) { s.buckets[i] = bucket32{} i++ } break } if s.buckets[i].hi == a.buckets[j].hi { s.buckets[i].intersect(&a.buckets[j]) i++ j++ } } s.fixItemsCount() } // Subtract removes from s all the shared items between s and a. func (s *Set) Subtract(a *Set) { if s.Len() == 0 || a.Len() == 0 { // Fast path - nothing to subtract. return } a.ForEach(func(part []uint64) bool { for _, x := range part { s.Del(x) } return true }) } // Equal returns true if s contains the same items as a. func (s *Set) Equal(a *Set) bool { if s.Len() != a.Len() { return false } equal := true a.ForEach(func(part []uint64) bool { for _, x := range part { if !s.Has(x) { equal = false return false } } return true }) return equal } // ForEach calls f for all the items stored in s. // // Each call to f contains part with arbitrary part of items stored in the set. // The iteration is stopped if f returns false. func (s *Set) ForEach(f func(part []uint64) bool) { if s == nil { return } for i := range s.buckets { if !s.buckets[i].forEach(f) { return } } } type bucket32 struct { hi uint32 // hint may contain bucket index for the last successful operation. // This allows saving CPU time on subsequent calls to the same bucket. hint uint32 // b16his contains high 16 bits for each bucket in buckets. // // It is always sorted. b16his []uint16 // buckets are sorted by b16his buckets []*bucket16 } func (b *bucket32) getLen() int { n := 0 for i := range b.buckets { n += b.buckets[i].getLen() } return n } func (b *bucket32) union(a *bucket32, mayOwn bool) { i := 0 j := 0 bBucketsLen := len(b.buckets) for { for i < bBucketsLen && j < len(a.b16his) && b.b16his[i] < a.b16his[j] { i++ } if i >= bBucketsLen { for j < len(a.b16his) { b16 := b.addBucket16(a.b16his[j]) if mayOwn { *b16 = *a.buckets[j] } else { a.buckets[j].copyTo(b16) } j++ } break } for j < len(a.b16his) && a.b16his[j] < b.b16his[i] { b16 := b.addBucket16(a.b16his[j]) if mayOwn { *b16 = *a.buckets[j] } else { a.buckets[j].copyTo(b16) } j++ } if j >= len(a.b16his) { break } if b.b16his[i] == a.b16his[j] { b.buckets[i].union(a.buckets[j]) i++ j++ } } // Restore buckets order, which could be violated during the merge above. if !sort.IsSorted(b) { sort.Sort(b) } } // This is for sort.Interface used in bucket32.union func (b *bucket32) Len() int { return len(b.b16his) } func (b *bucket32) Less(i, j int) bool { return b.b16his[i] < b.b16his[j] } func (b *bucket32) Swap(i, j int) { his := b.b16his buckets := b.buckets his[i], his[j] = his[j], his[i] buckets[i], buckets[j] = buckets[j], buckets[i] } func (b *bucket32) intersect(a *bucket32) { i := 0 j := 0 for { for i < len(b.b16his) && j < len(a.b16his) && b.b16his[i] < a.b16his[j] { *b.buckets[i] = bucket16{} i++ } if i >= len(b.b16his) { break } for j < len(a.b16his) && a.b16his[j] < b.b16his[i] { j++ } if j >= len(a.b16his) { for i < len(b.b16his) { *b.buckets[i] = bucket16{} i++ } break } if b.b16his[i] == a.b16his[j] { b.buckets[i].intersect(a.buckets[j]) i++ j++ } } // Remove zero buckets b16his := b.b16his[:0] bs := b.buckets[:0] for i, b16 := range b.buckets { if b16.isZero() { continue } b16his = append(b16his, b.b16his[i]) bs = append(bs, b16) } for i := len(bs); i < len(b.buckets); i++ { b.buckets[i] = nil } b.hint = 0 b.b16his = b16his b.buckets = bs } func (b *bucket32) forEach(f func(part []uint64) bool) bool { xbuf := partBufPool.Get().(*[]uint64) buf := *xbuf for i, b16 := range b.buckets { hi16 := b.b16his[i] buf = b16.appendTo(buf[:0], b.hi, hi16) if !f(buf) { return false } } *xbuf = buf partBufPool.Put(xbuf) return true } var partBufPool = &sync.Pool{ New: func() interface{} { buf := make([]uint64, 0, bitsPerBucket) return &buf }, } func (b *bucket32) sizeBytes() uint64 { n := uint64(unsafe.Sizeof(*b)) n += 2 * uint64(len(b.b16his)) for _, b16 := range b.buckets { n += b16.sizeBytes() } return n } func (b *bucket32) copyTo(dst *bucket32) { dst.hi = b.hi dst.b16his = append(dst.b16his[:0], b.b16his...) // Do not reuse dst.buckets, since it may be used in other places. dst.buckets = nil if len(b.buckets) > 0 { dst.buckets = make([]*bucket16, len(b.buckets)) for i, b16 := range b.buckets { b16Dst := &bucket16{} b16.copyTo(b16Dst) dst.buckets[i] = b16Dst } } } func (b *bucket32) getHint() uint32 { return atomic.LoadUint32(&b.hint) } func (b *bucket32) setHint(n int) { atomic.StoreUint32(&b.hint, uint32(n)) } func (b *bucket32) add(x uint32) bool { hi := uint16(x >> 16) lo := uint16(x) his := b.b16his if n := b.getHint(); n < uint32(len(his)) && his[n] == hi { // Fast path - add to the previously used bucket. bs := b.buckets return n < uint32(len(bs)) && bs[n].add(lo) } return b.addSlow(hi, lo) } func (b *bucket32) addMulti(a []uint64) int { if len(a) == 0 { return 0 } count := 0 hiPrev := uint16(a[0] >> 16) i := 0 for j, x := range a { hi := uint16(x >> 16) if hi == hiPrev { continue } b16 := b.getOrCreateBucket16(hiPrev) count += b16.addMulti(a[i:j]) hiPrev = hi i = j } b16 := b.getOrCreateBucket16(hiPrev) count += b16.addMulti(a[i:]) return count } func (b *bucket32) getOrCreateBucket16(hi uint16) *bucket16 { his := b.b16his bs := b.buckets n := binarySearch16(his, hi) if n < 0 || n >= len(his) || his[n] != hi { return b.addBucketAtPos(hi, n) } return bs[n] } func (b *bucket32) addSlow(hi, lo uint16) bool { his := b.b16his n := binarySearch16(his, hi) if n < 0 || n >= len(his) || his[n] != hi { b16 := b.addBucketAtPos(hi, n) b16.add(lo) return true } b.setHint(n) bs := b.buckets return n < len(bs) && bs[n].add(lo) } func (b *bucket32) addBucket16(hi uint16) *bucket16 { b.b16his = append(b.b16his, hi) b.buckets = append(b.buckets, &bucket16{}) return b.buckets[len(b.buckets)-1] } func (b *bucket32) addBucketAtPos(hi uint16, pos int) *bucket16 { if pos < 0 { // This is a hint to Go compiler to remove automatic bounds checks below. return nil } if pos >= len(b.b16his) { return b.addBucket16(hi) } b.b16his = append(b.b16his[:pos+1], b.b16his[pos:]...) b.b16his[pos] = hi b.buckets = append(b.buckets[:pos+1], b.buckets[pos:]...) b16 := &bucket16{} b.buckets[pos] = b16 return b16 } func (b *bucket32) has(x uint32) bool { hi := uint16(x >> 16) lo := uint16(x) his := b.b16his n := binarySearch16(his, hi) if n < 0 || n >= len(his) || his[n] != hi { return false } // Do not call b.setHint(n) here, since this may trash performance // when many concurrent goroutines call b.has() method from many CPU cores. bs := b.buckets return n < len(bs) && bs[n].has(lo) } func (b *bucket32) del(x uint32) bool { hi := uint16(x >> 16) lo := uint16(x) his := b.b16his if n := b.getHint(); n < uint32(len(his)) && his[n] == hi { // Fast path - use the bucket from the previous operation. bs := b.buckets return n < uint32(len(bs)) && bs[n].del(lo) } return b.delSlow(hi, lo) } func (b *bucket32) delSlow(hi, lo uint16) bool { his := b.b16his n := binarySearch16(his, hi) if n < 0 || n >= len(his) || his[n] != hi { return false } b.setHint(n) bs := b.buckets return n < len(bs) && bs[n].del(lo) } func (b *bucket32) appendTo(dst []uint64) []uint64 { for i := range b.buckets { hi16 := b.b16his[i] dst = b.buckets[i].appendTo(dst, b.hi, hi16) } return dst } const ( bitsPerBucket = 1 << 16 wordsPerBucket = bitsPerBucket / 64 ) type bucket16 struct { bits *[wordsPerBucket]uint64 smallPool [smallPoolSize]uint16 smallPoolLen int } const smallPoolSize = 56 func (b *bucket16) isZero() bool { return b.bits == nil && b.smallPoolLen == 0 } func (b *bucket16) getLen() int { if b.bits == nil { return b.smallPoolLen } n := 0 for _, x := range b.bits { if x > 0 { n += bits.OnesCount64(x) } } return n } func (b *bucket16) union(a *bucket16) { if a.bits != nil && b.bits != nil { // Fast path - use bitwise ops. ab := a.bits bb := b.bits _ = bb[len(ab)-1] for i, ax := range ab { bx := bb[i] bx |= ax bb[i] = bx } return } // Slow path xbuf := partBufPool.Get().(*[]uint64) buf := *xbuf buf = a.appendTo(buf[:0], 0, 0) for _, x := range buf { x16 := uint16(x) b.add(x16) } *xbuf = buf partBufPool.Put(xbuf) } func (b *bucket16) intersect(a *bucket16) { if a.bits != nil && b.bits != nil { // Fast path - use bitwise ops ab := a.bits bb := b.bits _ = bb[len(ab)-1] for i, ax := range ab { bx := bb[i] bx &= ax bb[i] = bx } return } // Slow path xbuf := partBufPool.Get().(*[]uint64) buf := *xbuf buf = b.appendTo(buf[:0], 0, 0) for _, x := range buf { x16 := uint16(x) if !a.has(x16) { b.del(x16) } } *xbuf = buf partBufPool.Put(xbuf) } func (b *bucket16) sizeBytes() uint64 { n := unsafe.Sizeof(*b) if b.bits != nil { n += unsafe.Sizeof(*b.bits) } return uint64(n) } func (b *bucket16) copyTo(dst *bucket16) { // Do not reuse dst.bits, since it may be used in other places. dst.bits = nil if b.bits != nil { bits := *b.bits dst.bits = &bits } dst.smallPool = b.smallPool dst.smallPoolLen = b.smallPoolLen } func (b *bucket16) add(x uint16) bool { bits := b.bits if bits == nil { return b.addToSmallPool(x) } wordNum, bitMask := getWordNumBitMask(x) ok := bits[wordNum]&bitMask == 0 if ok { bits[wordNum] |= bitMask } return ok } func (b *bucket16) addMulti(a []uint64) int { if b.bits == nil { if b.smallPoolLen+len(a) > len(b.smallPool) { b.switchSmallPoolToBits() goto fastPath } // Slow path count := 0 for _, x := range a { if b.addToSmallPool(uint16(x)) { count++ } } return count } fastPath: count := 0 bits := b.bits for _, x := range a { wordNum, bitMask := getWordNumBitMask(uint16(x)) if bits[wordNum]&bitMask == 0 { bits[wordNum] |= bitMask count++ } } return count } func (b *bucket16) addToSmallPool(x uint16) bool { if b.hasInSmallPool(x) { return false } sp := b.smallPool[:] if b.smallPoolLen < len(sp) { sp[b.smallPoolLen] = x b.smallPoolLen++ return true } b.switchSmallPoolToBits() b.add(x) return true } func (b *bucket16) switchSmallPoolToBits() { smallPoolLen := b.smallPoolLen b.smallPoolLen = 0 var bits [wordsPerBucket]uint64 b.bits = &bits for _, v := range b.smallPool[:smallPoolLen] { b.add(v) } } func (b *bucket16) has(x uint16) bool { if b.bits == nil { return b.hasInSmallPool(x) } wordNum, bitMask := getWordNumBitMask(x) return b.bits[wordNum]&bitMask != 0 } func (b *bucket16) hasInSmallPool(x uint16) bool { for _, v := range b.smallPool[:b.smallPoolLen] { if v == x { return true } } return false } func (b *bucket16) del(x uint16) bool { if b.bits == nil { return b.delFromSmallPool(x) } wordNum, bitMask := getWordNumBitMask(x) word := &b.bits[wordNum] ok := *word&bitMask != 0 *word &^= bitMask return ok } func (b *bucket16) delFromSmallPool(x uint16) bool { sp := b.smallPool[:] for i, v := range sp[:b.smallPoolLen] { if v == x { copy(sp[i:], sp[i+1:]) b.smallPoolLen-- return true } } return false } func (b *bucket16) appendTo(dst []uint64, hi uint32, hi16 uint16) []uint64 { hi64 := uint64(hi)<<32 | uint64(hi16)<<16 if b.bits == nil { // Use smallPoolSorter instead of sort.Slice here in order to reduce memory allocations. sps := smallPoolSorterPool.Get().(*smallPoolSorter) // Sort a copy of b.smallPool, since b must be readonly in order to prevent from data races // when b.appendTo is called from concurrent goroutines. sps.smallPool = b.smallPool sps.a = sps.smallPool[:b.smallPoolLen] if len(sps.a) > 1 && !sort.IsSorted(sps) { sort.Sort(sps) } for _, v := range sps.a { x := hi64 | uint64(v) dst = append(dst, x) } smallPoolSorterPool.Put(sps) return dst } var wordNum uint64 for _, word := range b.bits { if word == 0 { wordNum++ continue } x64 := hi64 | (wordNum * 64) for { tzn := uint64(bits.TrailingZeros64(word)) if tzn >= 64 { break } word &^= uint64(1) << tzn x := x64 | tzn dst = append(dst, x) } wordNum++ } return dst } var smallPoolSorterPool = &sync.Pool{ New: func() interface{} { return &smallPoolSorter{} }, } type smallPoolSorter struct { smallPool [smallPoolSize]uint16 a []uint16 } func (sps *smallPoolSorter) Len() int { return len(sps.a) } func (sps *smallPoolSorter) Less(i, j int) bool { a := sps.a return a[i] < a[j] } func (sps *smallPoolSorter) Swap(i, j int) { a := sps.a a[i], a[j] = a[j], a[i] } func getWordNumBitMask(x uint16) (uint16, uint64) { wordNum := x / 64 bitMask := uint64(1) << (x & 63) return wordNum, bitMask } func binarySearch16(u16 []uint16, x uint16) int { // The code has been adapted from sort.Search. n := len(u16) if n > 0 && u16[n-1] < x { // Fast path for values scanned in ascending order. return n } i, j := 0, n for i < j { h := int(uint(i+j) >> 1) if h >= 0 && h < len(u16) && u16[h] < x { i = h + 1 } else { j = h } } return i }