Skip to content

Commit

Permalink
dehardcode test string (NVIDIA#8865)
Browse files Browse the repository at this point in the history
* dehardcode test string

Signed-off-by: Jimmy Zhang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update nemo_utils.py

Signed-off-by: JimmyZhang12 <[email protected]>

---------

Signed-off-by: Jimmy Zhang <[email protected]>
Signed-off-by: JimmyZhang12 <[email protected]>
Co-authored-by: Jimmy Zhang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Onur Yilmaz <[email protected]>
  • Loading branch information
4 people authored May 3, 2024
1 parent e19dd7a commit f60d183
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions nemo/export/trt_llm/nemo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ def nemo_llm_to_model_config(
return model_configs, tokenizer


def to_word_list_format(word_dict: List[List[str]], tokenizer=None):
def to_word_list_format(
word_dict: List[List[str]], tokenizer=None, ref_str="<extra_id_1>",
):
'''
format of word_dict
len(word_dict) should be same to batch_size
Expand All @@ -207,10 +209,9 @@ def to_word_list_format(word_dict: List[List[str]], tokenizer=None):

flat_ids = []
offsets = []
# We use a similar trick as in NeMo to deal with the fact that the encoding of a single word
# can't always be trusted. See
# The encoding of a single word can't always be trusted. See
# https://github.com/NVIDIA/NeMo/blob/bb575b72fd0be51ae10cc77d9f89ddb9e9d3b96d/nemo/collections/nlp/modules/common/text_generation_strategy.py#L229
ids_ref = tokenizer.encode("<extra_id_1>")
ids_ref = tokenizer.encode(ref_str)
for word_dict_item in word_dict:
item_flat_ids = []
item_offsets = []
Expand All @@ -220,7 +221,7 @@ def to_word_list_format(word_dict: List[List[str]], tokenizer=None):

words = list(csv.reader(word_dict_item))[0]
for word in words:
ids = tokenizer.encode(f"<extra_id_1>{word}")
ids = tokenizer.encode(f"{ref_str}{word}")
if ids[0 : len(ids_ref)] == ids_ref:
# It worked! We can obtain the token(s) associated to `word` by stripping the prefix tokens.
ids = ids[len(ids_ref) :]
Expand Down

0 comments on commit f60d183

Please sign in to comment.