diff --git a/runtime/Go/antlr/utils.go b/runtime/Go/antlr/utils.go index 9c7d0a6cda..8fc8763345 100644 --- a/runtime/Go/antlr/utils.go +++ b/runtime/Go/antlr/utils.go @@ -8,7 +8,7 @@ import ( "bytes" "errors" "fmt" - "sort" + "math/bits" "strconv" "strings" ) @@ -71,59 +71,87 @@ type hasher interface { hash() int } +const bitsPerWord = 64 + +func indexForBit(bit int) int { + return bit / bitsPerWord +} + +func wordForBit(data []uint64, bit int) uint64 { + idx := indexForBit(bit) + if idx >= len(data) { + return 0 + } + return data[idx] +} + +func maskForBit(bit int) uint64 { + return uint64(1) << (bit % bitsPerWord) +} + +func wordsNeeded(bit int) int { + return indexForBit(bit) + 1 +} + type BitSet struct { - data map[int]bool + data []uint64 } func NewBitSet() *BitSet { - b := new(BitSet) - b.data = make(map[int]bool) - return b + return &BitSet{} } func (b *BitSet) add(value int) { - b.data[value] = true + idx := indexForBit(value) + if idx >= len(b.data) { + size := wordsNeeded(value) + data := make([]uint64, size) + copy(data, b.data) + b.data = data + } + b.data[idx] |= maskForBit(value) } func (b *BitSet) clear(index int) { - delete(b.data, index) + idx := indexForBit(index) + if idx >= len(b.data) { + return + } + b.data[idx] &= ^maskForBit(index) } func (b *BitSet) or(set *BitSet) { - for k := range set.data { - b.add(k) + size := intMax(b.minLen(), set.minLen()) + if size > len(b.data) { + data := make([]uint64, size) + copy(data, b.data) + b.data = data + } + for i := 0; i < size; i++ { + b.data[i] |= set.data[i] } } func (b *BitSet) remove(value int) { - delete(b.data, value) + b.clear(value) } func (b *BitSet) contains(value int) bool { - return b.data[value] -} - -func (b *BitSet) values() []int { - ks := make([]int, len(b.data)) - i := 0 - for k := range b.data { - ks[i] = k - i++ + idx := indexForBit(value) + if idx >= len(b.data) { + return false } - sort.Ints(ks) - return ks + return (b.data[idx] & maskForBit(value)) != 0 } func (b *BitSet) minValue() int { - min := 2147483647 - - for k := range b.data { - if k < min { - min = k + for i, v := range b.data { + if v == 0 { + continue } + return i*bitsPerWord + bits.TrailingZeros64(v) } - - return min + return 2147483647 } func (b *BitSet) equals(other interface{}) bool { @@ -132,12 +160,16 @@ func (b *BitSet) equals(other interface{}) bool { return false } + if b == otherBitSet { + return true + } + if len(b.data) != len(otherBitSet.data) { return false } - for k, v := range b.data { - if otherBitSet.data[k] != v { + for k := range b.data { + if b.data[k] != otherBitSet.data[k] { return false } } @@ -145,18 +177,35 @@ func (b *BitSet) equals(other interface{}) bool { return true } +func (b *BitSet) minLen() int { + for i := len(b.data); i > 0; i-- { + if b.data[i-1] != 0 { + return i + } + } + return 0 +} + func (b *BitSet) length() int { - return len(b.data) + cnt := 0 + for _, val := range b.data { + cnt += bits.OnesCount64(val) + } + return cnt } func (b *BitSet) String() string { - vals := b.values() - valsS := make([]string, len(vals)) + vals := make([]string, 0, b.length()) - for i, val := range vals { - valsS[i] = strconv.Itoa(val) + for i, v := range b.data { + for v != 0 { + n := bits.TrailingZeros64(v) + vals = append(vals, strconv.Itoa(i*bitsPerWord+n)) + v &= ^(uint64(1) << n) + } } - return "{" + strings.Join(valsS, ", ") + "}" + + return "{" + strings.Join(vals, ", ") + "}" } type AltDict struct { diff --git a/runtime/Go/antlr/utils_test.go b/runtime/Go/antlr/utils_test.go new file mode 100644 index 0000000000..88e7aa9fe8 --- /dev/null +++ b/runtime/Go/antlr/utils_test.go @@ -0,0 +1,151 @@ +package antlr + +import "testing" + +func TestBitSet(t *testing.T) { + bs1 := NewBitSet() + if got, want := bs1.String(), "{}"; got != want { + t.Errorf("String() = %q, want %q", got, want) + } + if got, want := bs1.length(), 0; got != want { + t.Errorf("length() = %q, want %q", got, want) + } + if got, want := bs1.contains(1), false; got != want { + t.Errorf("contains(%v) = %v, want %v", 1, got, want) + } + if got, want := bs1.minValue(), 2147483647; got != want { + t.Errorf("minValue() = %v, want %v", got, want) + } + bs1.add(0) + if got, want := bs1.String(), "{0}"; got != want { + t.Errorf("String() = %q, want %q", got, want) + } + if got, want := bs1.length(), 1; got != want { + t.Errorf("length() = %q, want %q", got, want) + } + if got, want := bs1.contains(0), true; got != want { + t.Errorf("contains(%v) = %v, want %v", 1, got, want) + } + if got, want := bs1.minValue(), 0; got != want { + t.Errorf("minValue() = %v, want %v", got, want) + } + bs1.add(63) + if got, want := bs1.String(), "{0, 63}"; got != want { + t.Errorf("String() = %q, want %q", got, want) + } + if got, want := bs1.length(), 2; got != want { + t.Errorf("length() = %q, want %q", got, want) + } + if got, want := bs1.contains(1), false; got != want { + t.Errorf("contains(%v) = %v, want %v", 1, got, want) + } + if got, want := bs1.contains(0), true; got != want { + t.Errorf("contains(%v) = %v, want %v", 1, got, want) + } + if got, want := bs1.contains(63), true; got != want { + t.Errorf("contains(%v) = %v, want %v", 1, got, want) + } + if got, want := bs1.minValue(), 0; got != want { + t.Errorf("minValue() = %v, want %v", got, want) + } + bs1.remove(0) + if got, want := bs1.String(), "{63}"; got != want { + t.Errorf("String() = %q, want %q", got, want) + } + if got, want := bs1.length(), 1; got != want { + t.Errorf("length() = %q, want %q", got, want) + } + if got, want := bs1.contains(0), false; got != want { + t.Errorf("contains(%v) = %v, want %v", 1, got, want) + } + if got, want := bs1.contains(63), true; got != want { + t.Errorf("contains(%v) = %v, want %v", 1, got, want) + } + if got, want := bs1.minValue(), 63; got != want { + t.Errorf("minValue() = %v, want %v", got, want) + } + bs1.add(20) + if got, want := bs1.String(), "{20, 63}"; got != want { + t.Errorf("String() = %q, want %q", got, want) + } + if got, want := bs1.length(), 2; got != want { + t.Errorf("length() = %q, want %q", got, want) + } + if got, want := bs1.contains(0), false; got != want { + t.Errorf("contains(%v) = %v, want %v", 1, got, want) + } + if got, want := bs1.contains(20), true; got != want { + t.Errorf("contains(%v) = %v, want %v", 1, got, want) + } + if got, want := bs1.contains(63), true; got != want { + t.Errorf("contains(%v) = %v, want %v", 1, got, want) + } + if got, want := bs1.minValue(), 20; got != want { + t.Errorf("minValue() = %v, want %v", got, want) + } + bs1.clear(63) + if got, want := bs1.String(), "{20}"; got != want { + t.Errorf("String() = %q, want %q", got, want) + } + if got, want := bs1.length(), 1; got != want { + t.Errorf("length() = %q, want %q", got, want) + } + if got, want := bs1.contains(0), false; got != want { + t.Errorf("contains(%v) = %v, want %v", 1, got, want) + } + if got, want := bs1.contains(20), true; got != want { + t.Errorf("contains(%v) = %v, want %v", 1, got, want) + } + if got, want := bs1.contains(63), false; got != want { + t.Errorf("contains(%v) = %v, want %v", 1, got, want) + } + if got, want := bs1.minValue(), 20; got != want { + t.Errorf("minValue() = %v, want %v", got, want) + } + bs2 := NewBitSet() + bs2.add(64) + bs1.or(bs2) + if got, want := bs1.String(), "{20, 64}"; got != want { + t.Errorf("String() = %q, want %q", got, want) + } + if got, want := bs1.length(), 2; got != want { + t.Errorf("length() = %q, want %q", got, want) + } + if got, want := bs1.contains(0), false; got != want { + t.Errorf("contains(%v) = %v, want %v", 1, got, want) + } + if got, want := bs1.contains(20), true; got != want { + t.Errorf("contains(%v) = %v, want %v", 1, got, want) + } + if got, want := bs1.contains(63), false; got != want { + t.Errorf("contains(%v) = %v, want %v", 1, got, want) + } + if got, want := bs1.contains(64), true; got != want { + t.Errorf("contains(%v) = %v, want %v", 1, got, want) + } + if got, want := bs1.minValue(), 20; got != want { + t.Errorf("minValue() = %v, want %v", got, want) + } + bs1.remove(20) + if got, want := bs1.String(), "{64}"; got != want { + t.Errorf("String() = %q, want %q", got, want) + } + if got, want := bs1.length(), 1; got != want { + t.Errorf("length() = %q, want %q", got, want) + } + if got, want := bs1.contains(0), false; got != want { + t.Errorf("contains(%v) = %v, want %v", 1, got, want) + } + if got, want := bs1.contains(20), false; got != want { + t.Errorf("contains(%v) = %v, want %v", 1, got, want) + } + if got, want := bs1.contains(63), false; got != want { + t.Errorf("contains(%v) = %v, want %v", 1, got, want) + } + if got, want := bs1.contains(64), true; got != want { + t.Errorf("contains(%v) = %v, want %v", 1, got, want) + } + if got, want := bs1.minValue(), 64; got != want { + t.Errorf("minValue() = %v, want %v", got, want) + } +}