diff --git a/runtime/Cpp/runtime/src/Parser.cpp b/runtime/Cpp/runtime/src/Parser.cpp index bcd36b44fb..ec74ca89e2 100755 --- a/runtime/Cpp/runtime/src/Parser.cpp +++ b/runtime/Cpp/runtime/src/Parser.cpp @@ -38,7 +38,7 @@ struct BypassAltsAtnCache final { /// bypass alternatives. /// /// - std::map, std::unique_ptr> map; + std::map, std::unique_ptr, std::less<>> map; }; BypassAltsAtnCache* getBypassAltsAtnCache() { @@ -227,9 +227,8 @@ TokenFactory* Parser::getTokenFactory() { return _input->getTokenSource()->getTokenFactory(); } - const atn::ATN& Parser::getATNWithBypassAlts() { - const std::vector &serializedAtn = getSerializedATN(); + auto serializedAtn = getSerializedATN(); if (serializedAtn.empty()) { throw UnsupportedOperationException("The current parser does not support an ATN with bypass alternatives."); } @@ -244,15 +243,16 @@ const atn::ATN& Parser::getATNWithBypassAlts() { } } + std::unique_lock lock(cache->mutex); + auto existing = cache->map.find(serializedAtn); + if (existing != cache->map.end()) { + return *existing->second; + } atn::ATNDeserializationOptions deserializationOptions; deserializationOptions.setGenerateRuleBypassTransitions(true); atn::ATNDeserializer deserializer(deserializationOptions); auto atn = deserializer.deserialize(serializedAtn); - - { - std::unique_lock lock(cache->mutex); - return *cache->map.insert(std::make_pair(serializedAtn, std::move(atn))).first->second; - } + return *cache->map.insert(std::make_pair(std::vector(serializedAtn.begin(), serializedAtn.end()), std::move(atn))).first->second; } tree::pattern::ParseTreePattern Parser::compileParseTreePattern(const std::string &pattern, int patternRuleIndex) { diff --git a/runtime/Cpp/runtime/src/Recognizer.h b/runtime/Cpp/runtime/src/Recognizer.h index 28abfc8741..849075b360 100755 --- a/runtime/Cpp/runtime/src/Recognizer.h +++ b/runtime/Cpp/runtime/src/Recognizer.h @@ -7,6 +7,7 @@ #include "ProxyErrorListener.h" #include "support/Casts.h" +#include "atn/SerializedATNView.h" namespace antlr4 { @@ -53,7 +54,7 @@ namespace antlr4 { /// For interpreters, we don't know their serialized ATN despite having /// created the interpreter from it. /// - virtual const std::vector& getSerializedATN() const { + virtual atn::SerializedATNView getSerializedATN() const { throw "there is no serialized ATN"; } diff --git a/runtime/Cpp/runtime/src/antlr4-runtime.h b/runtime/Cpp/runtime/src/antlr4-runtime.h index 722df6d075..85022cc5f3 100644 --- a/runtime/Cpp/runtime/src/antlr4-runtime.h +++ b/runtime/Cpp/runtime/src/antlr4-runtime.h @@ -108,6 +108,7 @@ #include "atn/RuleStopState.h" #include "atn/RuleTransition.h" #include "atn/SemanticContext.h" +#include "atn/SerializedATNView.h" #include "atn/SetTransition.h" #include "atn/SingletonPredictionContext.h" #include "atn/StarBlockStartState.h" diff --git a/runtime/Cpp/runtime/src/atn/ATNDeserializer.cpp b/runtime/Cpp/runtime/src/atn/ATNDeserializer.cpp index c415aef552..1b169bbd2c 100755 --- a/runtime/Cpp/runtime/src/atn/ATNDeserializer.cpp +++ b/runtime/Cpp/runtime/src/atn/ATNDeserializer.cpp @@ -221,12 +221,12 @@ namespace { return s; } - ssize_t readUnicodeInt32(const std::vector& data, int& p) { + ssize_t readUnicodeInt32(SerializedATNView data, int& p) { return static_cast(data[p++]); } void deserializeSets( - const std::vector& data, + SerializedATNView data, int& p, std::vector& sets) { size_t nsets = data[p++]; @@ -255,7 +255,7 @@ ATNDeserializer::ATNDeserializer() : ATNDeserializer(ATNDeserializationOptions:: ATNDeserializer::ATNDeserializer(ATNDeserializationOptions deserializationOptions) : _deserializationOptions(std::move(deserializationOptions)) {} -std::unique_ptr ATNDeserializer::deserialize(const std::vector& data) const { +std::unique_ptr ATNDeserializer::deserialize(SerializedATNView data) const { int p = 0; int version = data[p++]; if (version != SERIALIZED_VERSION) { diff --git a/runtime/Cpp/runtime/src/atn/ATNDeserializer.h b/runtime/Cpp/runtime/src/atn/ATNDeserializer.h index 2442d4b7bd..9be49159fd 100755 --- a/runtime/Cpp/runtime/src/atn/ATNDeserializer.h +++ b/runtime/Cpp/runtime/src/atn/ATNDeserializer.h @@ -6,6 +6,7 @@ #pragma once #include "atn/ATNDeserializationOptions.h" +#include "atn/SerializedATNView.h" #include "atn/LexerAction.h" #include "atn/Transition.h" @@ -20,7 +21,7 @@ namespace atn { explicit ATNDeserializer(ATNDeserializationOptions deserializationOptions); - std::unique_ptr deserialize(const std::vector &input) const; + std::unique_ptr deserialize(SerializedATNView input) const; void verifyATN(const ATN &atn) const; private: diff --git a/runtime/Cpp/runtime/src/atn/SerializedATNView.h b/runtime/Cpp/runtime/src/atn/SerializedATNView.h new file mode 100644 index 0000000000..a723589bc3 --- /dev/null +++ b/runtime/Cpp/runtime/src/atn/SerializedATNView.h @@ -0,0 +1,101 @@ +/* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. + * Use of this file is governed by the BSD 3-clause license that + * can be found in the LICENSE.txt file in the project root. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include "antlr4-common.h" +#include "misc/MurmurHash.h" + +namespace antlr4 { +namespace atn { + + class ANTLR4CPP_PUBLIC SerializedATNView final { + public: + using value_type = int32_t; + using size_type = size_t; + using difference_type = ptrdiff_t; + using reference = int32_t&; + using const_reference = const int32_t&; + using pointer = int32_t*; + using const_pointer = const int32_t*; + using iterator = const_pointer; + using const_iterator = const_pointer; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; + + SerializedATNView() = default; + + SerializedATNView(const_pointer data, size_type size) : _data(data), _size(size) {} + + SerializedATNView(const std::vector &serializedATN) : _data(serializedATN.data()), _size(serializedATN.size()) {} + + SerializedATNView(const SerializedATNView&) = default; + + SerializedATNView& operator=(const SerializedATNView&) = default; + + const_iterator begin() const { return data(); } + + const_iterator cbegin() const { return data(); } + + const_iterator end() const { return data() + size(); } + + const_iterator cend() const { return data() + size(); } + + const_reverse_iterator rbegin() const { return const_reverse_iterator(end()); } + + const_reverse_iterator crbegin() const { return const_reverse_iterator(cend()); } + + const_reverse_iterator rend() const { return const_reverse_iterator(begin()); } + + const_reverse_iterator crend() const { return const_reverse_iterator(cbegin()); } + + bool empty() const { return size() == 0; } + + const_pointer data() const { return _data; } + + size_type size() const { return _size; } + + size_type size_bytes() const { return size() * sizeof(value_type); } + + const_reference operator[](size_type index) const { return _data[index]; } + + private: + const_pointer _data = nullptr; + size_type _size = 0; + }; + + inline bool operator==(const SerializedATNView &lhs, const SerializedATNView &rhs) { + return (lhs.data() == rhs.data() && lhs.size() == rhs.size()) || + (lhs.size() == rhs.size() && std::memcmp(lhs.data(), rhs.data(), lhs.size_bytes()) == 0); + } + + inline bool operator!=(const SerializedATNView &lhs, const SerializedATNView &rhs) { + return !operator==(lhs, rhs); + } + + inline bool operator<(const SerializedATNView &lhs, const SerializedATNView &rhs) { + int diff = std::memcmp(lhs.data(), rhs.data(), std::min(lhs.size_bytes(), rhs.size_bytes())); + return diff < 0 || (diff == 0 && lhs.size() < rhs.size()); + } + +} // namespace atn +} // namespace antlr4 + +namespace std { + + template <> + struct hash<::antlr4::atn::SerializedATNView> { + size_t operator()(const ::antlr4::atn::SerializedATNView &serializedATNView) const { + return ::antlr4::misc::MurmurHash::hashCode(serializedATNView.data(), serializedATNView.size()); + } + }; + +} // namespace std diff --git a/runtime/Cpp/runtime/src/misc/MurmurHash.cpp b/runtime/Cpp/runtime/src/misc/MurmurHash.cpp index 3a4fd1869f..73562cd9bd 100755 --- a/runtime/Cpp/runtime/src/misc/MurmurHash.cpp +++ b/runtime/Cpp/runtime/src/misc/MurmurHash.cpp @@ -5,6 +5,7 @@ #include #include +#include #include "misc/MurmurHash.h" @@ -62,6 +63,23 @@ size_t MurmurHash::update(size_t hash, size_t value) { return hash; } +size_t MurmurHash::update(size_t hash, const void *data, size_t size) { + size_t value; + const uint8_t *bytes = static_cast(data); + while (size >= sizeof(size_t)) { + std::memcpy(&value, bytes, sizeof(size_t)); + hash = update(hash, value); + bytes += sizeof(size_t); + size -= sizeof(size_t); + } + if (size != 0) { + value = 0; + std::memcpy(&value, bytes, size); + hash = update(hash, value); + } + return hash; +} + size_t MurmurHash::finish(size_t hash, size_t entryCount) { hash ^= entryCount * 8; hash ^= hash >> 33; diff --git a/runtime/Cpp/runtime/src/misc/MurmurHash.h b/runtime/Cpp/runtime/src/misc/MurmurHash.h index 9074592922..940ee67155 100755 --- a/runtime/Cpp/runtime/src/misc/MurmurHash.h +++ b/runtime/Cpp/runtime/src/misc/MurmurHash.h @@ -6,6 +6,7 @@ #pragma once #include +#include #include "antlr4-common.h" @@ -47,6 +48,13 @@ namespace misc { return update(hash, value != nullptr ? value->hashCode() : 0); } + static size_t update(size_t hash, const void *data, size_t size); + + template + static size_t update(size_t hash, const T *data, size_t size) { + return update(hash, static_cast(data), size * sizeof(std::remove_reference_t)); + } + /// /// Apply the final computation steps to the intermediate value {@code hash} /// to form the final result of the MurmurHash 3 hash function. @@ -63,7 +71,7 @@ namespace misc { /// the seed for the MurmurHash algorithm /// the hash code of the data template // where T is C array type - static size_t hashCode(const std::vector> &data, size_t seed) { + static size_t hashCode(const std::vector> &data, size_t seed = DEFAULT_SEED) { size_t hash = initialize(seed); for (auto &entry : data) { hash = update(hash, entry); @@ -71,6 +79,17 @@ namespace misc { return finish(hash, data.size()); } + static size_t hashCode(const void *data, size_t size, size_t seed = DEFAULT_SEED) { + size_t hash = initialize(seed); + hash = update(hash, data, size); + return finish(hash, size); + } + + template + static size_t hashCode(const T *data, size_t size, size_t seed = DEFAULT_SEED) { + return hashCode(static_cast(data), size * sizeof(std::remove_reference_t), seed); + } + private: MurmurHash() = delete; diff --git a/runtime/Cpp/runtime/src/tree/xpath/XPathLexer.cpp b/runtime/Cpp/runtime/src/tree/xpath/XPathLexer.cpp index b648f6c085..48318f9a28 100644 --- a/runtime/Cpp/runtime/src/tree/xpath/XPathLexer.cpp +++ b/runtime/Cpp/runtime/src/tree/xpath/XPathLexer.cpp @@ -33,7 +33,7 @@ struct XPathLexerStaticData final { const std::vector literalNames; const std::vector symbolicNames; const antlr4::dfa::Vocabulary vocabulary; - std::vector serializedATN; + antlr4::atn::SerializedATNView serializedATN; std::unique_ptr atn; }; @@ -61,7 +61,7 @@ void xpathLexerInitialize() { "STRING" } ); - static const int32_t serializedATNSegment0[] = { + static const int32_t serializedATNSegment[] = { 0x4, 0x0, 0x8, 0x32, 0x6, -1, 0x2, 0x0, 0x7, 0x0, 0x2, 0x1, 0x7, 0x1, 0x2, 0x2, 0x7, 0x2, 0x2, 0x3, 0x7, 0x3, 0x2, 0x4, 0x7, 0x4, 0x2, 0x5, 0x7, 0x5, 0x2, 0x6, 0x7, 0x6, 0x2, 0x7, 0x7, 0x7, 0x1, @@ -102,12 +102,7 @@ void xpathLexerInitialize() { 0x1, 0x0, 0x0, 0x0, 0x4, 0x0, 0x1e, 0x25, 0x2d, 0x1, 0x1, 0x4, 0x0, }; - size_t serializedATNSize = 0; - serializedATNSize += sizeof(serializedATNSegment0) / sizeof(serializedATNSegment0[0]); - staticData->serializedATN.reserve(serializedATNSize); - - staticData->serializedATN.insert(staticData->serializedATN.end(), serializedATNSegment0, - serializedATNSegment0 + sizeof(serializedATNSegment0) / sizeof(serializedATNSegment0[0])); + staticData->serializedATN = antlr4::atn::SerializedATNView(serializedATNSegment, sizeof(serializedATNSegment) / sizeof(serializedATNSegment[0])); atn::ATNDeserializer deserializer; staticData->atn = deserializer.deserialize(staticData->serializedATN); @@ -151,7 +146,7 @@ const dfa::Vocabulary& XPathLexer::getVocabulary() const { return xpathLexerStaticData->vocabulary; } -const std::vector& XPathLexer::getSerializedATN() const { +antlr4::atn::SerializedATNView XPathLexer::getSerializedATN() const { return xpathLexerStaticData->serializedATN; } diff --git a/runtime/Cpp/runtime/src/tree/xpath/XPathLexer.h b/runtime/Cpp/runtime/src/tree/xpath/XPathLexer.h index bd6711077e..6926d2161e 100644 --- a/runtime/Cpp/runtime/src/tree/xpath/XPathLexer.h +++ b/runtime/Cpp/runtime/src/tree/xpath/XPathLexer.h @@ -28,7 +28,7 @@ class XPathLexer : public antlr4::Lexer { virtual const antlr4::dfa::Vocabulary& getVocabulary() const override; - virtual const std::vector& getSerializedATN() const override; + virtual antlr4::atn::SerializedATNView getSerializedATN() const override; virtual const antlr4::atn::ATN& getATN() const override; diff --git a/tool/resources/org/antlr/v4/tool/templates/codegen/Cpp/Cpp.stg b/tool/resources/org/antlr/v4/tool/templates/codegen/Cpp/Cpp.stg index 9ac3fea418..b36bde81b2 100644 --- a/tool/resources/org/antlr/v4/tool/templates/codegen/Cpp/Cpp.stg +++ b/tool/resources/org/antlr/v4/tool/templates/codegen/Cpp/Cpp.stg @@ -79,8 +79,9 @@ public: const antlr4::dfa::Vocabulary& getVocabulary() const override; - virtual const std::vector\& getSerializedATN() const override; - virtual const antlr4::atn::ATN& getATN() const override; + antlr4::atn::SerializedATNView getSerializedATN() const override; + + const antlr4::atn::ATN& getATN() const override; void action(antlr4::RuleContext *context, size_t ruleIndex, size_t actionIndex) override; @@ -138,7 +139,7 @@ struct StaticData final { const std::vector\ literalNames; const std::vector\ symbolicNames; const antlr4::dfa::Vocabulary vocabulary; - std::vector\ serializedATN; + antlr4::atn::SerializedATNView serializedATN; std::unique_ptr\ atn; }; @@ -199,7 +200,7 @@ const dfa::Vocabulary& ::getVocabulary() const { return LexerStaticData->vocabulary; } -const std::vector\& ::getSerializedATN() const { +antlr4::atn::SerializedATNView ::getSerializedATN() const { return LexerStaticData->serializedATN; } @@ -307,7 +308,7 @@ public: const antlr4::dfa::Vocabulary& getVocabulary() const override; - const std::vector\& getSerializedATN() const override; + antlr4::atn::SerializedATNView getSerializedATN() const override; @@ -356,7 +357,7 @@ struct StaticData final { const std::vector\ literalNames; const std::vector\ symbolicNames; const antlr4::dfa::Vocabulary vocabulary; - std::vector\ serializedATN; + antlr4::atn::SerializedATNView serializedATN; std::unique_ptr\ atn; }; @@ -407,7 +408,7 @@ const dfa::Vocabulary& ::getVocabulary() const { return ParserStaticData->vocabulary; } -const std::vector\& ::getSerializedATN() const { +antlr4::atn::SerializedATNView ::getSerializedATN() const { return ParserStaticData->serializedATN; } @@ -441,9 +442,7 @@ SerializedATN(model) ::= << static const int32_t serializedATNSegment[] = { }; separator=",", wrap> }; -staticData->serializedATN.reserve(sizeof(serializedATNSegment) / sizeof(serializedATNSegment[0])); -staticData->serializedATN.insert(staticData->serializedATN.end(), serializedATNSegment, - serializedATNSegment + sizeof(serializedATNSegment) / sizeof(serializedATNSegment[0])); +staticData->serializedATN = antlr4::atn::SerializedATNView(serializedATNSegment, sizeof(serializedATNSegment) / sizeof(serializedATNSegment[0])); antlr4::atn::ATNDeserializer deserializer; staticData->atn = deserializer.deserialize(staticData->serializedATN);