diff --git a/lib/uint64set/uint64set.go b/lib/uint64set/uint64set.go index ad826a244..f9d6209fb 100644 --- a/lib/uint64set/uint64set.go +++ b/lib/uint64set/uint64set.go @@ -45,6 +45,14 @@ func (s *Set) Clone() *Set { return &dst } +func (s *Set) fixItemsCount() { + n := 0 + for i := range s.buckets { + n += s.buckets[i].getLen() + } + s.itemsCount = n +} + func (s *Set) cloneShallow() *Set { var dst Set dst.itemsCount = s.itemsCount @@ -177,27 +185,54 @@ func (s *Set) UnionMayOwn(a *Set) { } func (s *Set) union(a *Set, mayOwn bool) { - if mayOwn && s.Len() < a.Len() { - // Swap `a` with `s` in order to reduce the number of iterations in ForEach loop below. - // This operation is safe only if `a` is no longer used after the call to union. - *a, *s = *s, *a - } if a.Len() == 0 { // Fast path - nothing to union. return } if s.Len() == 0 { - // Fast path - just copy a. - aCopy := a.Clone() - *s = *aCopy + // Fast path - copy `a` to `s`. + if !mayOwn { + a = a.Clone() + } + *s = *a return } - a.ForEach(func(part []uint64) bool { - for _, x := range part { - s.Add(x) + // Make shallow copy of `a`, since it can be modified by a.sort(). + if !mayOwn { + a = a.cloneShallow() + } + a.sort() + s.sort() + i := 0 + j := 0 + sbuckets := s.buckets + for { + for i < len(sbuckets) && j < len(a.buckets) && sbuckets[i].hi < a.buckets[j].hi { + i++ } - return true - }) + if i >= len(sbuckets) { + for j < len(a.buckets) { + b32 := s.addBucket32() + a.buckets[j].copyTo(b32) + j++ + } + break + } + for j < len(a.buckets) && a.buckets[j].hi < sbuckets[i].hi { + b32 := s.addBucket32() + a.buckets[j].copyTo(b32) + j++ + } + if j >= len(a.buckets) { + break + } + if sbuckets[i].hi == a.buckets[j].hi { + sbuckets[i].union(&a.buckets[j], mayOwn) + i++ + j++ + } + } + s.fixItemsCount() } // Intersect removes all the items missing in a from s. @@ -207,11 +242,10 @@ func (s *Set) Intersect(a *Set) { *s = Set{} return } - // Make shallow copy of `a`, since it can be modified below. + // Make shallow copy of `a`, since it can be modified by a.sort(). a = a.cloneShallow() a.sort() s.sort() - itemsCount := 0 i := 0 j := 0 for { @@ -233,12 +267,12 @@ func (s *Set) Intersect(a *Set) { break } if s.buckets[i].hi == a.buckets[j].hi { - itemsCount += s.buckets[i].intersect(&a.buckets[j]) + s.buckets[i].intersect(&a.buckets[j]) i++ j++ } } - s.itemsCount = itemsCount + s.fixItemsCount() } // Subtract removes from s all the shared items between s and a. @@ -307,11 +341,56 @@ func (b *bucket32) cloneShallow() *bucket32 { return &dst } -func (b *bucket32) intersect(a *bucket32) int { - a = a.cloneShallow() // clone a, since is is sorted below. +func (b *bucket32) getLen() int { + n := 0 + for i := range b.buckets { + n += b.buckets[i].getLen() + } + return n +} + +func (b *bucket32) union(a *bucket32, mayOwn bool) { + if !mayOwn { + a = a.cloneShallow() // clone a, since it is sorted below. + } + a.sort() + b.sort() + i := 0 + j := 0 + bb16his := b.b16his + for { + for i < len(bb16his) && j < len(a.b16his) && bb16his[i] < a.b16his[j] { + i++ + } + if i >= len(bb16his) { + for j < len(a.b16his) { + b16 := b.addBucket16() + a.buckets[j].copyTo(b16) + j++ + } + break + } + for j < len(a.b16his) && a.b16his[j] < bb16his[i] { + b16 := b.addBucket16() + a.buckets[j].copyTo(b16) + j++ + } + if j >= len(a.b16his) { + break + } + if bb16his[i] == a.b16his[j] { + b.buckets[i].union(&a.buckets[j]) + i++ + j++ + } + } + b.sort() +} + +func (b *bucket32) intersect(a *bucket32) { + a = a.cloneShallow() // clone a, since it is sorted below. a.sort() b.sort() - itemsCount := 0 i := 0 j := 0 for { @@ -333,12 +412,11 @@ func (b *bucket32) intersect(a *bucket32) int { break } if b.b16his[i] == a.b16his[j] { - itemsCount += b.buckets[i].intersect(&a.buckets[j]) + b.buckets[i].intersect(&a.buckets[j]) i++ j++ } } - return itemsCount } func (b *bucket32) forEach(f func(part []uint64) bool) bool { @@ -541,36 +619,65 @@ type bucket16 struct { smallPool [56]uint16 } -func (b *bucket16) intersect(a *bucket16) int { - itemsCount := 0 +func (b *bucket16) getLen() int { + if b.bits == nil { + return b.smallPoolLen + } + n := 0 + for _, x := range b.bits { + if x > 0 { + n += bits.OnesCount64(x) + } + } + return n +} + +func (b *bucket16) union(a *bucket16) { + if a.bits != nil && b.bits != nil { + // Fast path - use bitwise ops. + for i, ax := range a.bits { + bx := b.bits[i] + bx |= ax + b.bits[i] = bx + } + return + } + + // Slow path + xbuf := partBufPool.Get().(*[]uint64) + buf := *xbuf + buf = a.appendTo(buf[:0], 0, 0) + for _, x := range buf { + x16 := uint16(x) + b.add(x16) + } + *xbuf = buf + partBufPool.Put(xbuf) +} + +func (b *bucket16) intersect(a *bucket16) { if a.bits != nil && b.bits != nil { // Fast path - use bitwise ops for i, ax := range a.bits { bx := b.bits[i] bx &= ax - if bx > 0 { - itemsCount += bits.OnesCount64(bx) - } b.bits[i] = bx } - return itemsCount + return } // Slow path xbuf := partBufPool.Get().(*[]uint64) buf := *xbuf buf = b.appendTo(buf[:0], 0, 0) - itemsCount = len(buf) for _, x := range buf { x16 := uint16(x) if !a.has(x16) { b.del(x16) - itemsCount-- } } *xbuf = buf partBufPool.Put(xbuf) - return itemsCount } func (b *bucket16) sizeBytes() uint64 {