diff --git a/bitset.go b/bitset.go index d5f6b94..d4fb61b 100644 --- a/bitset.go +++ b/bitset.go @@ -42,24 +42,56 @@ */ package bitset +/* +#cgo CFLAGS: + +// function to compute the number of set bits in a long integer +#if defined(__GNUC__) +// when a GCC-like compiler is used, call the intrinsic +int pop(unsigned long x) { + return __builtin_popcountl(x); +} +#else +// otherwise use pure C +int pop(unsigned long v) { + v = v - ((v >> 1) & 0x5555555555555555); + v = (v & 0x3333333333333333) + + ((v >> 2) & 0x3333333333333333); + v = ((v + (v >> 4)) & 0x0F0F0F0F0F0F0F0F); + return (int)((v*(0x0101010101010101))>>56); +} +#endif + + +// computes the total number of set bits +unsigned int totalpop(void * v, int n) { + unsigned long * x = (unsigned long *) v; + unsigned int a = 0; + int k = 0; + for(; k < n ; ++k) a+= pop(x[k]); + return a; +} + + +*/ +import "C" + import ( "bytes" "encoding/base64" "encoding/binary" "encoding/json" - "fmt" "errors" + "fmt" + "unsafe" ) - - - +// we use the C code only if longs in C are 64-bit integers, otherwise fall back on pure Go +const useC = (unsafe.Sizeof(uint64(0)) == unsafe.Sizeof(C.ulong(0))) // Word size of a bit set const wordSize = uint(64) - - // for laster arith. const log2WordSize = uint(6) @@ -74,15 +106,15 @@ type BitSetError string // fixup b.set to be non-nil and return the field value func (b *BitSet) safeSet() []uint64 { if b.set == nil { - b.set = make([]uint64, wordsNeeded(0)) + b.set = make([]uint64, wordsNeeded(0)) } return b.set } func wordsNeeded(i uint) int { - if i > ((^uint(0)) - wordSize + 1 ) { + if i > ((^uint(0)) - wordSize + 1) { return int((^uint(0)) >> log2WordSize) - } + } return int((i + (wordSize - 1)) >> log2WordSize) } @@ -90,7 +122,7 @@ func New(length uint) *BitSet { return &BitSet{length, make([]uint64, wordsNeeded(length))} } -func Cap() uint { +func Cap() uint { return ^uint(0) } @@ -118,7 +150,7 @@ func (b *BitSet) Test(i uint) bool { if i >= b.length { return false } - return b.set[i>>log2WordSize] & (1<<(i&(wordSize-1))) != 0 + return b.set[i>>log2WordSize]&(1<<(i&(wordSize-1))) != 0 } // Set bit i to 1 @@ -157,20 +189,20 @@ func (b *BitSet) Flip(i uint) *BitSet { // return the next bit set from the specified index, including possibly the current index // along with an error code (true = valid, false = no set bit found) // for i,e := v.NextSet(0); e; i,e = v.NextSet(i + 1) {...} -func (b *BitSet) NextSet(i uint) (uint,bool) { - x := i >> log2WordSize - if x >= b.length { +func (b *BitSet) NextSet(i uint) (uint, bool) { + x := int(i >> log2WordSize) + if x >= len(b.set) { return 0, false } w := b.set[x] w = w >> (i & (wordSize - 1)) if w != 0 { - return i + trailingZeroes64(w),true + return i + trailingZeroes64(w), true } x = x + 1 - for x < uint(len(b.set)) { + for x < len(b.set) { if b.set[x] != 0 { - return x * wordSize + trailingZeroes64(b.set[x]),true + return uint(x)*wordSize + trailingZeroes64(b.set[x]), true } x = x + 1 @@ -196,8 +228,8 @@ func (b *BitSet) wordCount() int { // Clone this BitSet func (b *BitSet) Clone() *BitSet { c := New(b.length) - if b.set != nil {// Clone should not modify current object - copy(c.set, b.set) + if b.set != nil { // Clone should not modify current object + copy(c.set, b.set) } return c } @@ -209,8 +241,8 @@ func (b *BitSet) Copy(c *BitSet) (count uint) { if c == nil { return } - if b.set != nil {// Copy should not modify current object - copy(c.set, b.set) + if b.set != nil { // Copy should not modify current object + copy(c.set, b.set) } count = c.length if b.length < c.length { @@ -245,6 +277,9 @@ func popcount_2(x uint64) uint64 { // Count (number of set bits) func (b *BitSet) Count() uint { if b != nil && b.set != nil { + if useC { + return uint(C.totalpop(unsafe.Pointer(&b.set[0]), C.int(len(b.set)))) + } cnt := uint64(0) for _, word := range b.set { cnt += popcount_2(word) @@ -257,9 +292,6 @@ func (b *BitSet) Count() uint { // computes the number of trailing zeroes on the assumption that v is non-zero func trailingZeroes64(v uint64) uint { // NOTE: if 0 == v, then c = 63. - if v&0x1 != 0 { - return 0 - } c := uint(1) if (v & 0xffffffff) == 0 { v >>= 32 @@ -300,7 +332,7 @@ func (b *BitSet) Equal(c *BitSet) bool { } // testing for equality shoud not transform the bitset (no call to safeSet) - for p, v := range b.set { + for p, v := range b.set { if c.set[p] != v { return false } @@ -320,32 +352,48 @@ func (b *BitSet) Difference(compare *BitSet) (result *BitSet) { panicIfNull(b) panicIfNull(compare) result = b.Clone() // clone b (in case b is bigger than compare) - l := int(compare.wordCount()) + l := int(compare.wordCount()) if l > int(b.wordCount()) { - l = int(b.wordCount()) + l = int(b.wordCount()) } - for i := 0; i < l ; i++ { - result.set[i] = b.set[i] &^ compare.set[i] + for i := 0; i < l; i++ { + result.set[i] = b.set[i] &^ compare.set[i] } return } +// computes the cardinality of the differnce +func (b *BitSet) DifferenceCardinality(compare *BitSet) uint { + panicIfNull(b) + panicIfNull(compare) + l := int(compare.wordCount()) + if l > int(b.wordCount()) { + l = int(b.wordCount()) + } + cnt := uint64(0) + for i := 0; i < l; i++ { + cnt += popcount_2(b.set[i] &^ compare.set[i]) + } + for i := l; i < len(b.set); i++ { + cnt += popcount_2(b.set[i]) + } + return uint(cnt) +} + // Difference of base set and other set // This is the BitSet equivalent of &^ (and not) -func (b *BitSet) InPlaceDifference(compare *BitSet) { +func (b *BitSet) InPlaceDifference(compare *BitSet) { panicIfNull(b) panicIfNull(compare) - l := int(compare.wordCount()) + l := int(compare.wordCount()) if l > int(b.wordCount()) { - l = int(b.wordCount()) + l = int(b.wordCount()) } - for i := 0; i < l ; i++ { - b.set[i] &^= compare.set[i] + for i := 0; i < l; i++ { + b.set[i] &^= compare.set[i] } } - - // Convenience function: return two bitsets ordered by // increasing length. Note: neither can be nil func sortByLength(a *BitSet, b *BitSet) (ap *BitSet, bp *BitSet) { @@ -370,29 +418,39 @@ func (b *BitSet) Intersection(compare *BitSet) (result *BitSet) { return } +// Computes the cardinality of the union +func (b *BitSet) IntersectionCardinality(compare *BitSet) uint { + panicIfNull(b) + panicIfNull(compare) + b, compare = sortByLength(b, compare) + cnt := uint64(0) + for i, word := range b.set { + cnt += popcount_2(word & compare.set[i]) + } + return uint(cnt) +} + // Intersection of base set and other set // This is the BitSet equivalent of & (and) -func (b *BitSet) InPlaceIntersection(compare *BitSet) { +func (b *BitSet) InPlaceIntersection(compare *BitSet) { panicIfNull(b) panicIfNull(compare) - l := int(compare.wordCount()) + l := int(compare.wordCount()) if l > int(b.wordCount()) { - l = int(b.wordCount()) + l = int(b.wordCount()) } - for i := 0; i < l ; i++ { + for i := 0; i < l; i++ { b.set[i] &= compare.set[i] } - for i := l; i < len(b.set) ; i++ { + for i := l; i < len(b.set); i++ { b.set[i] = 0 } - if compare.length > 0 { - b.extendSetMaybe(compare.length - 1) + if compare.length > 0 { + b.extendSetMaybe(compare.length - 1) } return } - - // Union of base set and other set // This is the BitSet equivalent of | (or) func (b *BitSet) Union(compare *BitSet) (result *BitSet) { @@ -406,26 +464,40 @@ func (b *BitSet) Union(compare *BitSet) (result *BitSet) { return } +func (b *BitSet) UnionCardinality(compare *BitSet) uint { + panicIfNull(b) + panicIfNull(compare) + b, compare = sortByLength(b, compare) + cnt := uint64(0) + for i, word := range b.set { + cnt += popcount_2(word | compare.set[i]) + } + for i := len(b.set); i < len(compare.set); i++ { + cnt += popcount_2(compare.set[i]) + } + + return uint(cnt) +} // Union of base set and other set // This is the BitSet equivalent of | (or) -func (b *BitSet) InPlaceUnion(compare *BitSet) { +func (b *BitSet) InPlaceUnion(compare *BitSet) { panicIfNull(b) panicIfNull(compare) - l := int(compare.wordCount()) + l := int(compare.wordCount()) if l > int(b.wordCount()) { - l = int(b.wordCount()) + l = int(b.wordCount()) } - if compare.length > 0 { - b.extendSetMaybe(compare.length - 1) + if compare.length > 0 { + b.extendSetMaybe(compare.length - 1) } - for i := 0; i < l ; i++ { + for i := 0; i < l; i++ { b.set[i] |= compare.set[i] } - if len(compare.set) > l { - for i := l; i < len(compare.set) ; i++ { - b.set[i] = compare.set[i] - } + if len(compare.set) > l { + for i := l; i < len(compare.set); i++ { + b.set[i] = compare.set[i] + } } } @@ -443,25 +515,41 @@ func (b *BitSet) SymmetricDifference(compare *BitSet) (result *BitSet) { return } +// computes the cardinality of the symmetric difference +func (b *BitSet) SymmetricDifferenceCardinality(compare *BitSet) uint { + panicIfNull(b) + panicIfNull(compare) + b, compare = sortByLength(b, compare) + cnt := uint64(0) + for i, word := range b.set { + cnt += popcount_2(word ^ compare.set[i]) + } + for i := len(b.set); i < len(compare.set); i++ { + cnt += popcount_2(compare.set[i]) + } + + return uint(cnt) +} + // SymmetricDifference of base set and other set // This is the BitSet equivalent of ^ (xor) -func (b *BitSet) InPlaceSymmetricDifference(compare *BitSet) { +func (b *BitSet) InPlaceSymmetricDifference(compare *BitSet) { panicIfNull(b) panicIfNull(compare) - l := int(compare.wordCount()) + l := int(compare.wordCount()) if l > int(b.wordCount()) { - l = int(b.wordCount()) + l = int(b.wordCount()) } - if compare.length > 0 { - b.extendSetMaybe(compare.length - 1) + if compare.length > 0 { + b.extendSetMaybe(compare.length - 1) } - for i := 0; i < l ; i++ { + for i := 0; i < l; i++ { b.set[i] ^= compare.set[i] } - if len(compare.set) > l { - for i := l; i < len(compare.set) ; i++ { - b.set[i] = compare.set[i] - } + if len(compare.set) > l { + for i := l; i < len(compare.set); i++ { + b.set[i] = compare.set[i] + } } } @@ -574,9 +662,9 @@ func (b *BitSet) UnmarshalJSON(data []byte) error { return err } newset := New(uint(length)) - - if uint64(newset.length) != length { - return errors.New("Unmarshalling error: type mismatch") + + if uint64(newset.length) != length { + return errors.New("Unmarshalling error: type mismatch") } // Read remaining bytes as set diff --git a/bitset_test.go b/bitset_test.go index a1fda65..218a19f 100644 --- a/bitset_test.go +++ b/bitset_test.go @@ -51,7 +51,6 @@ func TestBitSetHuge(t *testing.T) { } } - func TestLen(t *testing.T) { v := New(1000) if v.Len() != 1000 { @@ -105,7 +104,7 @@ func TestIterate(t *testing.T) { v.Set(2) data := make([]uint, 3) c := 0 - for i,e := v.NextSet(0); e; i,e = v.NextSet(i + 1) { + for i, e := v.NextSet(0); e; i, e = v.NextSet(i + 1) { data[c] = i c++ } @@ -122,7 +121,7 @@ func TestIterate(t *testing.T) { v.Set(2000) data = make([]uint, 5) c = 0 - for i,e := v.NextSet(0); e; i,e = v.NextSet(i + 1) { + for i, e := v.NextSet(0); e; i, e = v.NextSet(i + 1) { data[c] = i c++ } @@ -187,7 +186,7 @@ func TestCount(t *testing.T) { v := New(tot) checkLast := true for i := uint(0); i < tot; i++ { - sz := v.Count() + sz := uint(v.Count()) if sz != i { t.Errorf("Count reported as %d, but it should be %d", sz, i) checkLast = false @@ -196,7 +195,7 @@ func TestCount(t *testing.T) { v.Set(i) } if checkLast { - sz := v.Count() + sz := uint(v.Count()) if sz != tot { t.Errorf("After all bits set, size reported as %d, but it should be %d", sz, tot) } @@ -208,7 +207,7 @@ func TestCount2(t *testing.T) { tot := uint(64*4 + 11) // just some multi unit64 number v := New(tot) for i := uint(0); i < tot; i += 3 { - sz := v.Count() + sz := uint(v.Count()) if sz != i/3 { t.Errorf("Count reported as %d, but it should be %d", sz, i) break @@ -421,6 +420,13 @@ func TestUnion(t *testing.T) { for i := uint(100); i < 200; i++ { b.Set(i) } + if a.UnionCardinality(b) != 200 { + t.Errorf("Union should have 200 bits set, but had %d", a.UnionCardinality(b)) + } + if a.UnionCardinality(b) != b.UnionCardinality(a) { + t.Errorf("Union should be symmetric") + } + c := a.Union(b) d := b.Union(a) if c.Count() != 200 { @@ -456,7 +462,6 @@ func TestInPlaceUnion(t *testing.T) { } } - func TestIntersection(t *testing.T) { a := New(100) b := New(200) @@ -467,6 +472,12 @@ func TestIntersection(t *testing.T) { for i := uint(100); i < 200; i++ { b.Set(i) } + if a.IntersectionCardinality(b) != 50 { + t.Errorf("Intersection should have 50 bits set, but had %d", a.IntersectionCardinality(b)) + } + if a.IntersectionCardinality(b) != b.IntersectionCardinality(a) { + t.Errorf("Intersection should be symmetric") + } c := a.Intersection(b) d := b.Intersection(a) if c.Count() != 50 { @@ -477,7 +488,6 @@ func TestIntersection(t *testing.T) { } } - func TestInplaceIntersection(t *testing.T) { a := New(100) b := New(200) @@ -503,7 +513,6 @@ func TestInplaceIntersection(t *testing.T) { } } - func TestDifference(t *testing.T) { a := New(100) b := New(200) @@ -514,6 +523,13 @@ func TestDifference(t *testing.T) { for i := uint(100); i < 200; i++ { b.Set(i) } + if a.DifferenceCardinality(b) != 50 { + t.Errorf("a-b Difference should have 50 bits set, but had %d", a.DifferenceCardinality(b)) + } + if b.DifferenceCardinality(a) != 150 { + t.Errorf("b-a Difference should have 150 bits set, but had %d", b.DifferenceCardinality(a)) + } + c := a.Difference(b) d := b.Difference(a) if c.Count() != 50 { @@ -527,7 +543,6 @@ func TestDifference(t *testing.T) { } } - func TestInPlaceDifference(t *testing.T) { a := New(100) b := New(200) @@ -563,6 +578,13 @@ func TestSymmetricDifference(t *testing.T) { for i := uint(100); i < 200; i++ { b.Set(i) } + if a.SymmetricDifferenceCardinality(b) != 150 { + t.Errorf("a^b Difference should have 150 bits set, but had %d", a.SymmetricDifferenceCardinality(b)) + } + if b.SymmetricDifferenceCardinality(a) != 150 { + t.Errorf("b^a Difference should have 150 bits set, but had %d", b.SymmetricDifferenceCardinality(a)) + } + c := a.SymmetricDifference(b) d := b.SymmetricDifference(a) if c.Count() != 150 { @@ -684,6 +706,7 @@ func BenchmarkSetExpand(b *testing.B) { } } +// go test -bench=Count func BenchmarkCount(b *testing.B) { b.StopTimer() s := New(100000) @@ -695,3 +718,35 @@ func BenchmarkCount(b *testing.B) { s.Count() } } + +// go test -bench=Iterate +func BenchmarkIterate(b *testing.B) { + b.StopTimer() + s := New(10000) + for i := 0; i < 10000; i += 3 { + s.Set(uint(i)) + } + b.StartTimer() + for j := 0; j < b.N; j++ { + c := uint(0) + for i, e := s.NextSet(0); e; i, e = s.NextSet(i + 1) { + c++ + } + } +} + +// go test -bench=SparseIterate +func BenchmarkSparseIterate(b *testing.B) { + b.StopTimer() + s := New(100000) + for i := 0; i < 100000; i += 30 { + s.Set(uint(i)) + } + b.StartTimer() + for j := 0; j < b.N; j++ { + c := uint(0) + for i, e := s.NextSet(0); e; i, e = s.NextSet(i + 1) { + c++ + } + } +}