Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nemo_skills/dataset/mrcr/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def count_n_tokens(messages: list[dict]) -> int:
def write_data_to_file(output_file, data, max_context_window, needles_subset):
with open(output_file, "wt", encoding="utf-8") as fout:
for idx, entry in tqdm(enumerate(data), desc=f"Writing {output_file.name}"):
messages = json.loads(entry["prompt"])
messages = json.loads(entry.pop("prompt"))

if entry['n_needles'] not in needles_subset:
print(f"Skipping {idx} because it has {entry['n_needles']} needle")
Expand All @@ -56,7 +56,7 @@ def write_data_to_file(output_file, data, max_context_window, needles_subset):
print(f"Skipping {idx} because it has {n_tokens} tokens")
continue

entry['messages'] = entry.pop('prompt')
entry['messages'] = messages
entry['expected_answer'] = entry.pop('answer')
entry['n_tokens'] = n_tokens
json.dump(entry, fout)
Expand Down