Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 11 additions & 23 deletions examples/seq2seq/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
logger = getLogger(__name__)

try:
from .utils import calculate_bleu, calculate_rouge, use_task_specific_params
from .utils import calculate_bleu, calculate_rouge, parse_numeric_cl_kwargs, use_task_specific_params
except ImportError:
from utils import calculate_bleu, calculate_rouge, use_task_specific_params
from utils import calculate_bleu, calculate_rouge, parse_numeric_cl_kwargs, use_task_specific_params

DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

Expand All @@ -36,7 +36,6 @@ def generate_summaries_or_translations(
device: str = DEFAULT_DEVICE,
fp16=False,
task="summarization",
decoder_start_token_id=None,
**generate_kwargs,
) -> Dict:
"""Save model.generate results to <out_file>, and return how long it took."""
Expand All @@ -59,7 +58,6 @@ def generate_summaries_or_translations(
summaries = model.generate(
input_ids=batch.input_ids,
attention_mask=batch.attention_mask,
decoder_start_token_id=decoder_start_token_id,
**generate_kwargs,
)
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
Expand All @@ -77,30 +75,20 @@ def run_generate():
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
parser.add_argument("input_path", type=str, help="like cnn_dm/test.source")
parser.add_argument("save_path", type=str, help="where to save summaries")

parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
parser.add_argument(
"--score_path",
type=str,
required=False,
default="metrics.json",
help="where to save the rouge score in json format",
)
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics")
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
parser.add_argument("--task", type=str, default="summarization", help="typically translation or summarization")
parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
parser.add_argument(
"--decoder_start_token_id",
type=int,
default=None,
required=False,
help="Defaults to using config",
)
parser.add_argument(
"--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all."
)
parser.add_argument("--fp16", action="store_true")
args = parser.parse_args()
# Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
args, rest = parser.parse_known_args()
parsed = parse_numeric_cl_kwargs(rest)
if parsed:
print(f"parsed the following generate kwargs: {parsed}")
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
if args.n_obs > 0:
examples = examples[: args.n_obs]
Expand All @@ -115,7 +103,7 @@ def run_generate():
device=args.device,
fp16=args.fp16,
task=args.task,
decoder_start_token_id=args.decoder_start_token_id,
**parsed,
)
if args.reference_path is None:
return
Expand Down
4 changes: 4 additions & 0 deletions examples/seq2seq/test_seq2seq_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,10 @@ def test_run_eval(model):
score_path,
"--task",
task,
"--num_beams",
"2",
"--length_penalty",
"2.0",
]
with patch.object(sys, "argv", testargs):
run_generate()
Expand Down
22 changes: 21 additions & 1 deletion examples/seq2seq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pickle
from logging import getLogger
from pathlib import Path
from typing import Callable, Dict, Iterable, List
from typing import Callable, Dict, Iterable, List, Union

import git
import numpy as np
Expand Down Expand Up @@ -309,3 +309,23 @@ def assert_not_all_frozen(model):
model_grads: List[bool] = list(grad_status(model))
npars = len(model_grads)
assert any(model_grads), f"none of {npars} weights require grad"


# CLI Parsing utils


def parse_numeric_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float]]:
"""Parse an argv list of unspecified command line args to a dict. Assumes all values are numeric."""
result = {}
assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}"
num_pairs = len(unparsed_args) // 2
for pair_num in range(num_pairs):
i = 2 * pair_num
assert unparsed_args[i].startswith("--")
try:
value = int(unparsed_args[i + 1])
except ValueError:
value = float(unparsed_args[i + 1]) # this can raise another informative ValueError

result[unparsed_args[i][2:]] = value
return result