Skip to content

Commit

Permalink
defaults changed (NVIDIA#7600)
Browse files Browse the repository at this point in the history
* defaults changed

Signed-off-by: arendu <[email protected]>

* typo

Signed-off-by: arendu <[email protected]>

* update

Signed-off-by: arendu <[email protected]>

---------

Signed-off-by: arendu <[email protected]>
Signed-off-by: Sasha Meister <[email protected]>
  • Loading branch information
arendu authored and ssh-meister committed Oct 5, 2023
1 parent deb80c4 commit 620c011
Showing 1 changed file with 11 additions and 19 deletions.
30 changes: 11 additions & 19 deletions scripts/metric_calculation/peft_metric_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,18 @@


"""
This script can be used to calcualte exact match and F1 scores for many different tasks, not just squad.
Example command for T5 Preds
```
python squad_metric_calc.py \
--ground-truth squad_test_gt.jsonl \
--preds squad_preds_t5.txt
```
This script can be used to calcualte exact match and F1 scores for many different tasks.
The file "squad_test_predictions.jsonl" is assumed to be generated by the
`examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py` script
Example command for GPT Preds
```
python squad_metric_calc.py \
--ground-truth squad_test_gt.jsonl \
--preds squad_preds_gpt.txt \
--split-string "answer:"
python peft_metric_calc.py \
--pred_file squad_test_predictions.jsonl \
--label_field "original_answers" \
```
In this case, the prediction file will be split on "answer: " when looking for the LM's predicted answer.
"""

Expand Down Expand Up @@ -92,21 +84,21 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
def main():
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument(
'--pred-file',
'--pred_file',
type=str,
help="Text file with test set prompts + model predictions. Prediction file can be made by running NeMo/examples/nlp/language_modeling/megatron_gpt_prompt_learning_eval.py",
)
parser.add_argument(
'--pred-field',
'--pred_field',
type=str,
help="The field in the json file that contains the prediction tokens",
default="pred",
)
parser.add_argument(
'--ground-truth-field',
'--label_field',
type=str,
help="The field in the json file that contains the ground truth tokens",
default="original_answers",
default="label",
)

args = parser.parse_args()
Expand All @@ -120,7 +112,7 @@ def main():
pred_line = json.loads(preds[i])

pred_answer = pred_line[args.pred_field]
true_answers = pred_line[args.ground_truth_field]
true_answers = pred_line[args.label_field]
if not isinstance(true_answers, list):
true_answers = [true_answers]

Expand Down

0 comments on commit 620c011

Please sign in to comment.