diff --git a/lib/uint64set/uint64set_test.go b/lib/uint64set/uint64set_test.go index bbb9675bc..d569db856 100644 --- a/lib/uint64set/uint64set_test.go +++ b/lib/uint64set/uint64set_test.go @@ -9,6 +9,198 @@ import ( "time" ) +func TestSetOps(t *testing.T) { + f := func(a, b []uint64) { + t.Helper() + mUnion := make(map[uint64]bool) + mIntersect := make(map[uint64]bool) + ma := make(map[uint64]bool) + sa := &Set{} + sb := &Set{} + for _, v := range a { + sa.Add(v) + ma[v] = true + mUnion[v] = true + } + for _, v := range b { + sb.Add(v) + mUnion[v] = true + if ma[v] { + mIntersect[v] = true + } + } + saOrig := sa.Clone() + if !saOrig.Equal(sa) { + t.Fatalf("saOrig must be equal to sa; got\n%v\nvs\n%v", saOrig, sa) + } + sbOrig := sb.Clone() + if !sbOrig.Equal(sb) { + t.Fatalf("sbOrig must be equal to sb; got\n%v\nvs\n%v", sbOrig, sb) + } + + // Verify sa.Union(sb) + sa.Union(sb) + if err := expectEqual(sa, mUnion); err != nil { + t.Fatalf("ivalid sa.Union(sb): %s", err) + } + if !sbOrig.Equal(sb) { + t.Fatalf("sbOrig must be equal to sb after sa.Union(sb); got\n%v\nvs\n%v", sbOrig, sb) + } + + // Verify sb.Union(sa) + sa = saOrig.Clone() + sb.Union(sa) + if err := expectEqual(sb, mUnion); err != nil { + t.Fatalf("invalid sb.Union(sa): %s", err) + } + if !saOrig.Equal(sa) { + t.Fatalf("saOrig must be equal to sa after sb.Union(sa); got\n%v\nvs\n%v", saOrig, sa) + } + + // Verify sa.UnionMayOwn(sb) + sa = saOrig.Clone() + sb = sbOrig.Clone() + sa.UnionMayOwn(sb) + if err := expectEqual(sa, mUnion); err != nil { + t.Fatalf("invalid sa.UnionMayOwn(sb): %s", err) + } + if !sbOrig.Equal(sb) { + t.Fatalf("sbOrig must be equal to sb after sa.UnionMayOwn(sb); got\n%v\nvs\n%v", sbOrig, sb) + } + + // Verify sb.UnionMayOwn(sa) + sa = saOrig.Clone() + sb.UnionMayOwn(sa) + if err := expectEqual(sb, mUnion); err != nil { + t.Fatalf("invalid sb.UnionMayOwn(sa): %s", err) + } + if !saOrig.Equal(sa) { + t.Fatalf("saOrig must be equal to sa after sb.UnionMayOwn(sa); got\n%v\nvs\n%v", saOrig, sa) + } + + // Verify sa.Intersect(sb) + sa = saOrig.Clone() + sb = sbOrig.Clone() + sa.Intersect(sb) + if err := expectEqual(sa, mIntersect); err != nil { + t.Fatalf("invalid sa.Intersect(sb): %s", err) + } + if !sbOrig.Equal(sb) { + t.Fatalf("sbOrig must be equal to sb after sa.Intersect(sb); got\n%v\nvs\n%v", sbOrig, sb) + } + + // Verify sb.Intersect(sa) + sa = saOrig.Clone() + sb.Intersect(sa) + if err := expectEqual(sb, mIntersect); err != nil { + t.Fatalf("invalid sb.Intersect(sa): %s", err) + } + if !saOrig.Equal(sa) { + t.Fatalf("saOrig must be equal to sa after sb.Intersect(sa); got\n%v\nvs\n%v", saOrig, sa) + } + + // Verify sa.Subtract(sb) + mSubtractAB := make(map[uint64]bool) + for _, v := range a { + mSubtractAB[v] = true + } + for _, v := range b { + delete(mSubtractAB, v) + } + sa = saOrig.Clone() + sb = sbOrig.Clone() + sa.Subtract(sb) + if err := expectEqual(sa, mSubtractAB); err != nil { + t.Fatalf("invalid sa.Subtract(sb): %s", err) + } + if !sbOrig.Equal(sb) { + t.Fatalf("sbOrig must be equal to sb after sa.Subtract(sb); got\n%v\nvs\n%v", sbOrig, sb) + } + + // Verify sb.Subtract(sa) + mSubtractBA := make(map[uint64]bool) + for _, v := range b { + mSubtractBA[v] = true + } + for _, v := range a { + delete(mSubtractBA, v) + } + sa = saOrig.Clone() + sb.Subtract(sa) + if err := expectEqual(sb, mSubtractBA); err != nil { + t.Fatalf("invalid sb.Subtract(sa): %s", err) + } + if !saOrig.Equal(sa) { + t.Fatalf("saOrig must be equal to sa after sb.Subtract(sa); got\n%v\nvs\n%v", saOrig, sa) + } + } + + f(nil, nil) + f([]uint64{1}, nil) + f([]uint64{1, 2, 3}, nil) + f([]uint64{1, 2, 3, 1 << 16, 1 << 32, 2 << 32}, nil) + f([]uint64{1}, []uint64{1}) + f([]uint64{0}, []uint64{1 << 16}) + f([]uint64{1}, []uint64{1 << 16}) + f([]uint64{1}, []uint64{4 << 16}) + f([]uint64{1}, []uint64{1 << 32}) + f([]uint64{1}, []uint64{1 << 32, 2 << 32}) + f([]uint64{1}, []uint64{2 << 32}) + f([]uint64{1, 1<<16 - 1}, []uint64{1 << 16}) + f([]uint64{0, 1<<16 - 1}, []uint64{1 << 16, 1<<16 - 1}) + f([]uint64{0, 1<<16 - 1}, []uint64{1 << 16, 1<<16 - 1, 2 << 16, 8 << 16}) + f([]uint64{0}, []uint64{1 << 16, 1<<16 - 1, 2 << 16, 8 << 16}) + f([]uint64{0,2<<16}, []uint64{1<<16}) + f([]uint64{0,2<<16}, []uint64{1<<16,3<<16}) + f([]uint64{0,2<<16}, []uint64{1<<16,2<<16}) + f([]uint64{0,2<<16}, []uint64{1<<16,2<<16,3<<16}) + f([]uint64{0,2<<32}, []uint64{1<<32}) + f([]uint64{0,2<<32}, []uint64{1<<32,3<<32}) + f([]uint64{0,2<<32}, []uint64{1<<32,2<<32}) + f([]uint64{0,2<<32}, []uint64{1<<32,2<<32,3<<32}) + + var a []uint64 + for i := 0; i < 100; i++ { + a = append(a, uint64(i)) + } + var b []uint64 + for i := 0; i < 100; i++ { + b = append(b, uint64(i+1<<16)) + } + f(a, b) +} + +func expectEqual(s *Set, m map[uint64]bool) error { + if s.Len() != len(m) { + return fmt.Errorf("unexpected s.Len(); got %d; want %d\ns=%v\nm=%v", s.Len(), len(m), s.AppendTo(nil), m) + } + for _, v := range s.AppendTo(nil) { + if !m[v] { + return fmt.Errorf("missing value %d in m; s=%v\nm=%v", v, s.AppendTo(nil), m) + } + } + + // Additional check via s.Has() + for v := range m { + if !s.Has(v) { + return fmt.Errorf("missing value %d in s; s=%v\nm=%v", v, s.AppendTo(nil), m) + } + } + + // Extra check via s.ForEach() + var err error + s.ForEach(func(part []uint64) bool { + for _, v := range part { + if !m[v] { + err = fmt.Errorf("miising value %d in m inside s.ForEach; s=%v\nm=%v", v, s.AppendTo(nil), m) + return false + } + } + return true + }) + return err +} + func TestSetBasicOps(t *testing.T) { for _, itemsCount := range []int{1, 2, 3, 4, 5, 6, 1e2, 1e3, 1e4, 1e5, 1e6, maxUnsortedBuckets * bitsPerBucket * 2} { t.Run(fmt.Sprintf("items_%d", itemsCount), func(t *testing.T) {