File tree Expand file tree Collapse file tree 4 files changed +40
-19
lines changed
include/pytorch/tokenizers Expand file tree Collapse file tree 4 files changed +40
-19
lines changed Original file line number Diff line number Diff line change 2525#include < pytorch/tokenizers/string_integer_map.h>
2626#include < pytorch/tokenizers/tokenizer.h>
2727
28+ #include " re2/re2.h"
29+
2830namespace tokenizers {
2931namespace detail {
3032
@@ -104,6 +106,25 @@ static Result<TokenMap> buildTokenMap(
104106 return buildTokenMap (std::move (pairs));
105107}
106108
109+ inline Result<std::unique_ptr<IRegex>> build_special_token_regex (
110+ const TokenMap& special_token_map) {
111+ std::string special_pattern;
112+ const std::size_t count = special_token_map.size ();
113+
114+ for (std::size_t i = 0 ; i < count; ++i) {
115+ const auto & [token, _] = special_token_map.getElement (i);
116+ if (!special_pattern.empty ()) {
117+ special_pattern += " |" ;
118+ }
119+ special_pattern += re2::RE2::QuoteMeta (std::string (token));
120+ }
121+
122+ if (special_pattern.empty ()) {
123+ return static_cast <std::unique_ptr<IRegex>>(nullptr );
124+ }
125+ return create_regex (special_pattern);
126+ }
127+
107128class BPETokenizerBase : public Tokenizer {
108129 public:
109130 Result<std::vector<uint64_t >>
Original file line number Diff line number Diff line change @@ -69,6 +69,12 @@ Error HFTokenizer::load(const std::string& path) {
6969 special_tokens,
7070 [](const auto & it) -> std::string { return it.at (" content" ); },
7171 [](const auto & it) -> std::uint64_t { return it.at (" id" ); }));
72+
73+ // Create special token regex to help later with encoding.
74+ special_token_regex_ =
75+ TK_UNWRAP (detail::build_special_token_regex (special_token_map));
76+
77+ // Store for future use.
7278 special_token_map_.emplace (std::move (special_token_map));
7379 } catch (const json::out_of_range& e) {
7480 fprintf (stderr, " Could not parse special tokens: %s\n " , e.what ());
@@ -142,8 +148,15 @@ Error HFTokenizer::load(const std::string& path) {
142148
143149 // Pull out the token strings
144150 try {
145- const std::string bos_token = parsed_config_json.at (" bos_token" );
146- const std::string eos_token = parsed_config_json.at (" eos_token" );
151+ const std::string bos_token = parsed_config_json.contains (" bos_token" ) &&
152+ !parsed_config_json[" bos_token" ].is_null ()
153+ ? parsed_config_json[" bos_token" ].get <std::string>()
154+ : " " ;
155+
156+ const std::string eos_token = parsed_config_json.contains (" eos_token" ) &&
157+ !parsed_config_json[" eos_token" ].is_null ()
158+ ? parsed_config_json[" eos_token" ].get <std::string>()
159+ : " " ;
147160 const auto bos_res = special_token_map_->tryGetInteger (bos_token);
148161 const auto eos_res = special_token_map_->tryGetInteger (eos_token);
149162 if (!bos_res) {
Original file line number Diff line number Diff line change 3232#include < fstream>
3333#include < limits>
3434#include < unordered_set>
35- #include " re2/re2.h"
3635
3736namespace tokenizers {
3837
@@ -47,21 +46,6 @@ static Result<std::unique_ptr<IRegex>> _create_regex(
4746 return create_regex (pattern);
4847}
4948
50- static Result<std::unique_ptr<IRegex>> _build_special_token_regex (
51- const std::vector<std::pair<std::string, std::uint64_t >>& special_encoder) {
52- std::string special_pattern;
53- for (const auto & ele : special_encoder) {
54- if (!special_pattern.empty ()) {
55- special_pattern += " |" ;
56- }
57- special_pattern += re2::RE2::QuoteMeta (ele.first );
58- }
59- if (special_pattern.empty ()) {
60- return static_cast <std::unique_ptr<IRegex>>(nullptr );
61- }
62- return _create_regex (special_pattern);
63- }
64-
6549static Result<std::pair<std::string, uint64_t >> _parse (
6650 const std::string& line) {
6751 // Tiktoken format
@@ -153,7 +137,7 @@ Error Tiktoken::load(const std::string& path) {
153137
154138 _regex = TK_UNWRAP (_create_regex (_pattern));
155139 special_token_regex_ =
156- TK_UNWRAP (_build_special_token_regex ( special_token_map));
140+ TK_UNWRAP (detail::build_special_token_regex ( TokenMap ( special_token_map) ));
157141
158142 // initialize vocab_size, bos_tok, eos_tok
159143 vocab_size_ = token_map_->size () + special_token_map_->size ();
Original file line number Diff line number Diff line change @@ -77,6 +77,9 @@ def define_common_targets():
7777 exported_deps = [
7878 ":headers" ,
7979 ],
80+ exported_external_deps = [
81+ "re2" ,
82+ ],
8083 visibility = [
8184 "//pytorch/tokenizers/..." ,
8285 ],
You can’t perform that action at this time.
0 commit comments