diff --git a/examples/models/llama2/runner/runner.cpp b/examples/models/llama2/runner/runner.cpp index af7c25ec67d..45b7975333a 100644 --- a/examples/models/llama2/runner/runner.cpp +++ b/examples/models/llama2/runner/runner.cpp @@ -81,14 +81,16 @@ Error Runner::load() { if (tokenizer_->bos_tok() != bos_id_) { ET_LOG( Error, - "Tokenizer's BOS id %d does not match model's BOS id %d, will override tokenizer's BOS.", + "Tokenizer's BOS id %" PRIu64 + " does not match model's BOS id %d, will override tokenizer's BOS.", tokenizer_->bos_tok(), bos_id_); } if (tokenizer_->eos_tok() != eos_id_) { ET_LOG( Error, - "Tokenizer's EOS id %d does not match model's EOS id %d, will override tokenizer's EOS.", + "Tokenizer's EOS id %" PRIu64 + " does not match model's EOS id %d, will override tokenizer's EOS.", tokenizer_->eos_tok(), eos_id_); } @@ -227,20 +229,18 @@ Error Runner::generate( stats_.inference_start_ms = util::time_in_ms(); shouldStop_ = false; - // encode the (string) prompt into tokens sequence - int num_prompt_tokens = 0; - // max # of prompt tokens: len(prompt) + '\0', ?BOS, ?EOS - int* prompt_tokens = new int[prompt.size() + 1 + n_bos_ + n_eos_]; - // Set the sequence length to the max seq length if not provided seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_; - tokenizer_->encode( - prompt.c_str(), - n_bos_, - append_eos_ ? n_eos_ : 0, - prompt_tokens, - &num_prompt_tokens); + Result> encode_res = + tokenizer_->encode(prompt, n_bos_, append_eos_ ? n_eos_ : 0); + + ET_CHECK_OK_OR_RETURN_ERROR( + encode_res.error(), "Failed to encode prompt %s", prompt.c_str()); + + // encode the (string) prompt into tokens sequence + std::vector prompt_tokens = encode_res.get(); + int num_prompt_tokens = prompt_tokens.size(); ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token"); ET_CHECK_MSG( @@ -303,13 +303,13 @@ Error Runner::generate( // Print the prompt for consistent output between single token prefill and // batch prefill. - int prev = prompt_tokens[0]; - int cur; + uint64_t prev = prompt_tokens[0]; + uint64_t cur; for (int i = 1; i < num_prompt_tokens; i++) { cur = prompt_tokens[i]; auto piece_res = tokenizer_->decode(prev, cur); ET_CHECK_OK_OR_RETURN_ERROR(piece_res.error()); - util::safe_printf(piece_res.get()); + util::safe_printf(piece_res.get().c_str()); fflush(stdout); prev = cur; } @@ -361,7 +361,7 @@ Error Runner::generate( // print the token as string, decode it with the Tokenizer object auto piece_res = tokenizer_->decode(prev_token, cur_token); ET_CHECK(piece_res.ok()); - const char* piece = piece_res.get(); + const char* piece = piece_res.get().c_str(); // same as printf("%s", piece), but skips "unsafe" bytes util::safe_printf(piece); @@ -396,7 +396,6 @@ Error Runner::generate( stats_callback(stats_); } - delete[] prompt_tokens; return Error::Ok; } diff --git a/examples/models/llama2/tokenizer/test/test_tokenizer.cpp b/examples/models/llama2/tokenizer/test/test_tokenizer.cpp index 95fb2be7829..787f008568c 100644 --- a/examples/models/llama2/tokenizer/test/test_tokenizer.cpp +++ b/examples/models/llama2/tokenizer/test/test_tokenizer.cpp @@ -9,6 +9,7 @@ #include #include #include +#include using namespace ::testing; @@ -28,8 +29,8 @@ class TokenizerExtensionTest : public Test { }; TEST_F(TokenizerExtensionTest, EncodeWithoutLoadFails) { - Error error = tokenizer_->encode("hello world", 0, 0, nullptr, nullptr); - EXPECT_EQ(error, Error::NotSupported); + Result> res = tokenizer_->encode("hello world", 0, 0); + EXPECT_EQ(res.error(), Error::NotSupported); } TEST_F(TokenizerExtensionTest, DecodeWithoutLoadFails) { diff --git a/examples/models/llama2/tokenizer/tokenizer.cpp b/examples/models/llama2/tokenizer/tokenizer.cpp index b380cc675b4..40fc3d5683e 100644 --- a/examples/models/llama2/tokenizer/tokenizer.cpp +++ b/examples/models/llama2/tokenizer/tokenizer.cpp @@ -23,7 +23,7 @@ static int compare_tokens(const void* a, const void* b) { return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str); } -Tokenizer::Tokenizer(int32_t vocab_size, int32_t bos_tok, int32_t eos_tok) +Tokenizer::Tokenizer(int32_t vocab_size, uint64_t bos_tok, uint64_t eos_tok) : initialized_(false), vocab_size_(vocab_size), bos_tok_(bos_tok), @@ -142,10 +142,10 @@ Tokenizer::~Tokenizer() { * * @param prev_token The previous token. * @param token The current token. - * @return Result A pointer to the string representation of the + * @return Result A pointer to the string representation of the * token. */ -Result Tokenizer::decode(int32_t prev_token, int32_t token) { +Result Tokenizer::decode(uint64_t prev_token, uint64_t token) { if (!initialized_) { ET_LOG(Error, "Tokenizer not initialized"); return Error::NotSupported; @@ -162,7 +162,8 @@ Result Tokenizer::decode(int32_t prev_token, int32_t token) { if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) { piece = (char*)byte_pieces_ + byte_val * 2; } - return piece; + std::string res(piece); + return res; } static int32_t @@ -183,14 +184,10 @@ str_lookup(const char* str, TokenIndex* sorted_vocab, int32_t vocab_size) { * @param eos The number of EOS to append to the token list. * @param tokens The output tokens. * @param n_tokens The number of tokens. - * @return Error + * @return Result> */ -Error Tokenizer::encode( - const char* text, - int8_t bos, - int8_t eos, - int32_t* tokens, - int32_t* n_tokens) { +Result> +Tokenizer::encode(const std::string& text, int8_t bos, int8_t eos) { if (!initialized_) { ET_LOG(Error, "Tokenizer not initialized"); return Error::NotSupported; @@ -198,8 +195,8 @@ Error Tokenizer::encode( // encode the string text (input) into an upper-bound preallocated tokens[] // array bos != 0 means prepend the BOS token (=1), eos != 0 means append the // EOS token (=2) - if (text == nullptr) { - ET_LOG(Error, "cannot encode null text"); + if (text.empty()) { + ET_LOG(Error, "cannot encode empty text"); return Error::InvalidArgument; } @@ -210,12 +207,12 @@ Error Tokenizer::encode( size_t str_len = 0; // start at 0 tokens - *n_tokens = 0; + std::vector tokens; // add optional BOS token, if desired if (bos > 0) { while (bos--) { - tokens[(*n_tokens)++] = bos_tok_; + tokens.push_back(bos_tok_); } } else { ET_LOG(Error, "bos %d should be >= 0", bos); @@ -230,7 +227,7 @@ Error Tokenizer::encode( const char* space = " "; if (text[0] != '\0') { int dummy_prefix = str_lookup(space, sorted_vocab_.get(), vocab_size_); - tokens[(*n_tokens)++] = dummy_prefix; + tokens.push_back(dummy_prefix); } // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia: @@ -242,7 +239,7 @@ Error Tokenizer::encode( // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx // process the raw (UTF-8) byte sequence of the input string - for (const char* c = text; *c != '\0'; c++) { + for (const char* c = text.c_str(); *c != '\0'; c++) { // reset buffer if the current byte is ASCII or a leading byte // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the // rest 0x80 is 10000000 in UTF-8, all continuation bytes start with "10" in @@ -271,13 +268,13 @@ Error Tokenizer::encode( int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_); if (id != -1) { // we found this codepoint in vocab, add it as a token - tokens[(*n_tokens)++] = id; + tokens.push_back(id); } else { // byte_fallback encoding: just encode each byte as a token // +3 is here because the first 3 vocab elements are , , // so the individual bytes only start at index 3 for (int i = 0; i < str_len; i++) { - tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3; + tokens.push_back((unsigned char)str_buffer[i] + 3); } } str_len = 0; // protect against a sequence of stray UTF8 continuation bytes @@ -290,7 +287,7 @@ Error Tokenizer::encode( int best_id = -1; int best_idx = -1; - for (int i = 0; i < (*n_tokens - 1); i++) { + for (int i = 0; i < tokens.size() - 1; i++) { // check if we can merge the pair (tokens[i], tokens[i+1]) snprintf( str_buffer, @@ -314,16 +311,16 @@ Error Tokenizer::encode( // merge the consecutive pair (best_idx, best_idx+1) into new token best_id tokens[best_idx] = best_id; // delete token at position best_idx+1, shift the entire sequence back 1 - for (int i = best_idx + 1; i < (*n_tokens - 1); i++) { + for (int i = best_idx + 1; i < tokens.size() - 1; i++) { tokens[i] = tokens[i + 1]; } - (*n_tokens)--; // token length decreased + tokens.pop_back(); // token length decreased } // add optional EOS (=2) token, if desired if (eos >= 0) { while (eos--) { - tokens[(*n_tokens)++] = eos_tok_; + tokens.push_back(eos_tok_); } } else { ET_LOG(Error, "eos %d should be >= 0", eos); @@ -331,7 +328,7 @@ Error Tokenizer::encode( } delete[] str_buffer; - return Error::Ok; + return Result(tokens); } } // namespace executor diff --git a/examples/models/llama2/tokenizer/tokenizer.h b/examples/models/llama2/tokenizer/tokenizer.h index 0edc4671b17..6b03278eace 100644 --- a/examples/models/llama2/tokenizer/tokenizer.h +++ b/examples/models/llama2/tokenizer/tokenizer.h @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -32,37 +33,33 @@ struct TokenIndex { class Tokenizer { public: - explicit Tokenizer(int32_t vocab_size, int32_t bos_tok, int32_t eos_tok); + explicit Tokenizer(int32_t vocab_size, uint64_t bos_tok, uint64_t eos_tok); ~Tokenizer(); Error load(const std::string& tokenizer_path); - Error encode( - const char* text, - int8_t bos, - int8_t eos, - int32_t* tokens, - int32_t* n_tokens); + Result> + encode(const std::string& input, int8_t bos, int8_t eos); - Result decode(int prev_token, int token); + Result decode(uint64_t prev_token, uint64_t token); // getters int32_t vocab_size() const { return vocab_size_; } - int32_t bos_tok() const { + uint64_t bos_tok() const { return bos_tok_; } - int32_t eos_tok() const { + uint64_t eos_tok() const { return eos_tok_; } private: bool initialized_; const int32_t vocab_size_; - int32_t bos_tok_, eos_tok_; + uint64_t bos_tok_, eos_tok_; std::unique_ptr vocab_; std::unique_ptr vocab_scores_; std::unique_ptr sorted_vocab_;