From 30f4376df10122d838042f4e34d30e20396b34a8 Mon Sep 17 00:00:00 2001 From: bofenghuang Date: Wed, 22 Feb 2023 13:57:54 +0100 Subject: [PATCH 1/6] Override the decoding parameters of Seq2SeqTrainer --- .../pytorch/summarization/run_summarization.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index e0b7fc214ec6..238cd0b5229e 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -639,6 +639,10 @@ def compute_metrics(eval_preds): result["gen_len"] = np.mean(prediction_lens) return result + # Override the decoding parameters of Seq2SeqTrainer + training_args.generation_max_length = training_args.generation_max_length if training_args.generation_max_length is not None else data_args.val_max_target_length + training_args.generation_num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams + # Initialize our Trainer trainer = Seq2SeqTrainer( model=model, @@ -672,15 +676,9 @@ def compute_metrics(eval_preds): # Evaluation results = {} - max_length = ( - training_args.generation_max_length - if training_args.generation_max_length is not None - else data_args.val_max_target_length - ) - num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams if training_args.do_eval: logger.info("*** Evaluate ***") - metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval") + metrics = trainer.evaluate(metric_key_prefix="eval") max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) @@ -691,7 +689,7 @@ def compute_metrics(eval_preds): logger.info("*** Predict ***") predict_results = trainer.predict( - predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams + predict_dataset, metric_key_prefix="predict" ) metrics = predict_results.metrics max_predict_samples = ( From 70b80e91ca98cfeaa94957df69af70a4bee54b9d Mon Sep 17 00:00:00 2001 From: bofenghuang Date: Wed, 22 Feb 2023 14:00:03 +0100 Subject: [PATCH 2/6] Fix quality --- .../pytorch/summarization/run_summarization.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index 238cd0b5229e..2645f677f1c6 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -640,8 +640,14 @@ def compute_metrics(eval_preds): return result # Override the decoding parameters of Seq2SeqTrainer - training_args.generation_max_length = training_args.generation_max_length if training_args.generation_max_length is not None else data_args.val_max_target_length - training_args.generation_num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams + training_args.generation_max_length = ( + training_args.generation_max_length + if training_args.generation_max_length is not None + else data_args.val_max_target_length + ) + training_args.generation_num_beams = ( + data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams + ) # Initialize our Trainer trainer = Seq2SeqTrainer( @@ -688,9 +694,7 @@ def compute_metrics(eval_preds): if training_args.do_predict: logger.info("*** Predict ***") - predict_results = trainer.predict( - predict_dataset, metric_key_prefix="predict" - ) + predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict") metrics = predict_results.metrics max_predict_samples = ( data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) From 366a72e6f95310e56dfd0991bfdf1d4c8e7b16b8 Mon Sep 17 00:00:00 2001 From: bofenghuang Date: Wed, 22 Feb 2023 14:07:42 +0100 Subject: [PATCH 3/6] Fix max_length parameter --- .../pytorch/summarization/run_summarization_no_trainer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/pytorch/summarization/run_summarization_no_trainer.py b/examples/pytorch/summarization/run_summarization_no_trainer.py index 8f669be72c58..2fbe66ab0a6d 100644 --- a/examples/pytorch/summarization/run_summarization_no_trainer.py +++ b/examples/pytorch/summarization/run_summarization_no_trainer.py @@ -164,10 +164,9 @@ def parse_args(): parser.add_argument( "--max_length", type=int, - default=128, + default=None, help=( - "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," - " sequences shorter will be padded if `--pad_to_max_lengh` is passed." + "The maximum target length to use when predicting with the generate method." ), ) parser.add_argument( @@ -671,7 +670,7 @@ def postprocess_text(preds, labels): args.val_max_target_length = args.max_target_length gen_kwargs = { - "max_length": args.val_max_target_length if args is not None else config.max_length, + "max_length": args.max_length if args.max_length is not None else args.val_max_target_length, "num_beams": args.num_beams, } for step, batch in enumerate(eval_dataloader): From fb7bc23b4e6256ce33b95af4ad20ca3ff8442c1d Mon Sep 17 00:00:00 2001 From: bofenghuang Date: Wed, 22 Feb 2023 14:08:59 +0100 Subject: [PATCH 4/6] Fix quality --- .../pytorch/summarization/run_summarization_no_trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/pytorch/summarization/run_summarization_no_trainer.py b/examples/pytorch/summarization/run_summarization_no_trainer.py index 2fbe66ab0a6d..4522107e15ab 100644 --- a/examples/pytorch/summarization/run_summarization_no_trainer.py +++ b/examples/pytorch/summarization/run_summarization_no_trainer.py @@ -165,9 +165,7 @@ def parse_args(): "--max_length", type=int, default=None, - help=( - "The maximum target length to use when predicting with the generate method." - ), + help=("The maximum target length to use when predicting with the generate method."), ) parser.add_argument( "--num_beams", From 2596d1bb0094576d02c442da8727c70a6162c4c0 Mon Sep 17 00:00:00 2001 From: bofenghuang Date: Fri, 24 Feb 2023 10:15:16 +0100 Subject: [PATCH 5/6] Remove redundant parameter max_length --- .../pytorch/summarization/run_summarization_no_trainer.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/examples/pytorch/summarization/run_summarization_no_trainer.py b/examples/pytorch/summarization/run_summarization_no_trainer.py index 4522107e15ab..8b3fa2b8b27a 100644 --- a/examples/pytorch/summarization/run_summarization_no_trainer.py +++ b/examples/pytorch/summarization/run_summarization_no_trainer.py @@ -161,12 +161,6 @@ def parse_args(): "param of ``model.generate``, which is used during ``evaluate`` and ``predict``." ), ) - parser.add_argument( - "--max_length", - type=int, - default=None, - help=("The maximum target length to use when predicting with the generate method."), - ) parser.add_argument( "--num_beams", type=int, @@ -668,7 +662,7 @@ def postprocess_text(preds, labels): args.val_max_target_length = args.max_target_length gen_kwargs = { - "max_length": args.max_length if args.max_length is not None else args.val_max_target_length, + "max_length": args.val_max_target_length, "num_beams": args.num_beams, } for step, batch in enumerate(eval_dataloader): From dd0b5a8667528c3e61cee6ec44365c55b8d85b87 Mon Sep 17 00:00:00 2001 From: bofenghuang Date: Fri, 24 Feb 2023 10:27:39 +0100 Subject: [PATCH 6/6] Separate the preprocess of train and validation to use different max_target_length --- .../run_summarization_no_trainer.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/examples/pytorch/summarization/run_summarization_no_trainer.py b/examples/pytorch/summarization/run_summarization_no_trainer.py index 8b3fa2b8b27a..b16a3fd06900 100644 --- a/examples/pytorch/summarization/run_summarization_no_trainer.py +++ b/examples/pytorch/summarization/run_summarization_no_trainer.py @@ -464,6 +464,9 @@ def main(): f"--summary_column' value '{args.summary_column}' needs to be one of: {', '.join(column_names)}" ) + if args.val_max_target_length is None: + args.val_max_target_length = args.max_target_length + # Temporarily set max_target_length for training. max_target_length = args.max_target_length padding = "max_length" if args.pad_to_max_length else False @@ -488,7 +491,7 @@ def preprocess_function(examples): return model_inputs with accelerator.main_process_first(): - processed_datasets = raw_datasets.map( + train_dataset = raw_datasets["train"].map( preprocess_function, batched=True, num_proc=args.preprocessing_num_workers, @@ -497,8 +500,16 @@ def preprocess_function(examples): desc="Running tokenizer on dataset", ) - train_dataset = processed_datasets["train"] - eval_dataset = processed_datasets["validation"] + # Temporarily set max_target_length for validation. + max_target_length = args.val_max_target_length + eval_dataset = raw_datasets["validation"].map( + preprocess_function, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on dataset", + ) # Log a few random samples from the training set: for index in random.sample(range(len(train_dataset)), 1): @@ -658,8 +669,6 @@ def postprocess_text(preds, labels): break model.eval() - if args.val_max_target_length is None: - args.val_max_target_length = args.max_target_length gen_kwargs = { "max_length": args.val_max_target_length,