diff --git a/lib/uint64set/uint64set.go b/lib/uint64set/uint64set.go index 0c63332c9..0686473e5 100644 --- a/lib/uint64set/uint64set.go +++ b/lib/uint64set/uint64set.go @@ -4,6 +4,7 @@ import ( "math/bits" "sort" "sync" + "sync/atomic" "unsafe" ) @@ -16,6 +17,10 @@ import ( type Set struct { itemsCount int buckets bucket32Sorter + + // Most likely the buckets contains only a single item, so put it here for performance reasons + // in order to improve memory locality. + scratchBuckets [1]bucket32 } type bucket32Sorter []bucket32 @@ -38,7 +43,11 @@ func (s *Set) Clone() *Set { } var dst Set dst.itemsCount = s.itemsCount - dst.buckets = make([]bucket32, len(s.buckets)) + if len(s.buckets) == 1 { + dst.buckets = dst.scratchBuckets[:] + } else { + dst.buckets = make([]bucket32, len(s.buckets)) + } for i := range s.buckets { s.buckets[i].copyTo(&dst.buckets[i]) } @@ -56,6 +65,9 @@ func (s *Set) fixItemsCount() { func (s *Set) cloneShallow() *Set { var dst Set dst.itemsCount = s.itemsCount + if len(s.buckets) == 1 { + dst.buckets = dst.scratchBuckets[:] + } dst.buckets = append(dst.buckets[:0], s.buckets...) return &dst } @@ -84,18 +96,37 @@ func (s *Set) Len() int { // Add adds x to s. func (s *Set) Add(x uint64) { - hi := uint32(x >> 32) - lo := uint32(x) - for i := range s.buckets { - b32 := &s.buckets[i] - if b32.hi == hi { - if b32.add(lo) { + hi32 := uint32(x >> 32) + lo32 := uint32(x) + bs := s.buckets + if len(bs) > 0 && bs[0].hi == hi32 { + // Manually inline bucket32.add for performance reasons. + hi16 := uint16(lo32 >> 16) + lo16 := uint16(lo32) + b32 := &bs[0] + his := b32.b16his + if n := b32.getHint(); n < uint32(len(his)) && his[n] == hi16 { + bs := b32.buckets + if n < uint32(len(bs)) && bs[n].add(lo16) { + s.itemsCount++ + } + return + } + if b32.addSlow(hi16, lo16) { + s.itemsCount++ + } + return + } + for i := range bs { + b32 := &bs[i] + if b32.hi == hi32 { + if b32.add(lo32) { s.itemsCount++ } return } } - s.addAlloc(hi, lo) + s.addAlloc(hi32, lo32) } func (s *Set) addAlloc(hi, lo uint32) { @@ -106,7 +137,11 @@ func (s *Set) addAlloc(hi, lo uint32) { } func (s *Set) addBucket32() *bucket32 { - s.buckets = append(s.buckets, bucket32{}) + if len(s.buckets) == 0 { + s.buckets = s.scratchBuckets[:] + } else { + s.buckets = append(s.buckets, bucket32{}) + } return &s.buckets[len(s.buckets)-1] } @@ -115,12 +150,26 @@ func (s *Set) Has(x uint64) bool { if s == nil { return false } - hi := uint32(x >> 32) - lo := uint32(x) - for i := range s.buckets { - b32 := &s.buckets[i] - if b32.hi == hi { - return b32.has(lo) + hi32 := uint32(x >> 32) + lo32 := uint32(x) + bs := s.buckets + if len(bs) > 0 && bs[0].hi == hi32 { + // Manually inline bucket32.has for performance reasons. + hi16 := uint16(lo32 >> 16) + lo16 := uint16(lo32) + b32 := &bs[0] + his := b32.b16his + if n := b32.getHint(); n < uint32(len(his)) && his[n] == hi16 { + // Fast path - check the previously used bucket. + bs := b32.buckets + return n < uint32(len(bs)) && bs[n].has(lo16) + } + return b32.hasSlow(hi16, lo16) + } + for i := range bs { + b32 := &bs[i] + if b32.hi == hi32 { + return b32.has(lo32) } } return false @@ -130,8 +179,15 @@ func (s *Set) Has(x uint64) bool { func (s *Set) Del(x uint64) { hi := uint32(x >> 32) lo := uint32(x) - for i := range s.buckets { - b32 := &s.buckets[i] + bs := s.buckets + if len(bs) > 0 && bs[0].hi == hi { + if bs[0].del(lo) { + s.itemsCount-- + } + return + } + for i := range bs { + b32 := &bs[i] if b32.hi == hi { if b32.del(lo) { s.itemsCount-- @@ -205,12 +261,11 @@ func (s *Set) union(a *Set, mayOwn bool) { s.sort() i := 0 j := 0 - sbuckets := s.buckets for { - for i < len(sbuckets) && j < len(a.buckets) && sbuckets[i].hi < a.buckets[j].hi { + for i < len(s.buckets) && j < len(a.buckets) && s.buckets[i].hi < a.buckets[j].hi { i++ } - if i >= len(sbuckets) { + if i >= len(s.buckets) { for j < len(a.buckets) { b32 := s.addBucket32() a.buckets[j].copyTo(b32) @@ -218,7 +273,7 @@ func (s *Set) union(a *Set, mayOwn bool) { } break } - for j < len(a.buckets) && a.buckets[j].hi < sbuckets[i].hi { + for j < len(a.buckets) && a.buckets[j].hi < s.buckets[i].hi { b32 := s.addBucket32() a.buckets[j].copyTo(b32) j++ @@ -226,8 +281,8 @@ func (s *Set) union(a *Set, mayOwn bool) { if j >= len(a.buckets) { break } - if sbuckets[i].hi == a.buckets[j].hi { - sbuckets[i].union(&a.buckets[j], mayOwn) + if s.buckets[i].hi == a.buckets[j].hi { + s.buckets[i].union(&a.buckets[j], mayOwn) i++ j++ } @@ -323,22 +378,19 @@ func (s *Set) ForEach(f func(part []uint64) bool) { } type bucket32 struct { - hi uint32 - b16his []uint16 - buckets []bucket16 + hi uint32 - // hint may contain bucket index for the last successful add or del operation. + // hint may contain bucket index for the last successful operation. // This allows saving CPU time on subsequent calls to the same bucket. - hint int -} + hint uint32 -func (b *bucket32) cloneShallow() *bucket32 { - var dst bucket32 - dst.hi = b.hi - dst.b16his = append(dst.b16his[:0], b.b16his...) - dst.buckets = append(dst.buckets[:0], b.buckets...) - dst.hint = b.hint - return &dst + // b16his contains high 16 bits for each bucket in buckets. + // + // It is always sorted. + b16his []uint16 + + // buckets are sorted by b16his + buckets []bucket16 } func (b *bucket32) getLen() int { @@ -350,49 +402,59 @@ func (b *bucket32) getLen() int { } func (b *bucket32) union(a *bucket32, mayOwn bool) { - if !mayOwn { - a = a.cloneShallow() // clone a, since it is sorted below. - } - a.sort() - b.sort() i := 0 j := 0 - bb16his := b.b16his for { - for i < len(bb16his) && j < len(a.b16his) && bb16his[i] < a.b16his[j] { + for i < len(b.b16his) && j < len(a.b16his) && b.b16his[i] < a.b16his[j] { i++ } - if i >= len(bb16his) { + if i >= len(b.b16his) { for j < len(a.b16his) { - b.b16his = append(b.b16his, a.b16his[j]) - b16 := b.addBucket16() - a.buckets[j].copyTo(b16) + b16 := b.addBucket16(a.b16his[j]) + if mayOwn { + *b16 = a.buckets[j] + } else { + a.buckets[j].copyTo(b16) + } j++ } break } - for j < len(a.b16his) && a.b16his[j] < bb16his[i] { - b.b16his = append(b.b16his, a.b16his[j]) - b16 := b.addBucket16() - a.buckets[j].copyTo(b16) + for j < len(a.b16his) && a.b16his[j] < b.b16his[i] { + b16 := b.addBucket16(a.b16his[j]) + if mayOwn { + *b16 = a.buckets[j] + } else { + a.buckets[j].copyTo(b16) + } j++ } if j >= len(a.b16his) { break } - if bb16his[i] == a.b16his[j] { + if b.b16his[i] == a.b16his[j] { b.buckets[i].union(&a.buckets[j]) i++ j++ } } - b.sort() + // Restore buckets order, which could be violated during the merge above. + if !sort.IsSorted(b) { + sort.Sort(b) + } +} + +// This is for sort.Interface used in bucket32.union +func (b *bucket32) Len() int { return len(b.b16his) } +func (b *bucket32) Less(i, j int) bool { return b.b16his[i] < b.b16his[j] } +func (b *bucket32) Swap(i, j int) { + his := b.b16his + buckets := b.buckets + his[i], his[j] = his[j], his[i] + buckets[i], buckets[j] = buckets[j], buckets[i] } func (b *bucket32) intersect(a *bucket32) { - a = a.cloneShallow() // clone a, since it is sorted below. - a.sort() - b.sort() i := 0 j := 0 for { @@ -465,73 +527,53 @@ func (b *bucket32) copyTo(dst *bucket32) { b.buckets[i].copyTo(&dst.buckets[i]) } } - dst.hint = b.hint } -// This is for sort.Interface -func (b *bucket32) Len() int { return len(b.b16his) } -func (b *bucket32) Less(i, j int) bool { return b.b16his[i] < b.b16his[j] } -func (b *bucket32) Swap(i, j int) { - his := b.b16his - buckets := b.buckets - his[i], his[j] = his[j], his[i] - buckets[i], buckets[j] = buckets[j], buckets[i] +func (b *bucket32) getHint() uint32 { + return atomic.LoadUint32(&b.hint) } -const maxUnsortedBuckets = 32 +func (b *bucket32) setHint(n int) { + atomic.StoreUint32(&b.hint, uint32(n)) +} func (b *bucket32) add(x uint32) bool { hi := uint16(x >> 16) lo := uint16(x) - if n := b.hint; n < len(b.b16his) && b.b16his[n] == hi { + his := b.b16his + if n := b.getHint(); n < uint32(len(his)) && his[n] == hi { // Fast path - add to the previously used bucket. - return n < len(b.buckets) && b.buckets[n].add(lo) + bs := b.buckets + return n < uint32(len(bs)) && bs[n].add(lo) } return b.addSlow(hi, lo) } func (b *bucket32) addSlow(hi, lo uint16) bool { - if len(b.buckets) > maxUnsortedBuckets { - n := binarySearch16(b.b16his, hi) - b.hint = n - if n < 0 || n >= len(b.b16his) || b.b16his[n] != hi { - b.addAllocBig(hi, lo, n) - return true - } - return n < len(b.buckets) && b.buckets[n].add(lo) + his := b.b16his + n := binarySearch16(his, hi) + if n < 0 || n >= len(his) || his[n] != hi { + b.addAlloc(hi, lo, n) + return true } - for i, hi16 := range b.b16his { - if hi16 == hi { - b.hint = i - return i < len(b.buckets) && b.buckets[i].add(lo) - } - } - b.addAllocSmall(hi, lo) - return true + b.setHint(n) + bs := b.buckets + return n < len(bs) && bs[n].add(lo) } -func (b *bucket32) addAllocSmall(hi, lo uint16) { +func (b *bucket32) addBucket16(hi uint16) *bucket16 { b.b16his = append(b.b16his, hi) - b16 := b.addBucket16() - _ = b16.add(lo) - if len(b.buckets) > maxUnsortedBuckets { - sort.Sort(b) - } -} - -func (b *bucket32) addBucket16() *bucket16 { b.buckets = append(b.buckets, bucket16{}) return &b.buckets[len(b.buckets)-1] } -func (b *bucket32) addAllocBig(hi, lo uint16, n int) { +func (b *bucket32) addAlloc(hi, lo uint16, n int) { if n < 0 { // This is a hint to Go compiler to remove automatic bounds checks below. return } if n >= len(b.b16his) { - b.b16his = append(b.b16his, hi) - b16 := b.addBucket16() + b16 := b.addBucket16(hi) _ = b16.add(lo) return } @@ -546,57 +588,50 @@ func (b *bucket32) addAllocBig(hi, lo uint16, n int) { func (b *bucket32) has(x uint32) bool { hi := uint16(x >> 16) lo := uint16(x) - if len(b.buckets) > maxUnsortedBuckets { - return b.hasSlow(hi, lo) + his := b.b16his + if n := b.getHint(); n < uint32(len(his)) && his[n] == hi { + // Fast path - check the previously used bucket. + bs := b.buckets + return n < uint32(len(bs)) && bs[n].has(lo) } - for i, hi16 := range b.b16his { - if hi16 == hi { - return i < len(b.buckets) && b.buckets[i].has(lo) - } - } - return false + return b.hasSlow(hi, lo) } func (b *bucket32) hasSlow(hi, lo uint16) bool { - n := binarySearch16(b.b16his, hi) - if n < 0 || n >= len(b.b16his) || b.b16his[n] != hi { + his := b.b16his + n := binarySearch16(his, hi) + if n < 0 || n >= len(his) || his[n] != hi { return false } - return n < len(b.buckets) && b.buckets[n].has(lo) + b.setHint(n) + bs := b.buckets + return n < len(bs) && bs[n].has(lo) } func (b *bucket32) del(x uint32) bool { hi := uint16(x >> 16) lo := uint16(x) - if n := b.hint; n < len(b.b16his) && b.b16his[n] == hi { + his := b.b16his + if n := b.getHint(); n < uint32(len(his)) && his[n] == hi { // Fast path - use the bucket from the previous operation. - return n < len(b.buckets) && b.buckets[n].del(lo) + bs := b.buckets + return n < uint32(len(bs)) && bs[n].del(lo) } return b.delSlow(hi, lo) } func (b *bucket32) delSlow(hi, lo uint16) bool { - if len(b.buckets) > maxUnsortedBuckets { - n := binarySearch16(b.b16his, hi) - b.hint = n - if n < 0 || n >= len(b.b16his) || b.b16his[n] != hi { - return false - } - return n < len(b.buckets) && b.buckets[n].del(lo) + his := b.b16his + n := binarySearch16(his, hi) + if n < 0 || n >= len(his) || his[n] != hi { + return false } - for i, hi16 := range b.b16his { - if hi16 == hi { - b.hint = i - return i < len(b.buckets) && b.buckets[i].del(lo) - } - } - return false + b.setHint(n) + bs := b.buckets + return n < len(bs) && bs[n].del(lo) } func (b *bucket32) appendTo(dst []uint64) []uint64 { - if len(b.buckets) <= maxUnsortedBuckets { - b.sort() - } for i := range b.buckets { hi16 := b.b16his[i] dst = b.buckets[i].appendTo(dst, b.hi, hi16) @@ -604,12 +639,6 @@ func (b *bucket32) appendTo(dst []uint64) []uint64 { return dst } -func (b *bucket32) sort() { - if !sort.IsSorted(b) { - sort.Sort(b) - } -} - const ( bitsPerBucket = 1 << 16 wordsPerBucket = bitsPerBucket / 64 @@ -637,10 +666,13 @@ func (b *bucket16) getLen() int { func (b *bucket16) union(a *bucket16) { if a.bits != nil && b.bits != nil { // Fast path - use bitwise ops. - for i, ax := range a.bits { - bx := b.bits[i] + ab := a.bits + bb := b.bits + _ = bb[len(ab)-1] + for i, ax := range ab { + bx := bb[i] bx |= ax - b.bits[i] = bx + bb[i] = bx } return } @@ -660,10 +692,13 @@ func (b *bucket16) union(a *bucket16) { func (b *bucket16) intersect(a *bucket16) { if a.bits != nil && b.bits != nil { // Fast path - use bitwise ops - for i, ax := range a.bits { - bx := b.bits[i] + ab := a.bits + bb := b.bits + _ = bb[len(ab)-1] + for i, ax := range ab { + bx := bb[i] bx &= ax - b.bits[i] = bx + bb[i] = bx } return } diff --git a/lib/uint64set/uint64set_test.go b/lib/uint64set/uint64set_test.go index ae8f8c6f9..46021d514 100644 --- a/lib/uint64set/uint64set_test.go +++ b/lib/uint64set/uint64set_test.go @@ -164,10 +164,29 @@ func TestSetOps(t *testing.T) { a = append(a, uint64(i)) } var b []uint64 - for i := 0; i < 100; i++ { - b = append(b, uint64(i+1<<16)) + for i := 1 << 16; i < 1<<16+1000; i++ { + b = append(b, uint64(i)) } f(a, b) + + for i := 1<<16 - 100; i < 1<<16+100; i++ { + a = append(a, uint64(i)) + } + for i := uint64(1) << 32; i < 1<<32+1<<16+200; i++ { + b = append(b, i) + } + f(a, b) + + rng := rand.New(rand.NewSource(0)) + for i := 0; i < 10; i++ { + a = nil + b = nil + for j := 0; j < 1000; j++ { + a = append(a, uint64(rng.Intn(1e6))) + b = append(b, uint64(rng.Intn(1e6))) + } + f(a, b) + } } func expectEqual(s *Set, m map[uint64]bool) error { @@ -202,7 +221,7 @@ func expectEqual(s *Set, m map[uint64]bool) error { } func TestSetBasicOps(t *testing.T) { - for _, itemsCount := range []int{1, 2, 3, 4, 5, 6, 1e2, 1e3, 1e4, 1e5, 1e6, maxUnsortedBuckets * bitsPerBucket * 2} { + for _, itemsCount := range []int{1, 2, 3, 4, 5, 6, 1e2, 1e3, 1e4, 1e5, 1e6} { t.Run(fmt.Sprintf("items_%d", itemsCount), func(t *testing.T) { testSetBasicOps(t, itemsCount) })