Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions examples/pytorch/summarization/run_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,16 @@ def compute_metrics(eval_preds):
result["gen_len"] = np.mean(prediction_lens)
return result

# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
training_args.generation_num_beams = (
data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
)

# Initialize our Trainer
trainer = Seq2SeqTrainer(
model=model,
Expand Down Expand Up @@ -672,15 +682,9 @@ def compute_metrics(eval_preds):

# Evaluation
results = {}
max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
metrics = trainer.evaluate(metric_key_prefix="eval")
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))

Expand All @@ -690,9 +694,7 @@ def compute_metrics(eval_preds):
if training_args.do_predict:
logger.info("*** Predict ***")

predict_results = trainer.predict(
predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
)
predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict")
metrics = predict_results.metrics
max_predict_samples = (
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,8 @@ def parse_args():
parser.add_argument(
"--max_length",
type=int,
default=128,
help=(
"The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
" sequences shorter will be padded if `--pad_to_max_lengh` is passed."
),
default=None,
help=("The maximum target length to use when predicting with the generate method."),
Comment thread
bofenghuang marked this conversation as resolved.
Outdated
)
parser.add_argument(
"--num_beams",
Expand Down Expand Up @@ -671,7 +668,7 @@ def postprocess_text(preds, labels):
args.val_max_target_length = args.max_target_length

gen_kwargs = {
"max_length": args.val_max_target_length if args is not None else config.max_length,
"max_length": args.max_length if args.max_length is not None else args.val_max_target_length,
Comment thread
bofenghuang marked this conversation as resolved.
Outdated
"num_beams": args.num_beams,
}
for step, batch in enumerate(eval_dataloader):
Expand Down