Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dehardcode test string #8865

Merged
merged 9 commits into from
May 3, 2024
11 changes: 6 additions & 5 deletions nemo/export/trt_llm/nemo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ def nemo_llm_to_model_config(
return model_configs, tokenizer


def to_word_list_format(word_dict: List[List[str]], tokenizer=None):
def to_word_list_format(
word_dict: List[List[str]], tokenizer=None, ref_str="<extra_id_1>",
):
'''
format of word_dict
len(word_dict) should be same to batch_size
Expand All @@ -207,10 +209,9 @@ def to_word_list_format(word_dict: List[List[str]], tokenizer=None):

flat_ids = []
offsets = []
# We use a similar trick as in NeMo to deal with the fact that the encoding of a single word
# can't always be trusted. See
# The encoding of a single word can't always be trusted. See
# https://github.com/NVIDIA/NeMo/blob/bb575b72fd0be51ae10cc77d9f89ddb9e9d3b96d/nemo/collections/nlp/modules/common/text_generation_strategy.py#L229
JimmyZhang12 marked this conversation as resolved.
Show resolved Hide resolved
ids_ref = tokenizer.encode("<extra_id_1>")
ids_ref = tokenizer.encode(ref_str)
for word_dict_item in word_dict:
item_flat_ids = []
item_offsets = []
Expand All @@ -220,7 +221,7 @@ def to_word_list_format(word_dict: List[List[str]], tokenizer=None):

words = list(csv.reader(word_dict_item))[0]
for word in words:
ids = tokenizer.encode(f"<extra_id_1>{word}")
ids = tokenizer.encode(f"{ref_str}{word}")
if ids[0 : len(ids_ref)] == ids_ref:
# It worked! We can obtain the token(s) associated to `word` by stripping the prefix tokens.
ids = ids[len(ids_ref) :]
Expand Down
Loading