Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 29 additions & 29 deletions runtime/Go/antlr/atn_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,14 @@ import (
"fmt"
)

type comparable interface {
equals(other interface{}) bool
}

// ATNConfig is a tuple: (ATN state, predicted alt, syntactic, semantic
// context). The syntactic context is a graph-structured stack node whose
// path(s) to the root is the rule invocation(s) chain used to arrive at the
// state. The semantic context is the tree of semantic predicates encountered
// before reaching an ATN state.
type ATNConfig interface {
comparable

gequals(other Collectable[ATNConfig]) bool

hash() int
Equals(o Collectable[ATNConfig]) bool
Hash() int

GetState() ATNState
GetAlt() int
Expand Down Expand Up @@ -136,15 +129,17 @@ func (b *BaseATNConfig) GetReachesIntoOuterContext() int {
func (b *BaseATNConfig) SetReachesIntoOuterContext(v int) {
b.reachesIntoOuterContext = v
}
func (b *BaseATNConfig) equals(o interface{}) bool {
return b.gequals(o.(Collectable[ATNConfig]))
}

// Equals is the default comparison function for an ATNConfig when no specialist implementation is required
// for a collection.
//
// An ATN configuration is equal to another if both have the same state, they
// predict the same alternative, and syntactic/semantic contexts are the same.
func (b *BaseATNConfig) gequals(o Collectable[ATNConfig]) bool {
func (b *BaseATNConfig) Equals(o Collectable[ATNConfig]) bool {
if b == o {
return true
} else if o == nil {
return false
}

var other, ok = o.(*BaseATNConfig)
Expand All @@ -158,30 +153,32 @@ func (b *BaseATNConfig) gequals(o Collectable[ATNConfig]) bool {
if b.context == nil {
equal = other.context == nil
} else {
equal = b.context.gequals(other.context)
equal = b.context.Equals(other.context)
}

var (
nums = b.state.GetStateNumber() == other.state.GetStateNumber()
alts = b.alt == other.alt
cons = b.semanticContext.equals(other.semanticContext)
cons = b.semanticContext.Equals(other.semanticContext)
sups = b.precedenceFilterSuppressed == other.precedenceFilterSuppressed
)

return nums && alts && cons && sups && equal
}

func (b *BaseATNConfig) hash() int {
// Hash is the default hash function for BaseATNConfig, when no specialist hash function
// is required for a collection
func (b *BaseATNConfig) Hash() int {
var c int
if b.context != nil {
c = b.context.hash()
c = b.context.Hash()
}

h := murmurInit(7)
h = murmurUpdate(h, b.state.GetStateNumber())
h = murmurUpdate(h, b.alt)
h = murmurUpdate(h, c)
h = murmurUpdate(h, b.semanticContext.hash())
h = murmurUpdate(h, b.semanticContext.Hash())
return murmurFinish(h, 4)
}

Expand Down Expand Up @@ -248,7 +245,9 @@ func NewLexerATNConfig1(state ATNState, alt int, context PredictionContext) *Lex
return &LexerATNConfig{BaseATNConfig: NewBaseATNConfig5(state, alt, context, SemanticContextNone)}
}

func (l *LexerATNConfig) hash() int {
// Hash is the default hash function for LexerATNConfig objects, it can be used directly or via
// the default comparator [ObjEqComparator].
func (l *LexerATNConfig) Hash() int {
var f int
if l.passedThroughNonGreedyDecision {
f = 1
Expand All @@ -258,19 +257,20 @@ func (l *LexerATNConfig) hash() int {
h := murmurInit(7)
h = murmurUpdate(h, l.state.GetStateNumber())
h = murmurUpdate(h, l.alt)
h = murmurUpdate(h, l.context.hash())
h = murmurUpdate(h, l.semanticContext.hash())
h = murmurUpdate(h, l.context.Hash())
h = murmurUpdate(h, l.semanticContext.Hash())
h = murmurUpdate(h, f)
h = murmurUpdate(h, l.lexerActionExecutor.hash())
h = murmurUpdate(h, l.lexerActionExecutor.Hash())
h = murmurFinish(h, 6)
return h
}

func (l *LexerATNConfig) equals(other interface{}) bool {
return l.gequals(other.(Collectable[ATNConfig]))
}

func (l *LexerATNConfig) gequals(other Collectable[ATNConfig]) bool {
// Equals is the default comparison function for LexerATNConfig objects, it can be used directly or via
// the default comparator [ObjEqComparator].
func (l *LexerATNConfig) Equals(other Collectable[ATNConfig]) bool {
if l == other {
return true
}
var othert, ok = other.(*LexerATNConfig)

if l == other {
Expand All @@ -284,7 +284,7 @@ func (l *LexerATNConfig) gequals(other Collectable[ATNConfig]) bool {
var b bool

if l.lexerActionExecutor != nil {
b = !l.lexerActionExecutor.equals(othert.lexerActionExecutor)
b = !l.lexerActionExecutor.Equals(othert.lexerActionExecutor)
} else {
b = othert.lexerActionExecutor != nil
}
Expand All @@ -293,7 +293,7 @@ func (l *LexerATNConfig) gequals(other Collectable[ATNConfig]) bool {
return false
}

return l.BaseATNConfig.equals(othert.BaseATNConfig)
return l.BaseATNConfig.Equals(othert.BaseATNConfig)
}

func checkNonGreedyDecision(source *LexerATNConfig, target ATNState) bool {
Expand Down
46 changes: 26 additions & 20 deletions runtime/Go/antlr/atn_config_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,17 @@ package antlr
import "fmt"

type ATNConfigSet interface {
hash() int
Hash() int
Equals(o Collectable[ATNConfig]) bool
Add(ATNConfig, *DoubleDict) bool
AddAll([]ATNConfig) bool

GetStates() Set
GetStates() *JStore[ATNState, Comparator[ATNState]]
GetPredicates() []SemanticContext
GetItems() []ATNConfig

OptimizeConfigs(interpreter *BaseATNSimulator)

Equals(other interface{}) bool

Length() int
IsEmpty() bool
Contains(ATNConfig) bool
Expand Down Expand Up @@ -57,7 +56,7 @@ type BaseATNConfigSet struct {
// effectively doubles the number of objects associated with ATNConfigs. All
// keys are hashed by (s, i, _, pi), not including the context. Wiped out when
// read-only because a set becomes a DFA state.
configLookup Set
configLookup *JStore[ATNConfig, Comparator[ATNConfig]]

// configs is the added elements.
configs []ATNConfig
Expand All @@ -83,7 +82,7 @@ type BaseATNConfigSet struct {

// readOnly is whether it is read-only. Do not
// allow any code to manipulate the set if true because DFA states will point at
// sets and those must not change. It not protect other fields; conflictingAlts
// sets and those must not change. It not, protect other fields; conflictingAlts
// in particular, which is assigned after readOnly.
readOnly bool

Expand All @@ -104,7 +103,7 @@ func (b *BaseATNConfigSet) Alts() *BitSet {
func NewBaseATNConfigSet(fullCtx bool) *BaseATNConfigSet {
return &BaseATNConfigSet{
cachedHash: -1,
configLookup: newArray2DHashSetWithCap(hashATNConfig, equalATNConfigs, 16, 2),
configLookup: NewJStore[ATNConfig, Comparator[ATNConfig]](&ATNConfigComparator[ATNConfig]{}),
fullCtx: fullCtx,
}
}
Expand All @@ -126,9 +125,11 @@ func (b *BaseATNConfigSet) Add(config ATNConfig, mergeCache *DoubleDict) bool {
b.dipsIntoOuterContext = true
}

existing := b.configLookup.Add(config).(ATNConfig)
existing, present := b.configLookup.Put(config)

if existing == config {
// The config was not already in the set
//
if !present {
b.cachedHash = -1
b.configs = append(b.configs, config) // Track order here
return true
Expand All @@ -154,11 +155,14 @@ func (b *BaseATNConfigSet) Add(config ATNConfig, mergeCache *DoubleDict) bool {
return true
}

func (b *BaseATNConfigSet) GetStates() Set {
states := newArray2DHashSet(nil, nil)
func (b *BaseATNConfigSet) GetStates() *JStore[ATNState, Comparator[ATNState]] {

// states uses the standard comparator provided by the ATNState instance
//
states := NewJStore[ATNState, Comparator[ATNState]](&ObjEqComparator[ATNState]{})

for i := 0; i < len(b.configs); i++ {
states.Add(b.configs[i].GetState())
states.Put(b.configs[i].GetState())
}

return states
Expand Down Expand Up @@ -227,7 +231,7 @@ func (b *BaseATNConfigSet) Compare(bs *BaseATNConfigSet) bool {
for _, c := range b.configs {
found := false
for _, c2 := range bs.configs {
if c.equals(c2) {
if c.Equals(c2) {
found = true
break
}
Expand All @@ -241,7 +245,8 @@ func (b *BaseATNConfigSet) Compare(bs *BaseATNConfigSet) bool {
return true
}

func (b *BaseATNConfigSet) Equals(other interface{}) bool {

func (b *BaseATNConfigSet) Equals(other Collectable[ATNConfig]) bool {
if b == other {
return true
} else if _, ok := other.(*BaseATNConfigSet); !ok {
Expand All @@ -259,7 +264,7 @@ func (b *BaseATNConfigSet) Equals(other interface{}) bool {
b.Compare(other2)
}

func (b *BaseATNConfigSet) hash() int {
func (b *BaseATNConfigSet) Hash() int {
if b.readOnly {
if b.cachedHash == -1 {
b.cachedHash = b.hashCodeConfigs()
Expand All @@ -274,7 +279,7 @@ func (b *BaseATNConfigSet) hash() int {
func (b *BaseATNConfigSet) hashCodeConfigs() int {
h := 1
for _, config := range b.configs {
h = 31*h + config.hash()
h = 31*h + config.Hash()
}
return h
}
Expand Down Expand Up @@ -310,7 +315,7 @@ func (b *BaseATNConfigSet) Clear() {

b.configs = make([]ATNConfig, 0)
b.cachedHash = -1
b.configLookup = newArray2DHashSet(nil, equalATNConfigs)
b.configLookup = NewJStore[ATNConfig, Comparator[ATNConfig]](&BaseATNConfigComparator[ATNConfig]{})
}

func (b *BaseATNConfigSet) FullContext() bool {
Expand Down Expand Up @@ -392,7 +397,8 @@ type OrderedATNConfigSet struct {
func NewOrderedATNConfigSet() *OrderedATNConfigSet {
b := NewBaseATNConfigSet(false)

b.configLookup = newArray2DHashSet(nil, nil)
// This set uses the standard Hash() and Equals() from ATNConfig
b.configLookup = NewJStore[ATNConfig, Comparator[ATNConfig]](&ObjEqComparator[ATNConfig]{})

return &OrderedATNConfigSet{BaseATNConfigSet: b}
}
Expand All @@ -402,7 +408,7 @@ func hashATNConfig(i interface{}) int {
hash := 7
hash = 31*hash + o.GetState().GetStateNumber()
hash = 31*hash + o.GetAlt()
hash = 31*hash + o.GetSemanticContext().hash()
hash = 31*hash + o.GetSemanticContext().Hash()
return hash
}

Expand Down Expand Up @@ -430,5 +436,5 @@ func equalATNConfigs(a, b interface{}) bool {
return false
}

return ai.GetSemanticContext().equals(bi.GetSemanticContext())
return ai.GetSemanticContext().Equals(bi.GetSemanticContext())
}
7 changes: 4 additions & 3 deletions runtime/Go/antlr/atn_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ type ATNState interface {
AddTransition(Transition, int)

String() string
hash() int
Hash() int
Equals(Collectable[ATNState]) bool
}

type BaseATNState struct {
Expand Down Expand Up @@ -123,15 +124,15 @@ func (as *BaseATNState) SetNextTokenWithinRule(v *IntervalSet) {
as.NextTokenWithinRule = v
}

func (as *BaseATNState) hash() int {
func (as *BaseATNState) Hash() int {
return as.stateNumber
}

func (as *BaseATNState) String() string {
return strconv.Itoa(as.stateNumber)
}

func (as *BaseATNState) equals(other interface{}) bool {
func (as *BaseATNState) Equals(other Collectable[ATNState]) bool {
if ot, ok := other.(ATNState); ok {
return as.stateNumber == ot.GetStateNumber()
}
Expand Down
Loading