Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 15 additions & 18 deletions runtime/Cpp/runtime/src/ParserRuleContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,39 +77,36 @@ void ParserRuleContext::removeLastChild() {
}
}

tree::TerminalNode* ParserRuleContext::getToken(size_t ttype, size_t i) {
tree::TerminalNode* ParserRuleContext::getToken(size_t ttype, size_t i) const {
if (i >= children.size()) {
return nullptr;
}

size_t j = 0; // what token with ttype have we found?
for (auto *o : children) {
if (o->getTreeType() == ParseTreeType::TERMINAL || o->getTreeType() == ParseTreeType::ERROR) {
tree::TerminalNode *tnode = downCast<tree::TerminalNode *>(o);
Token *symbol = tnode->getSymbol();
for (auto *child : children) {
if (TerminalNode::is(child)) {
tree::TerminalNode *typedChild = downCast<tree::TerminalNode*>(child);
Token *symbol = typedChild->getSymbol();
if (symbol->getType() == ttype) {
if (j++ == i) {
return tnode;
return typedChild;
}
}
}
}

return nullptr;
}

std::vector<tree::TerminalNode *> ParserRuleContext::getTokens(size_t ttype) {
std::vector<tree::TerminalNode *> tokens;
for (auto &o : children) {
if (o->getTreeType() == ParseTreeType::TERMINAL || o->getTreeType() == ParseTreeType::ERROR) {
tree::TerminalNode *tnode = downCast<tree::TerminalNode *>(o);
Token *symbol = tnode->getSymbol();
std::vector<tree::TerminalNode *> ParserRuleContext::getTokens(size_t ttype) const {
std::vector<tree::TerminalNode*> tokens;
for (auto *child : children) {
if (TerminalNode::is(child)) {
tree::TerminalNode *typedChild = downCast<tree::TerminalNode*>(child);
Token *symbol = typedChild->getSymbol();
if (symbol->getType() == ttype) {
tokens.push_back(tnode);
tokens.push_back(typedChild);
}
}
}

return tokens;
}

Expand All @@ -124,11 +121,11 @@ misc::Interval ParserRuleContext::getSourceInterval() {
return misc::Interval(start->getTokenIndex(), stop->getTokenIndex());
}

Token* ParserRuleContext::getStart() {
Token* ParserRuleContext::getStart() const {
return start;
}

Token* ParserRuleContext::getStop() {
Token* ParserRuleContext::getStop() const {
return stop;
}

Expand Down
40 changes: 20 additions & 20 deletions runtime/Cpp/runtime/src/ParserRuleContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ namespace antlr4 {

ParserRuleContext();
ParserRuleContext(ParserRuleContext *parent, size_t invokingStateNumber);
virtual ~ParserRuleContext() {}

/** COPY a ctx (I'm deliberately not using copy constructor) to avoid
* confusion with creating node with parent. Does not copy children
Expand All @@ -88,38 +87,39 @@ namespace antlr4 {
/// Used by enterOuterAlt to toss out a RuleContext previously added as
/// we entered a rule. If we have # label, we will need to remove
/// generic ruleContext object.
virtual void removeLastChild();
void removeLastChild();

virtual tree::TerminalNode* getToken(size_t ttype, std::size_t i);
tree::TerminalNode* getToken(size_t ttype, std::size_t i) const;

virtual std::vector<tree::TerminalNode *> getTokens(size_t ttype);
std::vector<tree::TerminalNode*> getTokens(size_t ttype) const;

template<typename T>
T* getRuleContext(size_t i) {
if (children.empty()) {
return nullptr;
}

T* getRuleContext(size_t i) const {
static_assert(std::is_base_of_v<RuleContext, T>, "T must be derived from RuleContext");
size_t j = 0; // what element have we found with ctxType?
for (auto &child : children) {
if (antlrcpp::is<T *>(child)) {
if (j++ == i) {
return dynamic_cast<T *>(child);
for (auto *child : children) {
if (RuleContext::is(child)) {
if (auto *typedChild = dynamic_cast<T*>(child); typedChild != nullptr) {
if (j++ == i) {
return typedChild;
}
}
}
}
return nullptr;
}

template<typename T>
std::vector<T *> getRuleContexts() {
std::vector<T *> contexts;
std::vector<T*> getRuleContexts() const {
static_assert(std::is_base_of_v<RuleContext, T>, "T must be derived from RuleContext");
std::vector<T*> contexts;
for (auto *child : children) {
if (antlrcpp::is<T *>(child)) {
contexts.push_back(dynamic_cast<T *>(child));
if (RuleContext::is(child)) {
if (auto *typedChild = dynamic_cast<T*>(child); typedChild != nullptr) {
contexts.push_back(typedChild);
}
}
}

return contexts;
}

Expand All @@ -130,14 +130,14 @@ namespace antlr4 {
* Note that the range from start to stop is inclusive, so for rules that do not consume anything
* (for example, zero length or error productions) this token may exceed stop.
*/
virtual Token *getStart();
Token* getStart() const;

/**
* Get the final token in this context.
* Note that the range from start to stop is inclusive, so for rules that do not consume anything
* (for example, zero length or error productions) this token may precede start.
*/
virtual Token *getStop();
Token* getStop() const;

/// <summary>
/// Used for rule context info debugging during parse-time, not so much for ATN debugging </summary>
Expand Down
4 changes: 4 additions & 0 deletions runtime/Cpp/runtime/src/RuleContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ namespace antlr4 {
*/
class ANTLR4CPP_PUBLIC RuleContext : public tree::ParseTree {
public:
static bool is(const tree::ParseTree &parseTree) { return parseTree.getTreeType() == tree::ParseTreeType::RULE; }

static bool is(const tree::ParseTree *parseTree) { return parseTree != nullptr && is(*parseTree); }

/// What state invoked the rule associated with this context?
/// The "return address" is the followState of invokingState
/// If parent is null, this should be -1 and this context object represents the start rule.
Expand Down
54 changes: 27 additions & 27 deletions runtime/Cpp/runtime/src/atn/ATNDeserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ namespace {
*/
void markPrecedenceDecisions(const ATN &atn) {
for (ATNState *state : atn.states) {
if (!is<StarLoopEntryState*>(state)) {
if (!StarLoopEntryState::is(state)) {
continue;
}

Expand All @@ -92,8 +92,8 @@ namespace {
*/
if (atn.ruleToStartState[state->ruleIndex]->isLeftRecursiveRule) {
ATNState *maybeLoopEndState = state->transitions[state->transitions.size() - 1]->target;
if (is<LoopEndState *>(maybeLoopEndState)) {
if (maybeLoopEndState->epsilonOnlyTransitions && is<RuleStopState*>(maybeLoopEndState->transitions[0]->target)) {
if (LoopEndState::is(maybeLoopEndState)) {
if (maybeLoopEndState->epsilonOnlyTransitions && RuleStopState::is(maybeLoopEndState->transitions[0]->target)) {
downCast<StarLoopEntryState*>(state)->isPrecedenceDecision = true;
}
}
Expand Down Expand Up @@ -291,7 +291,7 @@ std::unique_ptr<ATN> ATNDeserializer::deserialize(const std::vector<int32_t>& da
if (stype == ATNStateType::LOOP_END) { // special case
int loopBackStateNumber = data[p++];
loopBackStateNumbers.push_back({ downCast<LoopEndState*>(s), loopBackStateNumber });
} else if (is<BlockStartState*>(s)) {
} else if (BlockStartState::is(s)) {
int endStateNumber = data[p++];
endStateNumbers.push_back({ downCast<BlockStartState*>(s), endStateNumber });
}
Expand Down Expand Up @@ -340,7 +340,7 @@ std::unique_ptr<ATN> ATNDeserializer::deserialize(const std::vector<int32_t>& da

atn->ruleToStopState.resize(nrules);
for (ATNState *state : atn->states) {
if (!is<RuleStopState*>(state)) {
if (!RuleStopState::is(state)) {
continue;
}

Expand Down Expand Up @@ -389,7 +389,7 @@ std::unique_ptr<ATN> ATNDeserializer::deserialize(const std::vector<int32_t>& da
for (ATNState *state : atn->states) {
for (size_t i = 0; i < state->transitions.size(); i++) {
const Transition *t = state->transitions[i].get();
if (!is<const RuleTransition*>(t)) {
if (!RuleTransition::is(t)) {
continue;
}

Expand All @@ -407,7 +407,7 @@ std::unique_ptr<ATN> ATNDeserializer::deserialize(const std::vector<int32_t>& da
}

for (ATNState *state : atn->states) {
if (is<BlockStartState*>(state)) {
if (BlockStartState::is(state)) {
BlockStartState *startState = downCast<BlockStartState*>(state);

// we need to know the end state to set its start state
Expand All @@ -423,19 +423,19 @@ std::unique_ptr<ATN> ATNDeserializer::deserialize(const std::vector<int32_t>& da
startState->endState->startState = downCast<BlockStartState*>(state);
}

if (is<PlusLoopbackState*>(state)) {
if (PlusLoopbackState::is(state)) {
PlusLoopbackState *loopbackState = downCast<PlusLoopbackState*>(state);
for (size_t i = 0; i < loopbackState->transitions.size(); i++) {
ATNState *target = loopbackState->transitions[i]->target;
if (is<PlusBlockStartState*>(target)) {
if (PlusBlockStartState::is(target)) {
(downCast<PlusBlockStartState*>(target))->loopBackState = loopbackState;
}
}
} else if (is<StarLoopbackState*>(state)) {
} else if (StarLoopbackState::is(state)) {
StarLoopbackState *loopbackState = downCast<StarLoopbackState*>(state);
for (size_t i = 0; i < loopbackState->transitions.size(); i++) {
ATNState *target = loopbackState->transitions[i]->target;
if (is<StarLoopEntryState *>(target)) {
if (StarLoopEntryState::is(target)) {
downCast<StarLoopEntryState*>(target)->loopBackState = loopbackState;
}
}
Expand Down Expand Up @@ -506,16 +506,16 @@ std::unique_ptr<ATN> ATNDeserializer::deserialize(const std::vector<int32_t>& da
continue;
}

if (!is<StarLoopEntryState*>(state)) {
if (!StarLoopEntryState::is(state)) {
continue;
}

ATNState *maybeLoopEndState = state->transitions[state->transitions.size() - 1]->target;
if (!is<LoopEndState*>(maybeLoopEndState)) {
if (!LoopEndState::is(maybeLoopEndState)) {
continue;
}

if (maybeLoopEndState->epsilonOnlyTransitions && is<RuleStopState*>(maybeLoopEndState->transitions[0]->target)) {
if (maybeLoopEndState->epsilonOnlyTransitions && RuleStopState::is(maybeLoopEndState->transitions[0]->target)) {
endState = state;
break;
}
Expand Down Expand Up @@ -578,52 +578,52 @@ void ATNDeserializer::verifyATN(const ATN &atn) const {

checkCondition(state->epsilonOnlyTransitions || state->transitions.size() <= 1);

if (is<PlusBlockStartState*>(state)) {
if (PlusBlockStartState::is(state)) {
checkCondition((downCast<PlusBlockStartState*>(state))->loopBackState != nullptr);
}

if (is<StarLoopEntryState*>(state)) {
if (StarLoopEntryState::is(state)) {
StarLoopEntryState *starLoopEntryState = downCast<StarLoopEntryState*>(state);
checkCondition(starLoopEntryState->loopBackState != nullptr);
checkCondition(starLoopEntryState->transitions.size() == 2);

if (is<StarBlockStartState*>(starLoopEntryState->transitions[0]->target)) {
if (StarBlockStartState::is(starLoopEntryState->transitions[0]->target)) {
checkCondition(downCast<LoopEndState*>(starLoopEntryState->transitions[1]->target) != nullptr);
checkCondition(!starLoopEntryState->nonGreedy);
} else if (is<LoopEndState*>(starLoopEntryState->transitions[0]->target)) {
checkCondition(is<StarBlockStartState*>(starLoopEntryState->transitions[1]->target));
} else if (LoopEndState::is(starLoopEntryState->transitions[0]->target)) {
checkCondition(StarBlockStartState::is(starLoopEntryState->transitions[1]->target));
checkCondition(starLoopEntryState->nonGreedy);
} else {
throw IllegalStateException();
}
}

if (is<StarLoopbackState*>(state)) {
if (StarLoopbackState::is(state)) {
checkCondition(state->transitions.size() == 1);
checkCondition(is<StarLoopEntryState*>(state->transitions[0]->target));
checkCondition(StarLoopEntryState::is(state->transitions[0]->target));
}

if (is<LoopEndState*>(state)) {
if (LoopEndState::is(state)) {
checkCondition((downCast<LoopEndState*>(state))->loopBackState != nullptr);
}

if (is<RuleStartState*>(state)) {
if (RuleStartState::is(state)) {
checkCondition((downCast<RuleStartState*>(state))->stopState != nullptr);
}

if (is<BlockStartState*>(state)) {
if (BlockStartState::is(state)) {
checkCondition((downCast<BlockStartState*>(state))->endState != nullptr);
}

if (is<BlockEndState*>(state)) {
if (BlockEndState::is(state)) {
checkCondition((downCast<BlockEndState*>(state))->startState != nullptr);
}

if (is<DecisionState*>(state)) {
if (DecisionState::is(state)) {
DecisionState *decisionState = downCast<DecisionState*>(state);
checkCondition(decisionState->transitions.size() <= 1 || decisionState->decision >= 0);
} else {
checkCondition(state->transitions.size() <= 1 || is<RuleStopState*>(state));
checkCondition(state->transitions.size() <= 1 || RuleStopState::is(state));
}
}
}
4 changes: 4 additions & 0 deletions runtime/Cpp/runtime/src/atn/ActionTransition.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ namespace atn {

class ANTLR4CPP_PUBLIC ActionTransition final : public Transition {
public:
static bool is(const Transition &transition) { return transition.getTransitionType() == TransitionType::ACTION; }

static bool is(const Transition *transition) { return transition != nullptr && is(*transition); }

const size_t ruleIndex;
const size_t actionIndex;
const bool isCtxDependent; // e.g., $i ref in action
Expand Down
4 changes: 4 additions & 0 deletions runtime/Cpp/runtime/src/atn/ArrayPredictionContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ namespace atn {

class ANTLR4CPP_PUBLIC ArrayPredictionContext final : public PredictionContext {
public:
static bool is(const PredictionContext &predictionContext) { return predictionContext.getContextType() == PredictionContextType::ARRAY; }

static bool is(const PredictionContext *predictionContext) { return predictionContext != nullptr && is(*predictionContext); }

/// Parent can be empty only if full ctx mode and we make an array
/// from EMPTY and non-empty. We merge EMPTY by using null parent and
/// returnState == EMPTY_RETURN_STATE.
Expand Down
4 changes: 4 additions & 0 deletions runtime/Cpp/runtime/src/atn/AtomTransition.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ namespace atn {
/// TODO: make all transitions sets? no, should remove set edges.
class ANTLR4CPP_PUBLIC AtomTransition final : public Transition {
public:
static bool is(const Transition &transition) { return transition.getTransitionType() == TransitionType::ATOM; }

static bool is(const Transition *transition) { return transition != nullptr && is(*transition); }

/// The token type or character value; or, signifies special label.
/// TODO: rename this to label
const size_t _label;
Expand Down
4 changes: 4 additions & 0 deletions runtime/Cpp/runtime/src/atn/BasicBlockStartState.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ namespace atn {

class ANTLR4CPP_PUBLIC BasicBlockStartState final : public BlockStartState {
public:
static bool is(const ATNState &atnState) { return atnState.getStateType() == ATNStateType::BLOCK_START; }

static bool is(const ATNState *atnState) { return atnState != nullptr && is(*atnState); }

BasicBlockStartState() : BlockStartState(ATNStateType::BLOCK_START) {}
};

Expand Down
4 changes: 4 additions & 0 deletions runtime/Cpp/runtime/src/atn/BasicState.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ namespace atn {

class ANTLR4CPP_PUBLIC BasicState final : public ATNState {
public:
static bool is(const ATNState &atnState) { return atnState.getStateType() == ATNStateType::BASIC; }

static bool is(const ATNState *atnState) { return atnState != nullptr && is(*atnState); }

BasicState() : ATNState(ATNStateType::BASIC) {}
};

Expand Down
4 changes: 4 additions & 0 deletions runtime/Cpp/runtime/src/atn/BlockEndState.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ namespace atn {
/// Terminal node of a simple {@code (a|b|c)} block.
class ANTLR4CPP_PUBLIC BlockEndState final : public ATNState {
public:
static bool is(const ATNState &atnState) { return atnState.getStateType() == ATNStateType::BLOCK_END; }

static bool is(const ATNState *atnState) { return atnState != nullptr && is(*atnState); }

BlockStartState *startState = nullptr;

BlockEndState() : ATNState(ATNStateType::BLOCK_END) {}
Expand Down
Loading