Skip to content

Conversation

@sgugger
Copy link
Collaborator

@sgugger sgugger commented Oct 19, 2020

What does this PR do?

This PR cleans up the run_glue.py script to use the Datasets library. Along the way it adds a few fixes in Trainer. The script supports all glue tasks as well as custom user tasks (passed along with a training and validation file in csv or json format). It has been tested on the following setups:

  • single GPU
  • multi-GPU with DataParallel
  • multi-GPU with DistributedDataParallel
  • TPU

The README has been updated to reflect the changes, there is just one breaking change from before which is that data_dir is not an accepted argument anymore (since Datasets will take care of downloading the data files).

if self.control.should_evaluate:
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving this to the end of evaluate othewise that even is not called when we call trainer.evaluate() independently.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a fantastic change, it makes this example's code so much easier to read imo.

Comment on lines -34 to -36
glue_compute_metrics,
glue_output_modes,
glue_tasks_num_labels,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very nice

Comment on lines 141 to 143
if is_main_process(training_args.local_rank):
logging.set_verbosity_info()
logger.info(f"Training/evaluation parameters {training_args}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you confirmed this actually works? It seems to me that you're setting the default verbosity level of the root logger (so the loggers of transformers and every file contained in it), but the logger of the current file isn't a child of this logger (it's in examples/, not in src/transformers/)so it doesn't look like it'll be impacted by that change.

I would argue you would still need to change the current logger's default verbosity to info if you want to see the line logger.info(f"Training/evaluation parameters {training_args}") being printed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does work and I can see all the info being printed on my screen.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, this should be working now.

Copy link
Member

@thomwolf thomwolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really cool!

A few user experience comments :)

"sentence2_key": sentence2_key,
"max_length": data_args.max_seq_length,
}
datasets = datasets.map(preprocess_function, batched=True, fn_kwargs=encode_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like better to have preprocess_function as a closure written just here with the arguments instead of defining it above with kwargs.

It spares the reader a scroll up and down to see what's happening but I understand this is a matter of personal taste.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We lose the preprocessing caching for some reason when doing that.

# Get the metric function
metric = load_metric("glue", data_args.task_name)

def compute_metrics(p: EvalPrediction):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment for the reader

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return glue_compute_metrics(task_name, preds, p.label_ids)

return compute_metrics_fn
datasets = load_dataset("glue", data_args.task_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a detailed multi-line comment explaining how the user can also easily load his own datasets as a JSON or CSV files (with mock examples) and linking to the relevant page of the datasets library.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, I stayed basic on the examples since there is the link to the datasets documentation.

test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]

# Get the metric function
metric = load_metric("glue", data_args.task_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we should also think about a user who would like to train on his own classification CSV dataset.

I think we should probably have a few "f1", "accuracy" metrics in datasets for such use cases. What do you think @LysandreJik @lhoestq @sgugger ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all basic metrics provided by scikit-learn should be available in datasets, yes.


train_dataset = datasets["train"]
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger.info() a few dataset samples?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Member

@julien-c julien-c left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice job

@julien-c
Copy link
Member

Should we start thinking about automating the creation of the metadata block for the model's model card?

here for instance we'd already have this info:

---
datasets:
- mrpc
metrics:
- f1
finetuned_from: bert-base-cased
---

@sgugger
Copy link
Collaborator Author

sgugger commented Oct 21, 2020

We could think of something like that and add a blank model card to be completed by the user in the final checkpoint. We could also include the results of the last evaluation if there is one.

@sgugger sgugger changed the title [WIP] New run glue script New run glue script Oct 21, 2020
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is great. LGTM!

Copy link
Member

@thomwolf thomwolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look great, added a few proposals to make it a bit simpler to read (imo)

Comment on lines +177 to 178
# Set seed before initializing model.
set_seed(training_args.seed)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we also have a set_seed method in the datasets library.

@sgugger sgugger merged commit 2e5052d into master Oct 22, 2020
@sgugger sgugger deleted the new_run_glue branch October 22, 2020 15:42
@LysandreJik LysandreJik mentioned this pull request Oct 28, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants