diff --git a/runtime/Go/antlr/atn_config.go b/runtime/Go/antlr/atn_config.go index 604c968caa..13f6109f65 100644 --- a/runtime/Go/antlr/atn_config.go +++ b/runtime/Go/antlr/atn_config.go @@ -8,21 +8,14 @@ import ( "fmt" ) -type comparable interface { - equals(other interface{}) bool -} - // ATNConfig is a tuple: (ATN state, predicted alt, syntactic, semantic // context). The syntactic context is a graph-structured stack node whose // path(s) to the root is the rule invocation(s) chain used to arrive at the // state. The semantic context is the tree of semantic predicates encountered // before reaching an ATN state. type ATNConfig interface { - comparable - - gequals(other Collectable[ATNConfig]) bool - - hash() int + Equals(o Collectable[ATNConfig]) bool + Hash() int GetState() ATNState GetAlt() int @@ -136,15 +129,17 @@ func (b *BaseATNConfig) GetReachesIntoOuterContext() int { func (b *BaseATNConfig) SetReachesIntoOuterContext(v int) { b.reachesIntoOuterContext = v } -func (b *BaseATNConfig) equals(o interface{}) bool { - return b.gequals(o.(Collectable[ATNConfig])) -} +// Equals is the default comparison function for an ATNConfig when no specialist implementation is required +// for a collection. +// // An ATN configuration is equal to another if both have the same state, they // predict the same alternative, and syntactic/semantic contexts are the same. -func (b *BaseATNConfig) gequals(o Collectable[ATNConfig]) bool { +func (b *BaseATNConfig) Equals(o Collectable[ATNConfig]) bool { if b == o { return true + } else if o == nil { + return false } var other, ok = o.(*BaseATNConfig) @@ -158,30 +153,32 @@ func (b *BaseATNConfig) gequals(o Collectable[ATNConfig]) bool { if b.context == nil { equal = other.context == nil } else { - equal = b.context.gequals(other.context) + equal = b.context.Equals(other.context) } var ( nums = b.state.GetStateNumber() == other.state.GetStateNumber() alts = b.alt == other.alt - cons = b.semanticContext.equals(other.semanticContext) + cons = b.semanticContext.Equals(other.semanticContext) sups = b.precedenceFilterSuppressed == other.precedenceFilterSuppressed ) return nums && alts && cons && sups && equal } -func (b *BaseATNConfig) hash() int { +// Hash is the default hash function for BaseATNConfig, when no specialist hash function +// is required for a collection +func (b *BaseATNConfig) Hash() int { var c int if b.context != nil { - c = b.context.hash() + c = b.context.Hash() } h := murmurInit(7) h = murmurUpdate(h, b.state.GetStateNumber()) h = murmurUpdate(h, b.alt) h = murmurUpdate(h, c) - h = murmurUpdate(h, b.semanticContext.hash()) + h = murmurUpdate(h, b.semanticContext.Hash()) return murmurFinish(h, 4) } @@ -248,7 +245,9 @@ func NewLexerATNConfig1(state ATNState, alt int, context PredictionContext) *Lex return &LexerATNConfig{BaseATNConfig: NewBaseATNConfig5(state, alt, context, SemanticContextNone)} } -func (l *LexerATNConfig) hash() int { +// Hash is the default hash function for LexerATNConfig objects, it can be used directly or via +// the default comparator [ObjEqComparator]. +func (l *LexerATNConfig) Hash() int { var f int if l.passedThroughNonGreedyDecision { f = 1 @@ -258,19 +257,20 @@ func (l *LexerATNConfig) hash() int { h := murmurInit(7) h = murmurUpdate(h, l.state.GetStateNumber()) h = murmurUpdate(h, l.alt) - h = murmurUpdate(h, l.context.hash()) - h = murmurUpdate(h, l.semanticContext.hash()) + h = murmurUpdate(h, l.context.Hash()) + h = murmurUpdate(h, l.semanticContext.Hash()) h = murmurUpdate(h, f) - h = murmurUpdate(h, l.lexerActionExecutor.hash()) + h = murmurUpdate(h, l.lexerActionExecutor.Hash()) h = murmurFinish(h, 6) return h } -func (l *LexerATNConfig) equals(other interface{}) bool { - return l.gequals(other.(Collectable[ATNConfig])) -} - -func (l *LexerATNConfig) gequals(other Collectable[ATNConfig]) bool { +// Equals is the default comparison function for LexerATNConfig objects, it can be used directly or via +// the default comparator [ObjEqComparator]. +func (l *LexerATNConfig) Equals(other Collectable[ATNConfig]) bool { + if l == other { + return true + } var othert, ok = other.(*LexerATNConfig) if l == other { @@ -284,7 +284,7 @@ func (l *LexerATNConfig) gequals(other Collectable[ATNConfig]) bool { var b bool if l.lexerActionExecutor != nil { - b = !l.lexerActionExecutor.equals(othert.lexerActionExecutor) + b = !l.lexerActionExecutor.Equals(othert.lexerActionExecutor) } else { b = othert.lexerActionExecutor != nil } @@ -293,7 +293,7 @@ func (l *LexerATNConfig) gequals(other Collectable[ATNConfig]) bool { return false } - return l.BaseATNConfig.equals(othert.BaseATNConfig) + return l.BaseATNConfig.Equals(othert.BaseATNConfig) } func checkNonGreedyDecision(source *LexerATNConfig, target ATNState) bool { diff --git a/runtime/Go/antlr/atn_config_set.go b/runtime/Go/antlr/atn_config_set.go index f17b5c2b07..2b228cc377 100644 --- a/runtime/Go/antlr/atn_config_set.go +++ b/runtime/Go/antlr/atn_config_set.go @@ -7,18 +7,17 @@ package antlr import "fmt" type ATNConfigSet interface { - hash() int + Hash() int + Equals(o Collectable[ATNConfig]) bool Add(ATNConfig, *DoubleDict) bool AddAll([]ATNConfig) bool - GetStates() Set + GetStates() *JStore[ATNState, Comparator[ATNState]] GetPredicates() []SemanticContext GetItems() []ATNConfig OptimizeConfigs(interpreter *BaseATNSimulator) - Equals(other interface{}) bool - Length() int IsEmpty() bool Contains(ATNConfig) bool @@ -57,7 +56,7 @@ type BaseATNConfigSet struct { // effectively doubles the number of objects associated with ATNConfigs. All // keys are hashed by (s, i, _, pi), not including the context. Wiped out when // read-only because a set becomes a DFA state. - configLookup Set + configLookup *JStore[ATNConfig, Comparator[ATNConfig]] // configs is the added elements. configs []ATNConfig @@ -83,7 +82,7 @@ type BaseATNConfigSet struct { // readOnly is whether it is read-only. Do not // allow any code to manipulate the set if true because DFA states will point at - // sets and those must not change. It not protect other fields; conflictingAlts + // sets and those must not change. It not, protect other fields; conflictingAlts // in particular, which is assigned after readOnly. readOnly bool @@ -104,7 +103,7 @@ func (b *BaseATNConfigSet) Alts() *BitSet { func NewBaseATNConfigSet(fullCtx bool) *BaseATNConfigSet { return &BaseATNConfigSet{ cachedHash: -1, - configLookup: newArray2DHashSetWithCap(hashATNConfig, equalATNConfigs, 16, 2), + configLookup: NewJStore[ATNConfig, Comparator[ATNConfig]](&ATNConfigComparator[ATNConfig]{}), fullCtx: fullCtx, } } @@ -126,9 +125,11 @@ func (b *BaseATNConfigSet) Add(config ATNConfig, mergeCache *DoubleDict) bool { b.dipsIntoOuterContext = true } - existing := b.configLookup.Add(config).(ATNConfig) + existing, present := b.configLookup.Put(config) - if existing == config { + // The config was not already in the set + // + if !present { b.cachedHash = -1 b.configs = append(b.configs, config) // Track order here return true @@ -154,11 +155,14 @@ func (b *BaseATNConfigSet) Add(config ATNConfig, mergeCache *DoubleDict) bool { return true } -func (b *BaseATNConfigSet) GetStates() Set { - states := newArray2DHashSet(nil, nil) +func (b *BaseATNConfigSet) GetStates() *JStore[ATNState, Comparator[ATNState]] { + + // states uses the standard comparator provided by the ATNState instance + // + states := NewJStore[ATNState, Comparator[ATNState]](&ObjEqComparator[ATNState]{}) for i := 0; i < len(b.configs); i++ { - states.Add(b.configs[i].GetState()) + states.Put(b.configs[i].GetState()) } return states @@ -227,7 +231,7 @@ func (b *BaseATNConfigSet) Compare(bs *BaseATNConfigSet) bool { for _, c := range b.configs { found := false for _, c2 := range bs.configs { - if c.equals(c2) { + if c.Equals(c2) { found = true break } @@ -241,7 +245,8 @@ func (b *BaseATNConfigSet) Compare(bs *BaseATNConfigSet) bool { return true } -func (b *BaseATNConfigSet) Equals(other interface{}) bool { + +func (b *BaseATNConfigSet) Equals(other Collectable[ATNConfig]) bool { if b == other { return true } else if _, ok := other.(*BaseATNConfigSet); !ok { @@ -259,7 +264,7 @@ func (b *BaseATNConfigSet) Equals(other interface{}) bool { b.Compare(other2) } -func (b *BaseATNConfigSet) hash() int { +func (b *BaseATNConfigSet) Hash() int { if b.readOnly { if b.cachedHash == -1 { b.cachedHash = b.hashCodeConfigs() @@ -274,7 +279,7 @@ func (b *BaseATNConfigSet) hash() int { func (b *BaseATNConfigSet) hashCodeConfigs() int { h := 1 for _, config := range b.configs { - h = 31*h + config.hash() + h = 31*h + config.Hash() } return h } @@ -310,7 +315,7 @@ func (b *BaseATNConfigSet) Clear() { b.configs = make([]ATNConfig, 0) b.cachedHash = -1 - b.configLookup = newArray2DHashSet(nil, equalATNConfigs) + b.configLookup = NewJStore[ATNConfig, Comparator[ATNConfig]](&BaseATNConfigComparator[ATNConfig]{}) } func (b *BaseATNConfigSet) FullContext() bool { @@ -392,7 +397,8 @@ type OrderedATNConfigSet struct { func NewOrderedATNConfigSet() *OrderedATNConfigSet { b := NewBaseATNConfigSet(false) - b.configLookup = newArray2DHashSet(nil, nil) + // This set uses the standard Hash() and Equals() from ATNConfig + b.configLookup = NewJStore[ATNConfig, Comparator[ATNConfig]](&ObjEqComparator[ATNConfig]{}) return &OrderedATNConfigSet{BaseATNConfigSet: b} } @@ -402,7 +408,7 @@ func hashATNConfig(i interface{}) int { hash := 7 hash = 31*hash + o.GetState().GetStateNumber() hash = 31*hash + o.GetAlt() - hash = 31*hash + o.GetSemanticContext().hash() + hash = 31*hash + o.GetSemanticContext().Hash() return hash } @@ -430,5 +436,5 @@ func equalATNConfigs(a, b interface{}) bool { return false } - return ai.GetSemanticContext().equals(bi.GetSemanticContext()) + return ai.GetSemanticContext().Equals(bi.GetSemanticContext()) } diff --git a/runtime/Go/antlr/atn_state.go b/runtime/Go/antlr/atn_state.go index 3835bb2e93..55c9782236 100644 --- a/runtime/Go/antlr/atn_state.go +++ b/runtime/Go/antlr/atn_state.go @@ -49,7 +49,8 @@ type ATNState interface { AddTransition(Transition, int) String() string - hash() int + Hash() int + Equals(Collectable[ATNState]) bool } type BaseATNState struct { @@ -123,7 +124,7 @@ func (as *BaseATNState) SetNextTokenWithinRule(v *IntervalSet) { as.NextTokenWithinRule = v } -func (as *BaseATNState) hash() int { +func (as *BaseATNState) Hash() int { return as.stateNumber } @@ -131,7 +132,7 @@ func (as *BaseATNState) String() string { return strconv.Itoa(as.stateNumber) } -func (as *BaseATNState) equals(other interface{}) bool { +func (as *BaseATNState) Equals(other Collectable[ATNState]) bool { if ot, ok := other.(ATNState); ok { return as.stateNumber == ot.GetStateNumber() } diff --git a/runtime/Go/antlr/comparators.go b/runtime/Go/antlr/comparators.go new file mode 100644 index 0000000000..5c9e4684aa --- /dev/null +++ b/runtime/Go/antlr/comparators.go @@ -0,0 +1,133 @@ +package antlr + +// This file contains all the implementations of custom comparators used for generic collections when the +// Hash() and Equals() funcs supplied by the struct objects themselves need to be overridden. Normally, we would +// put the comparators in the source file for the struct themselves, but given the organization of this code is +// sorta kinda based upon the Java code, I found it confusing trying to find out which comparator was where and used by +// which instantiation of a collection. For instance, an Array2DHashSet in the Java source, when used with ATNConfig +// collections requires three different comparators depending on what the collection is being used for. Collecting - pun intended - +// all the comparators here, makes it much easier to see which implementation of hash and equals is used by which collection. +// It also makes it easy to verify that the Hash() and Equals() functions marry up with the Java implementations. + +// ObjEqComparator is the equivalent of the Java ObjectEqualityComparator, which is the default instance of +// Equality comparator. We do not have inheritance in Go, only interfaces, so we use generics to enforce some +// type safety and avoid having to implement this for every type that we want to perform comparison on. +// +// This comparator works by using the standard Hash() and Equals() methods of the type T that is being compared. Which +// allows us to use it in any collection instance that does nto require a special hash or equals implementation. +type ObjEqComparator[T Collectable[T]] struct{} + +// Equals2 delegates to the Equals() method of type T +func (c *ObjEqComparator[T]) Equals2(o1, o2 T) bool { + return o1.Equals(o2) +} + +// Hash1 delegates to the Hash() method of type T +func (c *ObjEqComparator[T]) Hash1(o T) int { + + return o.Hash() +} + +type SemCComparator[T Collectable[T]] struct{} + +// ATNConfigComparator is used as the compartor for the configLookup field of an ATNConfigSet +// and has a custom Equals() and Hash() implementation, because equality is not based on the +// standard Hash() and Equals() methods of the ATNConfig type. +type ATNConfigComparator[T Collectable[T]] struct { +} + +// Equals2 is a custom comparator for ATNConfigs specifically for configLookup +func (c *ATNConfigComparator[T]) Equals2(o1, o2 ATNConfig) bool { + + // Same pointer, must be equal, even if both nil + // + if o1 == o2 { + return true + + } + + // If either are nil, but not both, then the result is false + // + if o1 == nil || o2 == nil { + return false + } + + return o1.GetState().GetStateNumber() == o2.GetState().GetStateNumber() && + o1.GetAlt() == o2.GetAlt() && + o1.GetSemanticContext().Equals(o2.GetSemanticContext()) +} + +// Hash1 is custom hash implementation for ATNConfigs specifically for configLookup +func (c *ATNConfigComparator[T]) Hash1(o ATNConfig) int { + hash := 7 + hash = 31*hash + o.GetState().GetStateNumber() + hash = 31*hash + o.GetAlt() + hash = 31*hash + o.GetSemanticContext().Hash() + return hash +} + +// ATNAltConfigComparator is used as the comparator for mapping configs to Alt Bitsets +type ATNAltConfigComparator[T Collectable[T]] struct { +} + +// Equals2 is a custom comparator for ATNConfigs specifically for configLookup +func (c *ATNAltConfigComparator[T]) Equals2(o1, o2 ATNConfig) bool { + + // Same pointer, must be equal, even if both nil + // + if o1 == o2 { + return true + + } + + // If either are nil, but not both, then the result is false + // + if o1 == nil || o2 == nil { + return false + } + + return o1.GetState().GetStateNumber() == o2.GetState().GetStateNumber() && + o1.GetContext().Equals(o2.GetContext()) +} + +// Hash1 is custom hash implementation for ATNConfigs specifically for configLookup +func (c *ATNAltConfigComparator[T]) Hash1(o ATNConfig) int { + h := murmurInit(7) + h = murmurUpdate(h, o.GetState().GetStateNumber()) + h = murmurUpdate(h, o.GetContext().Hash()) + return murmurFinish(h, 2) +} + +// BaseATNConfigComparator is used as the comparator for the configLookup field of a BaseATNConfigSet +// and has a custom Equals() and Hash() implementation, because equality is not based on the +// standard Hash() and Equals() methods of the ATNConfig type. +type BaseATNConfigComparator[T Collectable[T]] struct { +} + +// Equals2 is a custom comparator for ATNConfigs specifically for baseATNConfigSet +func (c *BaseATNConfigComparator[T]) Equals2(o1, o2 ATNConfig) bool { + + // Same pointer, must be equal, even if both nil + // + if o1 == o2 { + return true + + } + + // If either are nil, but not both, then the result is false + // + if o1 == nil || o2 == nil { + return false + } + + return o1.GetState().GetStateNumber() == o2.GetState().GetStateNumber() && + o1.GetAlt() == o2.GetAlt() && + o1.GetSemanticContext().Equals(o2.GetSemanticContext()) +} + +// Hash1 is custom hash implementation for ATNConfigs specifically for configLookup, but in fact just +// delegates to the standard Hash() method of the ATNConfig type. +func (c *BaseATNConfigComparator[T]) Hash1(o ATNConfig) int { + + return o.Hash() +} diff --git a/runtime/Go/antlr/dfa.go b/runtime/Go/antlr/dfa.go index 26b622c200..260a5aa91b 100644 --- a/runtime/Go/antlr/dfa.go +++ b/runtime/Go/antlr/dfa.go @@ -13,11 +13,12 @@ type DFA struct { // states is all the DFA states. Use Map to get the old state back; Set can only // indicate whether it is there. Go maps implement key hash collisions and so on and are very // good, but the DFAState is an object and can't be used directly as the key as it can in say JAva - // amd C#, whereby if the hashcode is the same for two objects, then equals() is called against them + // amd C#, whereby if the hashcode is the same for two objects, then Equals() is called against them // to see if they really are the same object. // // - states *JStore[*DFAState, *DFAStateComparator[*DFAState]] + states *JStore[*DFAState, *ObjEqComparator[*DFAState]] + numstates int s0 *DFAState @@ -31,7 +32,7 @@ func NewDFA(atnStartState DecisionState, decision int) *DFA { dfa := &DFA{ atnStartState: atnStartState, decision: decision, - states: NewJStore[*DFAState, *DFAStateComparator[*DFAState]](&DFAStateComparator[*DFAState]{}), + states: NewJStore[*DFAState, *ObjEqComparator[*DFAState]](&ObjEqComparator[*DFAState]{}), } if s, ok := atnStartState.(*StarLoopEntryState); ok && s.precedenceRuleDecision { dfa.precedenceDfa = true @@ -94,7 +95,7 @@ func (d *DFA) getPrecedenceDfa() bool { // true or nil otherwise, and d.precedenceDfa is updated. func (d *DFA) setPrecedenceDfa(precedenceDfa bool) { if d.getPrecedenceDfa() != precedenceDfa { - d.states = NewJStore[*DFAState, *DFAStateComparator[*DFAState]](&DFAStateComparator[*DFAState]{}) + d.states = NewJStore[*DFAState, *ObjEqComparator[*DFAState]](&ObjEqComparator[*DFAState]{}) d.numstates = 0 if precedenceDfa { diff --git a/runtime/Go/antlr/dfa_state.go b/runtime/Go/antlr/dfa_state.go index eba0997380..758d459cb8 100644 --- a/runtime/Go/antlr/dfa_state.go +++ b/runtime/Go/antlr/dfa_state.go @@ -81,45 +81,6 @@ type DFAState struct { predicates []*PredPrediction } -type DFAStateComparator[T Collectable[T]] struct { -} - -// NB This is the same as normal DFAState -func (c *DFAStateComparator[T]) equals(o1, o2 T) bool { - return o1.gequals(o2) -} - -// NB This is the same as normal DFAState -func (c *DFAStateComparator[T]) hash(o T) int { - return o.hash() -} - -// equals returns true if this DFA state is equal to another DFA state. -// Two [DFAState] instances are equal if their ATN configuration sets -// are the same. This method is used to see if a state already exists. -// -// Because the number of alternatives and number of ATN configurations are -// finite, there is a finite number of DFA states that can be processed. -// This is necessary to show that the algorithm terminates. -// -// Cannot test the DFA state numbers here because in [ParserATNSimulator.addDFAState] -// we need to know if any other state exists that has this exact set of ATN configurations. The -// [stateNumber] is irrelevant. -func (d *DFAState) gequals(o Collectable[*DFAState]) bool { - if d == o { - return true - } - - return d.configs.Equals(o.(*DFAState).configs) -} -func (d *DFAState) equals(o interface{}) bool { - if d == o { - return true - } - - return d.configs.Equals(o.(*DFAState).configs) -} - func NewDFAState(stateNumber int, configs ATNConfigSet) *DFAState { if configs == nil { configs = NewBaseATNConfigSet(false) @@ -129,16 +90,16 @@ func NewDFAState(stateNumber int, configs ATNConfigSet) *DFAState { } // GetAltSet gets the set of all alts mentioned by all ATN configurations in d. -func (d *DFAState) GetAltSet() Set { - alts := newArray2DHashSet(nil, nil) +func (d *DFAState) GetAltSet() []int { + var alts []int if d.configs != nil { for _, c := range d.configs.GetItems() { - alts.Add(c.GetAlt()) + alts = append(alts, c.GetAlt()) } } - if alts.Len() == 0 { + if len(alts) == 0 { return nil } @@ -169,27 +130,6 @@ func (d *DFAState) setPrediction(v int) { d.prediction = v } -// equals returns whether d equals other. Two DFAStates are equal if their ATN -// configuration sets are the same. This method is used to see if a state -// already exists. -// -// Because the number of alternatives and number of ATN configurations are -// finite, there is a finite number of DFA states that can be processed. This is -// necessary to show that the algorithm terminates. -// -// Cannot test the DFA state numbers here because in -// ParserATNSimulator.addDFAState we need to know if any other state exists that -// has d exact set of ATN configurations. The stateNumber is irrelevant. -// func (d *DFAState) equals(other interface{}) bool { -// if d == other { -// return true -// } else if _, ok := other.(*DFAState); !ok { -// return false -// } -// -// return d.configs.Equals(other.(*DFAState).configs) -//} - func (d *DFAState) String() string { var s string if d.isAcceptState { @@ -203,17 +143,27 @@ func (d *DFAState) String() string { return fmt.Sprintf("%d:%s%s", d.stateNumber, fmt.Sprint(d.configs), s) } -func (d *DFAState) hash() int { +func (d *DFAState) Hash() int { h := murmurInit(7) - h = murmurUpdate(h, d.configs.hash()) + h = murmurUpdate(h, d.configs.Hash()) return murmurFinish(h, 1) } -func (d *DFAState) Equals(o *DFAState) bool { - +// Equals returns whether d equals other. Two DFAStates are equal if their ATN +// configuration sets are the same. This method is used to see if a state +// already exists. +// +// Because the number of alternatives and number of ATN configurations are +// finite, there is a finite number of DFA states that can be processed. This is +// necessary to show that the algorithm terminates. +// +// Cannot test the DFA state numbers here because in +// ParserATNSimulator.addDFAState we need to know if any other state exists that +// has d exact set of ATN configurations. The stateNumber is irrelevant. +func (d *DFAState) Equals(o Collectable[*DFAState]) bool { if d == o { return true } - return d.configs.Equals(o.configs) + return d.configs.Equals(o.(*DFAState).configs) } diff --git a/runtime/Go/antlr/jcollect.go b/runtime/Go/antlr/jcollect.go index bdbbf8241c..d7660cd915 100644 --- a/runtime/Go/antlr/jcollect.go +++ b/runtime/Go/antlr/jcollect.go @@ -3,16 +3,15 @@ package antlr import "sort" // Collectable is an interface that a struct should implement if it is to be -// usable as a key in this collection. Cannot use the intuitive equals function -// here because the non-generic runtime has already claimed it. +// usable as a key in these collections. type Collectable[T any] interface { - hash() int - gequals(other Collectable[T]) bool + Hash() int + Equals(other Collectable[T]) bool } type Comparator[T any] interface { - hash(o T) int - equals(T, T) bool + Hash1(o T) int + Equals2(T, T) bool } // JStore implements a container that allows the use of a struct to calculate the key @@ -20,7 +19,7 @@ type Comparator[T any] interface { // serve the needs of the ANTLR Go runtime. // // For ease of porting the logic of the runtime from the master target (Java), this collection -// operates in a similar way to Java, in that it can use any struct that supplies a hash() and equals() +// operates in a similar way to Java, in that it can use any struct that supplies a Hash() and Equals() // function as the key. The values are stored in a standard go map which internally is a form of hashmap // itself, the key for the go map is the hash supplied by the key object. The collection is able to deal with // hash conflicts by using a simple slice of values associated with the hash code indexed bucket. That isn't @@ -58,11 +57,11 @@ func NewJStore[T any, C Comparator[T]](comparator Comparator[T]) *JStore[T, C] { // If the given value is not present in the store, then the value is added to the store and returned as v and exists is set to false. func (s *JStore[T, C]) Put(value T) (v T, exists bool) { //nolint:ireturn - kh := s.comparator.hash(value) + kh := s.comparator.Hash1(value) - for _, v := range s.store[kh] { - if s.comparator.equals(value, v) { - return v, true + for _, v1 := range s.store[kh] { + if s.comparator.Equals2(value, v1) { + return v1, true } } s.store[kh] = append(s.store[kh], value) @@ -75,16 +74,23 @@ func (s *JStore[T, C]) Put(value T) (v T, exists bool) { //nolint:ireturn // generated using the object we are going to store. func (s *JStore[T, C]) Get(key T) (T, bool) { //nolint:ireturn - kh := s.comparator.hash(key) + kh := s.comparator.Hash1(key) for _, v := range s.store[kh] { - if s.comparator.equals(key, v) { + if s.comparator.Equals2(key, v) { return v, true } } return key, false } +// Contains returns true if the given key is present in the store +func (s *JStore[T, C]) Contains(key T) bool { //nolint:ireturn + + _, present := s.Get(key) + return present +} + func (s *JStore[T, C]) SortedSlice(less func(i, j T) bool) []T { vs := make([]T, 0, len(s.store)) for _, v := range s.store { @@ -97,10 +103,28 @@ func (s *JStore[T, C]) SortedSlice(less func(i, j T) bool) []T { return vs } +func (s *JStore[T, C]) Each(f func(T) bool) { + for _, e := range s.store { + for _, v := range e { + f(v) + } + } +} + func (s *JStore[T, C]) Len() int { return s.len } +func (s *JStore[T, C]) Values() []T { + vs := make([]T, 0, len(s.store)) + for _, e := range s.store { + for _, v := range e { + vs = append(vs, v) + } + } + return vs +} + type entry[K, V any] struct { key K val V @@ -120,7 +144,7 @@ func NewJMap[K, V any, C Comparator[K]](comparator Comparator[K]) *JMap[K, V, C] } func (m *JMap[K, V, C]) Put(key K, val V) { - kh := m.comparator.hash(key) + kh := m.comparator.Hash1(key) m.store[kh] = append(m.store[kh], &entry[K, V]{key, val}) m.len++ } @@ -138,9 +162,9 @@ func (m *JMap[K, V, C]) Values() []V { func (m *JMap[K, V, C]) Get(key K) (V, bool) { var none V - kh := m.comparator.hash(key) + kh := m.comparator.Hash1(key) for _, e := range m.store[kh] { - if m.comparator.equals(e.key, key) { + if m.comparator.Equals2(e.key, key) { return e.val, true } } @@ -152,9 +176,9 @@ func (m *JMap[K, V, C]) Len() int { } func (m *JMap[K, V, C]) Delete(key K) { - kh := m.comparator.hash(key) + kh := m.comparator.Hash1(key) for i, e := range m.store[kh] { - if m.comparator.equals(e.key, key) { + if m.comparator.Equals2(e.key, key) { m.store[kh] = append(m.store[kh][:i], m.store[kh][i+1:]...) m.len-- return diff --git a/runtime/Go/antlr/lexer_action.go b/runtime/Go/antlr/lexer_action.go index 5a325be137..e023781202 100644 --- a/runtime/Go/antlr/lexer_action.go +++ b/runtime/Go/antlr/lexer_action.go @@ -21,8 +21,8 @@ type LexerAction interface { getActionType() int getIsPositionDependent() bool execute(lexer Lexer) - hash() int - equals(other LexerAction) bool + Hash() int + Equals(other LexerAction) bool } type BaseLexerAction struct { @@ -51,15 +51,14 @@ func (b *BaseLexerAction) getIsPositionDependent() bool { return b.isPositionDependent } -func (b *BaseLexerAction) hash() int { +func (b *BaseLexerAction) Hash() int { return b.actionType } -func (b *BaseLexerAction) equals(other LexerAction) bool { +func (b *BaseLexerAction) Equals(other LexerAction) bool { return b == other } -// // Implements the {@code Skip} lexer action by calling {@link Lexer//Skip}. // //

The {@code Skip} command does not have any parameters, so l action is @@ -85,7 +84,8 @@ func (l *LexerSkipAction) String() string { return "skip" } -// Implements the {@code type} lexer action by calling {@link Lexer//setType} +// Implements the {@code type} lexer action by calling {@link Lexer//setType} +// // with the assigned type. type LexerTypeAction struct { *BaseLexerAction @@ -104,14 +104,14 @@ func (l *LexerTypeAction) execute(lexer Lexer) { lexer.SetType(l.thetype) } -func (l *LexerTypeAction) hash() int { +func (l *LexerTypeAction) Hash() int { h := murmurInit(0) h = murmurUpdate(h, l.actionType) h = murmurUpdate(h, l.thetype) return murmurFinish(h, 2) } -func (l *LexerTypeAction) equals(other LexerAction) bool { +func (l *LexerTypeAction) Equals(other LexerAction) bool { if l == other { return true } else if _, ok := other.(*LexerTypeAction); !ok { @@ -148,14 +148,14 @@ func (l *LexerPushModeAction) execute(lexer Lexer) { lexer.PushMode(l.mode) } -func (l *LexerPushModeAction) hash() int { +func (l *LexerPushModeAction) Hash() int { h := murmurInit(0) h = murmurUpdate(h, l.actionType) h = murmurUpdate(h, l.mode) return murmurFinish(h, 2) } -func (l *LexerPushModeAction) equals(other LexerAction) bool { +func (l *LexerPushModeAction) Equals(other LexerAction) bool { if l == other { return true } else if _, ok := other.(*LexerPushModeAction); !ok { @@ -245,14 +245,14 @@ func (l *LexerModeAction) execute(lexer Lexer) { lexer.SetMode(l.mode) } -func (l *LexerModeAction) hash() int { +func (l *LexerModeAction) Hash() int { h := murmurInit(0) h = murmurUpdate(h, l.actionType) h = murmurUpdate(h, l.mode) return murmurFinish(h, 2) } -func (l *LexerModeAction) equals(other LexerAction) bool { +func (l *LexerModeAction) Equals(other LexerAction) bool { if l == other { return true } else if _, ok := other.(*LexerModeAction); !ok { @@ -303,7 +303,7 @@ func (l *LexerCustomAction) execute(lexer Lexer) { lexer.Action(nil, l.ruleIndex, l.actionIndex) } -func (l *LexerCustomAction) hash() int { +func (l *LexerCustomAction) Hash() int { h := murmurInit(0) h = murmurUpdate(h, l.actionType) h = murmurUpdate(h, l.ruleIndex) @@ -311,13 +311,14 @@ func (l *LexerCustomAction) hash() int { return murmurFinish(h, 3) } -func (l *LexerCustomAction) equals(other LexerAction) bool { +func (l *LexerCustomAction) Equals(other LexerAction) bool { if l == other { return true } else if _, ok := other.(*LexerCustomAction); !ok { return false } else { - return l.ruleIndex == other.(*LexerCustomAction).ruleIndex && l.actionIndex == other.(*LexerCustomAction).actionIndex + return l.ruleIndex == other.(*LexerCustomAction).ruleIndex && + l.actionIndex == other.(*LexerCustomAction).actionIndex } } @@ -344,14 +345,14 @@ func (l *LexerChannelAction) execute(lexer Lexer) { lexer.SetChannel(l.channel) } -func (l *LexerChannelAction) hash() int { +func (l *LexerChannelAction) Hash() int { h := murmurInit(0) h = murmurUpdate(h, l.actionType) h = murmurUpdate(h, l.channel) return murmurFinish(h, 2) } -func (l *LexerChannelAction) equals(other LexerAction) bool { +func (l *LexerChannelAction) Equals(other LexerAction) bool { if l == other { return true } else if _, ok := other.(*LexerChannelAction); !ok { @@ -412,10 +413,10 @@ func (l *LexerIndexedCustomAction) execute(lexer Lexer) { l.lexerAction.execute(lexer) } -func (l *LexerIndexedCustomAction) hash() int { +func (l *LexerIndexedCustomAction) Hash() int { h := murmurInit(0) h = murmurUpdate(h, l.offset) - h = murmurUpdate(h, l.lexerAction.hash()) + h = murmurUpdate(h, l.lexerAction.Hash()) return murmurFinish(h, 2) } @@ -425,6 +426,7 @@ func (l *LexerIndexedCustomAction) equals(other LexerAction) bool { } else if _, ok := other.(*LexerIndexedCustomAction); !ok { return false } else { - return l.offset == other.(*LexerIndexedCustomAction).offset && l.lexerAction == other.(*LexerIndexedCustomAction).lexerAction + return l.offset == other.(*LexerIndexedCustomAction).offset && + l.lexerAction.Equals(other.(*LexerIndexedCustomAction).lexerAction) } } diff --git a/runtime/Go/antlr/lexer_action_executor.go b/runtime/Go/antlr/lexer_action_executor.go index 056941dd6e..6bae3a206f 100644 --- a/runtime/Go/antlr/lexer_action_executor.go +++ b/runtime/Go/antlr/lexer_action_executor.go @@ -4,6 +4,8 @@ package antlr +import "golang.org/x/exp/slices" + // Represents an executor for a sequence of lexer actions which traversed during // the Matching operation of a lexer rule (token). // @@ -12,8 +14,8 @@ package antlr // not cause bloating of the {@link DFA} created for the lexer.

type LexerActionExecutor struct { - lexerActions []LexerAction - cachedHash int + lexerActions []LexerAction + cachedHash int } func NewLexerActionExecutor(lexerActions []LexerAction) *LexerActionExecutor { @@ -30,7 +32,7 @@ func NewLexerActionExecutor(lexerActions []LexerAction) *LexerActionExecutor { // of the performance-critical {@link LexerATNConfig//hashCode} operation. l.cachedHash = murmurInit(57) for _, a := range lexerActions { - l.cachedHash = murmurUpdate(l.cachedHash, a.hash()) + l.cachedHash = murmurUpdate(l.cachedHash, a.Hash()) } return l @@ -151,14 +153,17 @@ func (l *LexerActionExecutor) execute(lexer Lexer, input CharStream, startIndex } } -func (l *LexerActionExecutor) hash() int { +func (l *LexerActionExecutor) Hash() int { if l == nil { + // TODO: Why is this here? l should not be nil return 61 } + + // TODO: This is created from the action itself when the struct is created - will this be an issue at some point? Java uses the runtime assign hashcode return l.cachedHash } -func (l *LexerActionExecutor) equals(other interface{}) bool { +func (l *LexerActionExecutor) Equals(other interface{}) bool { if l == other { return true } @@ -169,5 +174,13 @@ func (l *LexerActionExecutor) equals(other interface{}) bool { if othert == nil { return false } - return l.cachedHash == othert.cachedHash && &l.lexerActions == &othert.lexerActions + if l.cachedHash != othert.cachedHash { + return false + } + if len(l.lexerActions) != len(othert.lexerActions) { + return false + } + return slices.EqualFunc(l.lexerActions, othert.lexerActions, func(i, j LexerAction) bool { + return i.Equals(j) + }) } diff --git a/runtime/Go/antlr/lexer_atn_simulator.go b/runtime/Go/antlr/lexer_atn_simulator.go index 8906d02e6f..bdef6fc8bd 100644 --- a/runtime/Go/antlr/lexer_atn_simulator.go +++ b/runtime/Go/antlr/lexer_atn_simulator.go @@ -595,7 +595,7 @@ func (l *LexerATNSimulator) addDFAState(configs ATNConfigSet, suppressEdge bool) l.atn.stateMu.Lock() defer l.atn.stateMu.Unlock() - existing, present := dfa.states.Put(proposed) + existing, present := dfa.states.Get(proposed) if present { // This state was already present, so just return it. @@ -603,12 +603,12 @@ func (l *LexerATNSimulator) addDFAState(configs ATNConfigSet, suppressEdge bool) proposed = existing } else { - // The proposed state has already been added to the DFA. We still have the pointer, so - // we can modify it even though it is stored already. + // We need to add the new state // - proposed.stateNumber = dfa.states.Len() - 1 + proposed.stateNumber = dfa.states.Len() configs.SetReadOnly(true) proposed.configs = configs + dfa.states.Put(proposed) } if !suppressEdge { dfa.setS0(proposed) diff --git a/runtime/Go/antlr/ll1_analyzer.go b/runtime/Go/antlr/ll1_analyzer.go index 6ffb37de69..f4007ec323 100644 --- a/runtime/Go/antlr/ll1_analyzer.go +++ b/runtime/Go/antlr/ll1_analyzer.go @@ -14,14 +14,15 @@ func NewLL1Analyzer(atn *ATN) *LL1Analyzer { return la } -//* Special value added to the lookahead sets to indicate that we hit -// a predicate during analysis if {@code seeThruPreds==false}. -/// +// - Special value added to the lookahead sets to indicate that we hit +// a predicate during analysis if {@code seeThruPreds==false}. +// +// / const ( LL1AnalyzerHitPred = TokenInvalidType ) -//* +// * // Calculates the SLL(1) expected lookahead set for each outgoing transition // of an {@link ATNState}. The returned array has one element for each // outgoing transition in {@code s}. If the closure from transition @@ -38,7 +39,7 @@ func (la *LL1Analyzer) getDecisionLookahead(s ATNState) []*IntervalSet { look := make([]*IntervalSet, count) for alt := 0; alt < count; alt++ { look[alt] = NewIntervalSet() - lookBusy := newArray2DHashSet(nil, nil) + lookBusy := NewJStore[ATNConfig, Comparator[ATNConfig]](&ObjEqComparator[ATNConfig]{}) seeThruPreds := false // fail to get lookahead upon pred la.look1(s.GetTransitions()[alt].getTarget(), nil, BasePredictionContextEMPTY, look[alt], lookBusy, NewBitSet(), seeThruPreds, false) // Wipe out lookahead for la alternative if we found nothing @@ -50,7 +51,7 @@ func (la *LL1Analyzer) getDecisionLookahead(s ATNState) []*IntervalSet { return look } -//* +// * // Compute set of tokens that can follow {@code s} in the ATN in the // specified {@code ctx}. // @@ -67,7 +68,7 @@ func (la *LL1Analyzer) getDecisionLookahead(s ATNState) []*IntervalSet { // // @return The set of tokens that can follow {@code s} in the ATN in the // specified {@code ctx}. -/// +// / func (la *LL1Analyzer) Look(s, stopState ATNState, ctx RuleContext) *IntervalSet { r := NewIntervalSet() seeThruPreds := true // ignore preds get all lookahead @@ -75,7 +76,7 @@ func (la *LL1Analyzer) Look(s, stopState ATNState, ctx RuleContext) *IntervalSet if ctx != nil { lookContext = predictionContextFromRuleContext(s.GetATN(), ctx) } - la.look1(s, stopState, lookContext, r, newArray2DHashSet(nil, nil), NewBitSet(), seeThruPreds, true) + la.look1(s, stopState, lookContext, r, NewJStore[ATNConfig, Comparator[ATNConfig]](&ObjEqComparator[ATNConfig]{}), NewBitSet(), seeThruPreds, true) return r } @@ -109,14 +110,14 @@ func (la *LL1Analyzer) Look(s, stopState ATNState, ctx RuleContext) *IntervalSet // outermost context is reached. This parameter has no effect if {@code ctx} // is {@code nil}. -func (la *LL1Analyzer) look2(s, stopState ATNState, ctx PredictionContext, look *IntervalSet, lookBusy Set, calledRuleStack *BitSet, seeThruPreds, addEOF bool, i int) { +func (la *LL1Analyzer) look2(s, stopState ATNState, ctx PredictionContext, look *IntervalSet, lookBusy *JStore[ATNConfig, Comparator[ATNConfig]], calledRuleStack *BitSet, seeThruPreds, addEOF bool, i int) { returnState := la.atn.states[ctx.getReturnState(i)] la.look1(returnState, stopState, ctx.GetParent(i), look, lookBusy, calledRuleStack, seeThruPreds, addEOF) } -func (la *LL1Analyzer) look1(s, stopState ATNState, ctx PredictionContext, look *IntervalSet, lookBusy Set, calledRuleStack *BitSet, seeThruPreds, addEOF bool) { +func (la *LL1Analyzer) look1(s, stopState ATNState, ctx PredictionContext, look *IntervalSet, lookBusy *JStore[ATNConfig, Comparator[ATNConfig]], calledRuleStack *BitSet, seeThruPreds, addEOF bool) { c := NewBaseATNConfig6(s, 0, ctx) @@ -124,8 +125,11 @@ func (la *LL1Analyzer) look1(s, stopState ATNState, ctx PredictionContext, look return } - lookBusy.Add(c) + _, present := lookBusy.Put(c) + if present { + return + } if s == stopState { if ctx == nil { look.addOne(TokenEpsilon) @@ -198,7 +202,7 @@ func (la *LL1Analyzer) look1(s, stopState ATNState, ctx PredictionContext, look } } -func (la *LL1Analyzer) look3(stopState ATNState, ctx PredictionContext, look *IntervalSet, lookBusy Set, calledRuleStack *BitSet, seeThruPreds, addEOF bool, t1 *RuleTransition) { +func (la *LL1Analyzer) look3(stopState ATNState, ctx PredictionContext, look *IntervalSet, lookBusy *JStore[ATNConfig, Comparator[ATNConfig]], calledRuleStack *BitSet, seeThruPreds, addEOF bool, t1 *RuleTransition) { newContext := SingletonBasePredictionContextCreate(ctx, t1.followState.GetStateNumber()) diff --git a/runtime/Go/antlr/parser_atn_simulator.go b/runtime/Go/antlr/parser_atn_simulator.go index cd725167a5..35bee83113 100644 --- a/runtime/Go/antlr/parser_atn_simulator.go +++ b/runtime/Go/antlr/parser_atn_simulator.go @@ -570,7 +570,7 @@ func (p *ParserATNSimulator) computeReachSet(closure ATNConfigSet, t int, fullCt // if reach == nil { reach = NewBaseATNConfigSet(fullCtx) - closureBusy := newArray2DHashSet(nil, nil) + closureBusy := NewJStore[ATNConfig, Comparator[ATNConfig]](&ObjEqComparator[ATNConfig]{}) treatEOFAsEpsilon := t == TokenEOF amount := len(intermediate.configs) for k := 0; k < amount; k++ { @@ -663,7 +663,7 @@ func (p *ParserATNSimulator) computeStartState(a ATNState, ctx RuleContext, full for i := 0; i < len(a.GetTransitions()); i++ { target := a.GetTransitions()[i].getTarget() c := NewBaseATNConfig6(target, i+1, initialContext) - closureBusy := newArray2DHashSet(nil, nil) + closureBusy := NewJStore[ATNConfig, Comparator[ATNConfig]](&BaseATNConfigComparator[ATNConfig]{}) p.closure(c, configs, closureBusy, true, fullCtx, false) } return configs @@ -756,7 +756,7 @@ func (p *ParserATNSimulator) applyPrecedenceFilter(configs ATNConfigSet) ATNConf // (basically a graph subtraction algorithm). if !config.getPrecedenceFilterSuppressed() { context := statesFromAlt1[config.GetState().GetStateNumber()] - if context != nil && context.gequals(config.GetContext()) { + if context != nil && context.Equals(config.GetContext()) { // eliminated continue } @@ -966,13 +966,13 @@ func (p *ParserATNSimulator) evalSemanticContext(predPredictions []*PredPredicti return predictions } -func (p *ParserATNSimulator) closure(config ATNConfig, configs ATNConfigSet, closureBusy Set, collectPredicates, fullCtx, treatEOFAsEpsilon bool) { +func (p *ParserATNSimulator) closure(config ATNConfig, configs ATNConfigSet, closureBusy *JStore[ATNConfig, Comparator[ATNConfig]], collectPredicates, fullCtx, treatEOFAsEpsilon bool) { initialDepth := 0 p.closureCheckingStopState(config, configs, closureBusy, collectPredicates, fullCtx, initialDepth, treatEOFAsEpsilon) } -func (p *ParserATNSimulator) closureCheckingStopState(config ATNConfig, configs ATNConfigSet, closureBusy Set, collectPredicates, fullCtx bool, depth int, treatEOFAsEpsilon bool) { +func (p *ParserATNSimulator) closureCheckingStopState(config ATNConfig, configs ATNConfigSet, closureBusy *JStore[ATNConfig, Comparator[ATNConfig]], collectPredicates, fullCtx bool, depth int, treatEOFAsEpsilon bool) { if ParserATNSimulatorDebug { fmt.Println("closure(" + config.String() + ")") fmt.Println("configs(" + configs.String() + ")") @@ -1025,7 +1025,7 @@ func (p *ParserATNSimulator) closureCheckingStopState(config ATNConfig, configs } // Do the actual work of walking epsilon edges// -func (p *ParserATNSimulator) closureWork(config ATNConfig, configs ATNConfigSet, closureBusy Set, collectPredicates, fullCtx bool, depth int, treatEOFAsEpsilon bool) { +func (p *ParserATNSimulator) closureWork(config ATNConfig, configs ATNConfigSet, closureBusy *JStore[ATNConfig, Comparator[ATNConfig]], collectPredicates, fullCtx bool, depth int, treatEOFAsEpsilon bool) { state := config.GetState() // optimization if !state.GetEpsilonOnlyTransitions() { @@ -1060,7 +1060,8 @@ func (p *ParserATNSimulator) closureWork(config ATNConfig, configs ATNConfigSet, c.SetReachesIntoOuterContext(c.GetReachesIntoOuterContext() + 1) - if closureBusy.Add(c) != c { + _, present := closureBusy.Put(c) + if present { // avoid infinite recursion for right-recursive rules continue } @@ -1071,9 +1072,13 @@ func (p *ParserATNSimulator) closureWork(config ATNConfig, configs ATNConfigSet, fmt.Println("dips into outer ctx: " + c.String()) } } else { - if !t.getIsEpsilon() && closureBusy.Add(c) != c { - // avoid infinite recursion for EOF* and EOF+ - continue + + if !t.getIsEpsilon() { + _, present := closureBusy.Put(c) + if present { + // avoid infinite recursion for EOF* and EOF+ + continue + } } if _, ok := t.(*RuleTransition); ok { // latch when newDepth goes negative - once we step out of the entry context we can't return @@ -1490,18 +1495,19 @@ func (p *ParserATNSimulator) addDFAState(dfa *DFA, d *DFAState) *DFAState { if d == ATNSimulatorError { return d } - existing, present := dfa.states.Put(d) + existing, present := dfa.states.Get(d) if present { return existing } // The state was not present, so update it with configs // - d.stateNumber = dfa.states.Len() - 1 + d.stateNumber = dfa.states.Len() if !d.configs.ReadOnly() { d.configs.OptimizeConfigs(p.BaseATNSimulator) d.configs.SetReadOnly(true) } + dfa.states.Put(d) if ParserATNSimulatorDebug { fmt.Println("adding NewDFA state: " + d.String()) } diff --git a/runtime/Go/antlr/prediction_context.go b/runtime/Go/antlr/prediction_context.go index d388d7ca31..ded6a0e2ee 100644 --- a/runtime/Go/antlr/prediction_context.go +++ b/runtime/Go/antlr/prediction_context.go @@ -5,6 +5,7 @@ package antlr import ( + "golang.org/x/exp/slices" "strconv" ) @@ -26,10 +27,10 @@ var ( ) type PredictionContext interface { - hash() int + Hash() int + Equals(interface{}) bool GetParent(int) PredictionContext getReturnState(int) int - gequals(Collectable[PredictionContext]) bool length() int isEmpty() bool hasEmptyPath() bool @@ -53,7 +54,7 @@ func (b *BasePredictionContext) isEmpty() bool { func calculateHash(parent PredictionContext, returnState int) int { h := murmurInit(1) - h = murmurUpdate(h, parent.hash()) + h = murmurUpdate(h, parent.Hash()) h = murmurUpdate(h, returnState) return murmurFinish(h, 2) } @@ -158,32 +159,29 @@ func (b *BaseSingletonPredictionContext) getReturnState(index int) int { func (b *BaseSingletonPredictionContext) hasEmptyPath() bool { return b.returnState == BasePredictionContextEmptyReturnState } -func (b *BaseSingletonPredictionContext) equals(other interface{}) bool { - return b.gequals(other.(Collectable[PredictionContext])) + +func (b *BaseSingletonPredictionContext) Hash() int { + return b.cachedHash } -func (b *BaseSingletonPredictionContext) gequals(other Collectable[PredictionContext]) bool { +func (b *BaseSingletonPredictionContext) Equals(other interface{}) bool { if b == other { return true - } else if _, ok := other.(*BaseSingletonPredictionContext); !ok { + } + if _, ok := other.(*BaseSingletonPredictionContext); !ok { return false - } else if b.hash() != other.hash() { - return false // can't be same if hash is different } otherP := other.(*BaseSingletonPredictionContext) if b.returnState != otherP.getReturnState(0) { return false - } else if b.parentCtx == nil { + } + if b.parentCtx == nil { return otherP.parentCtx == nil } - return b.parentCtx.gequals(otherP.parentCtx) -} - -func (b *BaseSingletonPredictionContext) hash() int { - return b.cachedHash + return b.parentCtx.Equals(otherP.parentCtx) } func (b *BaseSingletonPredictionContext) String() string { @@ -217,7 +215,7 @@ func NewEmptyPredictionContext() *EmptyPredictionContext { p := new(EmptyPredictionContext) p.BaseSingletonPredictionContext = NewBaseSingletonPredictionContext(nil, BasePredictionContextEmptyReturnState) - + p.cachedHash = calculateEmptyHash() return p } @@ -232,11 +230,12 @@ func (e *EmptyPredictionContext) GetParent(index int) PredictionContext { func (e *EmptyPredictionContext) getReturnState(index int) int { return e.returnState } -func (e *EmptyPredictionContext) equals(other Collectable[PredictionContext]) bool { - return e == other + +func (e *EmptyPredictionContext) Hash() int { + return e.cachedHash } -func (e *EmptyPredictionContext) gequals(other Collectable[PredictionContext]) bool { +func (e *EmptyPredictionContext) Equals(other interface{}) bool { return e == other } @@ -259,7 +258,7 @@ func NewArrayPredictionContext(parents []PredictionContext, returnStates []int) hash := murmurInit(1) for _, parent := range parents { - hash = murmurUpdate(hash, parent.hash()) + hash = murmurUpdate(hash, parent.Hash()) } for _, returnState := range returnStates { @@ -303,22 +302,31 @@ func (a *ArrayPredictionContext) getReturnState(index int) int { return a.returnStates[index] } -func (a *ArrayPredictionContext) equals(other interface{}) bool { - return a.gequals(other.(*ArrayPredictionContext)) -} - -func (a *ArrayPredictionContext) gequals(other Collectable[PredictionContext]) bool { - if _, ok := other.(*ArrayPredictionContext); !ok { +// Equals is the default comparison function for ArrayPredictionContext when no specialized +// implementation is needed for a collection +func (a *ArrayPredictionContext) Equals(o interface{}) bool { + if a == o { + return true + } + other, ok := o.(*ArrayPredictionContext) + if !ok { return false - } else if a.cachedHash != other.hash() { + } + if a.cachedHash != other.Hash() { return false // can't be same if hash is different - } else { - otherP := other.(*ArrayPredictionContext) - return &a.returnStates == &otherP.returnStates && &a.parents == &otherP.parents } + + // Must compare the actual array elements and not just the array address + // + return slices.Equal(a.returnStates, other.returnStates) && + slices.EqualFunc(a.parents, other.parents, func(x, y PredictionContext) bool { + return x.Equals(y) + }) } -func (a *ArrayPredictionContext) hash() int { +// Hash is the default hash function for ArrayPredictionContext when no specialized +// implementation is needed for a collection +func (a *ArrayPredictionContext) Hash() int { return a.BasePredictionContext.cachedHash } @@ -431,11 +439,11 @@ func merge(a, b PredictionContext, rootIsWildcard bool, mergeCache *DoubleDict) // / func mergeSingletons(a, b *BaseSingletonPredictionContext, rootIsWildcard bool, mergeCache *DoubleDict) PredictionContext { if mergeCache != nil { - previous := mergeCache.Get(a.hash(), b.hash()) + previous := mergeCache.Get(a.Hash(), b.Hash()) if previous != nil { return previous.(PredictionContext) } - previous = mergeCache.Get(b.hash(), a.hash()) + previous = mergeCache.Get(b.Hash(), a.Hash()) if previous != nil { return previous.(PredictionContext) } @@ -444,7 +452,7 @@ func mergeSingletons(a, b *BaseSingletonPredictionContext, rootIsWildcard bool, rootMerge := mergeRoot(a, b, rootIsWildcard) if rootMerge != nil { if mergeCache != nil { - mergeCache.set(a.hash(), b.hash(), rootMerge) + mergeCache.set(a.Hash(), b.Hash(), rootMerge) } return rootMerge } @@ -464,7 +472,7 @@ func mergeSingletons(a, b *BaseSingletonPredictionContext, rootIsWildcard bool, // Newjoined parent so create Newsingleton pointing to it, a' spc := SingletonBasePredictionContextCreate(parent, a.returnState) if mergeCache != nil { - mergeCache.set(a.hash(), b.hash(), spc) + mergeCache.set(a.Hash(), b.Hash(), spc) } return spc } @@ -486,7 +494,7 @@ func mergeSingletons(a, b *BaseSingletonPredictionContext, rootIsWildcard bool, parents := []PredictionContext{singleParent, singleParent} apc := NewArrayPredictionContext(parents, payloads) if mergeCache != nil { - mergeCache.set(a.hash(), b.hash(), apc) + mergeCache.set(a.Hash(), b.Hash(), apc) } return apc } @@ -502,7 +510,7 @@ func mergeSingletons(a, b *BaseSingletonPredictionContext, rootIsWildcard bool, } apc := NewArrayPredictionContext(parents, payloads) if mergeCache != nil { - mergeCache.set(a.hash(), b.hash(), apc) + mergeCache.set(a.Hash(), b.Hash(), apc) } return apc } @@ -589,11 +597,11 @@ func mergeRoot(a, b SingletonPredictionContext, rootIsWildcard bool) PredictionC // / func mergeArrays(a, b *ArrayPredictionContext, rootIsWildcard bool, mergeCache *DoubleDict) PredictionContext { if mergeCache != nil { - previous := mergeCache.Get(a.hash(), b.hash()) + previous := mergeCache.Get(a.Hash(), b.Hash()) if previous != nil { return previous.(PredictionContext) } - previous = mergeCache.Get(b.hash(), a.hash()) + previous = mergeCache.Get(b.Hash(), a.Hash()) if previous != nil { return previous.(PredictionContext) } @@ -614,7 +622,7 @@ func mergeArrays(a, b *ArrayPredictionContext, rootIsWildcard bool, mergeCache * payload := a.returnStates[i] // $+$ = $ bothDollars := payload == BasePredictionContextEmptyReturnState && aParent == nil && bParent == nil - axAX := (aParent != nil && bParent != nil && aParent == bParent) // ax+ax + axAX := aParent != nil && bParent != nil && aParent == bParent // ax+ax // -> // ax if bothDollars || axAX { @@ -657,7 +665,7 @@ func mergeArrays(a, b *ArrayPredictionContext, rootIsWildcard bool, mergeCache * if k == 1 { // for just one merged element, return singleton top pc := SingletonBasePredictionContextCreate(mergedParents[0], mergedReturnStates[0]) if mergeCache != nil { - mergeCache.set(a.hash(), b.hash(), pc) + mergeCache.set(a.Hash(), b.Hash(), pc) } return pc } @@ -671,20 +679,20 @@ func mergeArrays(a, b *ArrayPredictionContext, rootIsWildcard bool, mergeCache * // TODO: track whether this is possible above during merge sort for speed if M == a { if mergeCache != nil { - mergeCache.set(a.hash(), b.hash(), a) + mergeCache.set(a.Hash(), b.Hash(), a) } return a } if M == b { if mergeCache != nil { - mergeCache.set(a.hash(), b.hash(), b) + mergeCache.set(a.Hash(), b.Hash(), b) } return b } combineCommonParents(mergedParents) if mergeCache != nil { - mergeCache.set(a.hash(), b.hash(), M) + mergeCache.set(a.Hash(), b.Hash(), M) } return M } diff --git a/runtime/Go/antlr/prediction_mode.go b/runtime/Go/antlr/prediction_mode.go index eca8ddc1a1..79b466d859 100644 --- a/runtime/Go/antlr/prediction_mode.go +++ b/runtime/Go/antlr/prediction_mode.go @@ -460,29 +460,7 @@ func PredictionModeGetAlts(altsets []*BitSet) *BitSet { return all } -type ATNComparator[T Collectable[T]] struct { -} - -// NB This is the same as normal DFAState -func (a *ATNComparator[T]) equals(o1, o2 ATNConfig) bool { - if o1 == o2 { - return true - } - if o1 != nil && o2 != nil { - return o1.GetState().GetStateNumber() == o2.GetState().GetStateNumber() - } - return false -} - -// NB This is the same as normal DFAState -func (a *ATNComparator[T]) hash(o ATNConfig) int { - h := murmurInit(7) - h = murmurUpdate(h, o.GetState().GetStateNumber()) - h = murmurUpdate(h, o.GetContext().hash()) - return murmurFinish(h, 2) -} - -// This func gets the conflicting alt subsets from a configuration set. +// PredictionModegetConflictingAltSubsets gets the conflicting alt subsets from a configuration set. // For each configuration {@code c} in {@code configs}: // //
@@ -490,7 +468,7 @@ func (a *ATNComparator[T]) hash(o ATNConfig) int {
 // alt and not pred
 // 
func PredictionModegetConflictingAltSubsets(configs ATNConfigSet) []*BitSet { - configToAlts := NewJMap[ATNConfig, *BitSet, *ATNComparator[ATNConfig]](&ATNComparator[ATNConfig]{}) + configToAlts := NewJMap[ATNConfig, *BitSet, *ATNAltConfigComparator[ATNConfig]](&ATNAltConfigComparator[ATNConfig]{}) for _, c := range configs.GetItems() { @@ -505,7 +483,7 @@ func PredictionModegetConflictingAltSubsets(configs ATNConfigSet) []*BitSet { return configToAlts.Values() } -// Get a map from state to alt subset from a configuration set. For each +// PredictionModeGetStateToAltMap gets a map from state to alt subset from a configuration set. For each // configuration {@code c} in {@code configs}: // //
diff --git a/runtime/Go/antlr/semantic_context.go b/runtime/Go/antlr/semantic_context.go
index 4c54429fd6..d4f3c93fd2 100644
--- a/runtime/Go/antlr/semantic_context.go
+++ b/runtime/Go/antlr/semantic_context.go
@@ -18,12 +18,12 @@ import (
 //
 
 type SemanticContext interface {
-	comparable
+	Equals(other Collectable[SemanticContext]) bool
+	Hash() int
 
 	evaluate(parser Recognizer, outerContext RuleContext) bool
 	evalPrecedence(parser Recognizer, outerContext RuleContext) SemanticContext
 
-	hash() int
 	String() string
 }
 
@@ -95,7 +95,7 @@ func (p *Predicate) evaluate(parser Recognizer, outerContext RuleContext) bool {
 	return parser.Sempred(localctx, p.ruleIndex, p.predIndex)
 }
 
-func (p *Predicate) equals(other interface{}) bool {
+func (p *Predicate) Equals(other Collectable[SemanticContext]) bool {
 	if p == other {
 		return true
 	} else if _, ok := other.(*Predicate); !ok {
@@ -107,7 +107,7 @@ func (p *Predicate) equals(other interface{}) bool {
 	}
 }
 
-func (p *Predicate) hash() int {
+func (p *Predicate) Hash() int {
 	h := murmurInit(0)
 	h = murmurUpdate(h, p.ruleIndex)
 	h = murmurUpdate(h, p.predIndex)
@@ -151,17 +151,22 @@ func (p *PrecedencePredicate) compareTo(other *PrecedencePredicate) int {
 	return p.precedence - other.precedence
 }
 
-func (p *PrecedencePredicate) equals(other interface{}) bool {
-	if p == other {
-		return true
-	} else if _, ok := other.(*PrecedencePredicate); !ok {
+func (p *PrecedencePredicate) Equals(other Collectable[SemanticContext]) bool {
+
+	var op *PrecedencePredicate
+	var ok bool
+	if op, ok = other.(*PrecedencePredicate); !ok {
 		return false
-	} else {
-		return p.precedence == other.(*PrecedencePredicate).precedence
 	}
+
+	if p == op {
+		return true
+	}
+
+	return p.precedence == other.(*PrecedencePredicate).precedence
 }
 
-func (p *PrecedencePredicate) hash() int {
+func (p *PrecedencePredicate) Hash() int {
 	h := uint32(1)
 	h = 31*h + uint32(p.precedence)
 	return int(h)
@@ -171,10 +176,10 @@ func (p *PrecedencePredicate) String() string {
 	return "{" + strconv.Itoa(p.precedence) + ">=prec}?"
 }
 
-func PrecedencePredicatefilterPrecedencePredicates(set Set) []*PrecedencePredicate {
+func PrecedencePredicatefilterPrecedencePredicates(set *JStore[SemanticContext, Comparator[SemanticContext]]) []*PrecedencePredicate {
 	result := make([]*PrecedencePredicate, 0)
 
-	set.Each(func(v interface{}) bool {
+	set.Each(func(v SemanticContext) bool {
 		if c2, ok := v.(*PrecedencePredicate); ok {
 			result = append(result, c2)
 		}
@@ -193,21 +198,21 @@ type AND struct {
 
 func NewAND(a, b SemanticContext) *AND {
 
-	operands := newArray2DHashSet(nil, nil)
+	operands := NewJStore[SemanticContext, Comparator[SemanticContext]](&ObjEqComparator[SemanticContext]{})
 	if aa, ok := a.(*AND); ok {
 		for _, o := range aa.opnds {
-			operands.Add(o)
+			operands.Put(o)
 		}
 	} else {
-		operands.Add(a)
+		operands.Put(a)
 	}
 
 	if ba, ok := b.(*AND); ok {
 		for _, o := range ba.opnds {
-			operands.Add(o)
+			operands.Put(o)
 		}
 	} else {
-		operands.Add(b)
+		operands.Put(b)
 	}
 	precedencePredicates := PrecedencePredicatefilterPrecedencePredicates(operands)
 	if len(precedencePredicates) > 0 {
@@ -220,7 +225,7 @@ func NewAND(a, b SemanticContext) *AND {
 			}
 		}
 
-		operands.Add(reduced)
+		operands.Put(reduced)
 	}
 
 	vs := operands.Values()
@@ -235,14 +240,15 @@ func NewAND(a, b SemanticContext) *AND {
 	return and
 }
 
-func (a *AND) equals(other interface{}) bool {
+func (a *AND) Equals(other Collectable[SemanticContext]) bool {
 	if a == other {
 		return true
-	} else if _, ok := other.(*AND); !ok {
+	}
+	if _, ok := other.(*AND); !ok {
 		return false
 	} else {
 		for i, v := range other.(*AND).opnds {
-			if !a.opnds[i].equals(v) {
+			if !a.opnds[i].Equals(v) {
 				return false
 			}
 		}
@@ -250,13 +256,11 @@ func (a *AND) equals(other interface{}) bool {
 	}
 }
 
-//
 // {@inheritDoc}
 //
 // 

// The evaluation of predicates by a context is short-circuiting, but // unordered.

-// func (a *AND) evaluate(parser Recognizer, outerContext RuleContext) bool { for i := 0; i < len(a.opnds); i++ { if !a.opnds[i].evaluate(parser, outerContext) { @@ -304,18 +308,18 @@ func (a *AND) evalPrecedence(parser Recognizer, outerContext RuleContext) Semant return result } -func (a *AND) hash() int { +func (a *AND) Hash() int { h := murmurInit(37) // Init with a value different from OR for _, op := range a.opnds { - h = murmurUpdate(h, op.hash()) + h = murmurUpdate(h, op.Hash()) } return murmurFinish(h, len(a.opnds)) } -func (a *OR) hash() int { +func (a *OR) Hash() int { h := murmurInit(41) // Init with a value different from AND for _, op := range a.opnds { - h = murmurUpdate(h, op.hash()) + h = murmurUpdate(h, op.Hash()) } return murmurFinish(h, len(a.opnds)) } @@ -345,21 +349,21 @@ type OR struct { func NewOR(a, b SemanticContext) *OR { - operands := newArray2DHashSet(nil, nil) + operands := NewJStore[SemanticContext, Comparator[SemanticContext]](&ObjEqComparator[SemanticContext]{}) if aa, ok := a.(*OR); ok { for _, o := range aa.opnds { - operands.Add(o) + operands.Put(o) } } else { - operands.Add(a) + operands.Put(a) } if ba, ok := b.(*OR); ok { for _, o := range ba.opnds { - operands.Add(o) + operands.Put(o) } } else { - operands.Add(b) + operands.Put(b) } precedencePredicates := PrecedencePredicatefilterPrecedencePredicates(operands) if len(precedencePredicates) > 0 { @@ -372,7 +376,7 @@ func NewOR(a, b SemanticContext) *OR { } } - operands.Add(reduced) + operands.Put(reduced) } vs := operands.Values() @@ -388,14 +392,14 @@ func NewOR(a, b SemanticContext) *OR { return o } -func (o *OR) equals(other interface{}) bool { +func (o *OR) Equals(other Collectable[SemanticContext]) bool { if o == other { return true } else if _, ok := other.(*OR); !ok { return false } else { for i, v := range other.(*OR).opnds { - if !o.opnds[i].equals(v) { + if !o.opnds[i].Equals(v) { return false } } @@ -406,7 +410,6 @@ func (o *OR) equals(other interface{}) bool { //

// The evaluation of predicates by o context is short-circuiting, but // unordered.

-// func (o *OR) evaluate(parser Recognizer, outerContext RuleContext) bool { for i := 0; i < len(o.opnds); i++ { if o.opnds[i].evaluate(parser, outerContext) { diff --git a/runtime/Go/antlr/utils.go b/runtime/Go/antlr/utils.go index 8d8a029782..af89099152 100644 --- a/runtime/Go/antlr/utils.go +++ b/runtime/Go/antlr/utils.go @@ -47,28 +47,25 @@ func (s *IntStack) Push(e int) { *s = append(*s, e) } -func standardEqualsFunction(a interface{}, b interface{}) bool { - - ac, oka := a.(comparable) - bc, okb := b.(comparable) +type comparable interface { + Equals(other Collectable[any]) bool +} - if !oka || !okb { - panic("Not Comparable") - } +func standardEqualsFunction(a Collectable[any], b Collectable[any]) bool { - return ac.equals(bc) + return a.Equals(b) } func standardHashFunction(a interface{}) int { if h, ok := a.(hasher); ok { - return h.hash() + return h.Hash() } panic("Not Hasher") } type hasher interface { - hash() int + Hash() int } const bitsPerWord = 64 diff --git a/runtime/Go/antlr/utils_set.go b/runtime/Go/antlr/utils_set.go index 0d4eac698d..c9bd6751e3 100644 --- a/runtime/Go/antlr/utils_set.go +++ b/runtime/Go/antlr/utils_set.go @@ -8,8 +8,6 @@ const ( _loadFactor = 0.75 ) -var _ Set = (*array2DHashSet)(nil) - type Set interface { Add(value interface{}) (added interface{}) Len() int @@ -20,9 +18,9 @@ type Set interface { } type array2DHashSet struct { - buckets [][]interface{} + buckets [][]Collectable[any] hashcodeFunction func(interface{}) int - equalsFunction func(interface{}, interface{}) bool + equalsFunction func(Collectable[any], Collectable[any]) bool n int // How many elements in set threshold int // when to expand @@ -61,11 +59,11 @@ func (as *array2DHashSet) Values() []interface{} { return values } -func (as *array2DHashSet) Contains(value interface{}) bool { +func (as *array2DHashSet) Contains(value Collectable[any]) bool { return as.Get(value) != nil } -func (as *array2DHashSet) Add(value interface{}) interface{} { +func (as *array2DHashSet) Add(value Collectable[any]) interface{} { if as.n > as.threshold { as.expand() } @@ -98,7 +96,7 @@ func (as *array2DHashSet) expand() { b := as.getBuckets(o) bucketLength := newBucketLengths[b] - var newBucket []interface{} + var newBucket []Collectable[any] if bucketLength == 0 { // new bucket newBucket = as.createBucket(as.initialBucketCapacity) @@ -107,7 +105,7 @@ func (as *array2DHashSet) expand() { newBucket = newTable[b] if bucketLength == len(newBucket) { // expand - newBucketCopy := make([]interface{}, len(newBucket)<<1) + newBucketCopy := make([]Collectable[any], len(newBucket)<<1) copy(newBucketCopy[:bucketLength], newBucket) newBucket = newBucketCopy newTable[b] = newBucket @@ -124,7 +122,7 @@ func (as *array2DHashSet) Len() int { return as.n } -func (as *array2DHashSet) Get(o interface{}) interface{} { +func (as *array2DHashSet) Get(o Collectable[any]) interface{} { if o == nil { return nil } @@ -147,7 +145,7 @@ func (as *array2DHashSet) Get(o interface{}) interface{} { return nil } -func (as *array2DHashSet) innerAdd(o interface{}) interface{} { +func (as *array2DHashSet) innerAdd(o Collectable[any]) interface{} { b := as.getBuckets(o) bucket := as.buckets[b] @@ -178,7 +176,7 @@ func (as *array2DHashSet) innerAdd(o interface{}) interface{} { // full bucket, expand and add to end oldLength := len(bucket) - bucketCopy := make([]interface{}, oldLength<<1) + bucketCopy := make([]Collectable[any], oldLength<<1) copy(bucketCopy[:oldLength], bucket) bucket = bucketCopy as.buckets[b] = bucket @@ -187,22 +185,22 @@ func (as *array2DHashSet) innerAdd(o interface{}) interface{} { return o } -func (as *array2DHashSet) getBuckets(value interface{}) int { +func (as *array2DHashSet) getBuckets(value Collectable[any]) int { hash := as.hashcodeFunction(value) return hash & (len(as.buckets) - 1) } -func (as *array2DHashSet) createBuckets(cap int) [][]interface{} { - return make([][]interface{}, cap) +func (as *array2DHashSet) createBuckets(cap int) [][]Collectable[any] { + return make([][]Collectable[any], cap) } -func (as *array2DHashSet) createBucket(cap int) []interface{} { - return make([]interface{}, cap) +func (as *array2DHashSet) createBucket(cap int) []Collectable[any] { + return make([]Collectable[any], cap) } func newArray2DHashSetWithCap( hashcodeFunction func(interface{}) int, - equalsFunction func(interface{}, interface{}) bool, + equalsFunction func(Collectable[any], Collectable[any]) bool, initCap int, initBucketCap int, ) *array2DHashSet { @@ -231,7 +229,7 @@ func newArray2DHashSetWithCap( func newArray2DHashSet( hashcodeFunction func(interface{}) int, - equalsFunction func(interface{}, interface{}) bool, + equalsFunction func(Collectable[any], Collectable[any]) bool, ) *array2DHashSet { return newArray2DHashSetWithCap(hashcodeFunction, equalsFunction, _initalCapacity, _initalBucketCapacity) }