diff --git a/bitset.go b/bitset.go index e7842bf..60e5b1f 100644 --- a/bitset.go +++ b/bitset.go @@ -371,6 +371,25 @@ func (b *BitSet) Difference(compare *BitSet) (result *BitSet) { return } +// computes the cardinality of the differnce +func (b *BitSet) DifferenceCardinality(compare *BitSet) (uint) { + panicIfNull(b) + panicIfNull(compare) + l := int(compare.wordCount()) + if l > int(b.wordCount()) { + l = int(b.wordCount()) + } + cnt := uint64(0) + for i := 0; i < l ; i++ { + cnt += popcount_2(b.set[i] &^ compare.set[i]) + } + for i := l; i < len(b.set) ; i++ { + cnt += popcount_2(b.set[i]) + } + return uint(cnt) +} + + // Difference of base set and other set // This is the BitSet equivalent of &^ (and not) func (b *BitSet) InPlaceDifference(compare *BitSet) { @@ -411,6 +430,20 @@ func (b *BitSet) Intersection(compare *BitSet) (result *BitSet) { return } + +// Computes the cardinality of the union +func (b *BitSet) IntersectionCardinality(compare *BitSet) (uint) { + panicIfNull(b) + panicIfNull(compare) + b, compare = sortByLength(b, compare) + cnt := uint64(0) + for i, word := range b.set { + cnt += popcount_2(word & compare.set[i]) + } + return uint(cnt) +} + + // Intersection of base set and other set // This is the BitSet equivalent of & (and) func (b *BitSet) InPlaceIntersection(compare *BitSet) { @@ -447,6 +480,21 @@ func (b *BitSet) Union(compare *BitSet) (result *BitSet) { return } +func (b *BitSet) UnionCardinality(compare *BitSet) (uint) { + panicIfNull(b) + panicIfNull(compare) + b, compare = sortByLength(b, compare) + cnt := uint64(0) + for i, word := range b.set { + cnt += popcount_2(word | compare.set[i]) + } + for i := len(b.set); i < len(compare.set) ; i++ { + cnt += popcount_2(compare.set[i]) + } + + return uint(cnt) +} + // Union of base set and other set // This is the BitSet equivalent of | (or) @@ -484,6 +532,23 @@ func (b *BitSet) SymmetricDifference(compare *BitSet) (result *BitSet) { return } +// computes the cardinality of the symmetric difference +func (b *BitSet) SymmetricDifferenceCardinality(compare *BitSet) (uint) { + panicIfNull(b) + panicIfNull(compare) + b, compare = sortByLength(b, compare) + cnt := uint64(0) + for i, word := range b.set { + cnt += popcount_2(word ^ compare.set[i]) + } + for i := len(b.set); i < len(compare.set) ; i++ { + cnt += popcount_2(compare.set[i]) + } + + return uint(cnt) +} + + // SymmetricDifference of base set and other set // This is the BitSet equivalent of ^ (xor) func (b *BitSet) InPlaceSymmetricDifference(compare *BitSet) { diff --git a/bitset_test.go b/bitset_test.go index 68d7a64..eedc8bf 100644 --- a/bitset_test.go +++ b/bitset_test.go @@ -188,7 +188,7 @@ func TestCount(t *testing.T) { v := New(tot) checkLast := true for i := uint(0); i < tot; i++ { - sz := v.Count() + sz := uint(v.Count()) if sz != i { t.Errorf("Count reported as %d, but it should be %d", sz, i) checkLast = false @@ -197,7 +197,7 @@ func TestCount(t *testing.T) { v.Set(i) } if checkLast { - sz := v.Count() + sz := uint(v.Count()) if sz != tot { t.Errorf("After all bits set, size reported as %d, but it should be %d", sz, tot) } @@ -209,7 +209,7 @@ func TestCount2(t *testing.T) { tot := uint(64*4 + 11) // just some multi unit64 number v := New(tot) for i := uint(0); i < tot; i += 3 { - sz := v.Count() + sz := uint(v.Count()) if sz != i/3 { t.Errorf("Count reported as %d, but it should be %d", sz, i) break @@ -422,6 +422,13 @@ func TestUnion(t *testing.T) { for i := uint(100); i < 200; i++ { b.Set(i) } + if a.UnionCardinality(b) != 200 { + t.Errorf("Union should have 200 bits set, but had %d", a.UnionCardinality(b)) + } + if a.UnionCardinality(b) != b.UnionCardinality(a) { + t.Errorf("Union should be symmetric") + } + c := a.Union(b) d := b.Union(a) if c.Count() != 200 { @@ -468,6 +475,12 @@ func TestIntersection(t *testing.T) { for i := uint(100); i < 200; i++ { b.Set(i) } + if a.IntersectionCardinality(b) != 50 { + t.Errorf("Intersection should have 50 bits set, but had %d", a.IntersectionCardinality(b)) + } + if a.IntersectionCardinality(b) != b.IntersectionCardinality(a) { + t.Errorf("Intersection should be symmetric") + } c := a.Intersection(b) d := b.Intersection(a) if c.Count() != 50 { @@ -515,6 +528,13 @@ func TestDifference(t *testing.T) { for i := uint(100); i < 200; i++ { b.Set(i) } + if a.DifferenceCardinality(b) != 50 { + t.Errorf("a-b Difference should have 50 bits set, but had %d", a.DifferenceCardinality(b)) + } + if b.DifferenceCardinality(a) != 150 { + t.Errorf("b-a Difference should have 150 bits set, but had %d", b.DifferenceCardinality(a)) + } + c := a.Difference(b) d := b.Difference(a) if c.Count() != 50 { @@ -564,6 +584,13 @@ func TestSymmetricDifference(t *testing.T) { for i := uint(100); i < 200; i++ { b.Set(i) } + if a.SymmetricDifferenceCardinality(b) != 150 { + t.Errorf("a^b Difference should have 150 bits set, but had %d", a.SymmetricDifferenceCardinality(b)) + } + if b.SymmetricDifferenceCardinality(a) != 150 { + t.Errorf("b^a Difference should have 150 bits set, but had %d", b.SymmetricDifferenceCardinality(a)) + } + c := a.SymmetricDifference(b) d := b.SymmetricDifference(a) if c.Count() != 150 {