diff --git a/examples/seq2seq/run_eval.py b/examples/seq2seq/run_eval.py index f36ca7101be3..cb94fbcde36d 100644 --- a/examples/seq2seq/run_eval.py +++ b/examples/seq2seq/run_eval.py @@ -15,9 +15,9 @@ logger = getLogger(__name__) try: - from .utils import calculate_bleu, calculate_rouge, use_task_specific_params + from .utils import calculate_bleu, calculate_rouge, parse_numeric_cl_kwargs, use_task_specific_params except ImportError: - from utils import calculate_bleu, calculate_rouge, use_task_specific_params + from utils import calculate_bleu, calculate_rouge, parse_numeric_cl_kwargs, use_task_specific_params DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" @@ -36,7 +36,6 @@ def generate_summaries_or_translations( device: str = DEFAULT_DEVICE, fp16=False, task="summarization", - decoder_start_token_id=None, **generate_kwargs, ) -> Dict: """Save model.generate results to , and return how long it took.""" @@ -59,7 +58,6 @@ def generate_summaries_or_translations( summaries = model.generate( input_ids=batch.input_ids, attention_mask=batch.attention_mask, - decoder_start_token_id=decoder_start_token_id, **generate_kwargs, ) dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) @@ -77,30 +75,20 @@ def run_generate(): parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.") parser.add_argument("input_path", type=str, help="like cnn_dm/test.source") parser.add_argument("save_path", type=str, help="where to save summaries") - - parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt") - parser.add_argument( - "--score_path", - type=str, - required=False, - default="metrics.json", - help="where to save the rouge score in json format", - ) + parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target") + parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics") parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.") - parser.add_argument("--task", type=str, default="summarization", help="typically translation or summarization") + parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics") parser.add_argument("--bs", type=int, default=8, required=False, help="batch size") - parser.add_argument( - "--decoder_start_token_id", - type=int, - default=None, - required=False, - help="Defaults to using config", - ) parser.add_argument( "--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all." ) parser.add_argument("--fp16", action="store_true") - args = parser.parse_args() + # Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate + args, rest = parser.parse_known_args() + parsed = parse_numeric_cl_kwargs(rest) + if parsed: + print(f"parsed the following generate kwargs: {parsed}") examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()] if args.n_obs > 0: examples = examples[: args.n_obs] @@ -115,7 +103,7 @@ def run_generate(): device=args.device, fp16=args.fp16, task=args.task, - decoder_start_token_id=args.decoder_start_token_id, + **parsed, ) if args.reference_path is None: return diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 5121e165aa7d..3f4ff2a31d19 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -300,6 +300,10 @@ def test_run_eval(model): score_path, "--task", task, + "--num_beams", + "2", + "--length_penalty", + "2.0", ] with patch.object(sys, "argv", testargs): run_generate() diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 2cee41657445..604cc6907366 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -5,7 +5,7 @@ import pickle from logging import getLogger from pathlib import Path -from typing import Callable, Dict, Iterable, List +from typing import Callable, Dict, Iterable, List, Union import git import numpy as np @@ -309,3 +309,23 @@ def assert_not_all_frozen(model): model_grads: List[bool] = list(grad_status(model)) npars = len(model_grads) assert any(model_grads), f"none of {npars} weights require grad" + + +# CLI Parsing utils + + +def parse_numeric_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float]]: + """Parse an argv list of unspecified command line args to a dict. Assumes all values are numeric.""" + result = {} + assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}" + num_pairs = len(unparsed_args) // 2 + for pair_num in range(num_pairs): + i = 2 * pair_num + assert unparsed_args[i].startswith("--") + try: + value = int(unparsed_args[i + 1]) + except ValueError: + value = float(unparsed_args[i + 1]) # this can raise another informative ValueError + + result[unparsed_args[i][2:]] = value + return result