Skip to content

Conversation

@sgugger
Copy link
Collaborator

@sgugger sgugger commented Oct 27, 2020

What does this PR do?

This PR adds an example of a causal language modeling fine-tuning (or training from scratch) using the 🤗 Datasets library. It supports loading a dataset via its name (from the hub) or local files. A test of training on a small text is added.

Copy link
Member

@thomwolf thomwolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

Just a few comments and suggestion.

from transformers.trainer_utils import is_main_process


logger = logging.getLogger(__name__)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use transformers's library logging ? (cc @LysandreJik)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, for the script, we should use the regular one. @LysandreJik had a very long explanation of why that I don't remember.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it is.

The gist of it is that imo the transformer logging utility should only be used to control the logging of the transformers module, not of the users' scripts directly, as it is not made for that and would lead to very weird behavior.

In my opinion the control of logging in a user script should contain both:

import logging
from transformers import logging as hf_logging

hf_logging.set_verbosity_xxx()
logger = logging.getLogger(__name__)

# then do stuff with the logger without worrying about the HF logging which has already been managed before
logger.warn("xxx")

def tokenize_function(examples):
return tokenizer(examples[text_column_name])

tokenized_datasets = datasets.map(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the two calls to map (here and below) it could be nice to add a reference to multi-processing with num_proc
(and maybe a link to the doc: https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do the same thing to the run_glue script too, in passing.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, LGTM!

Comment on lines 2 to 3
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the copyright to Google AI and NVIDIA? Are there some snippets taken from their codebases?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's a bad copy paste.

Comment on lines 16 to 18
"""
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL) on a text file or a dataset.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From experience, users will understand that only GPT, GPT-2 and CTRL are supported by that script. I would put (GPT, GPT-2, CTRL, ...) instead, and provide a link:

Suggested change
"""
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL) on a text file or a dataset.
"""
"""
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
Find the full list of model architectures that can be fine-tuned by this script on the documentation:
https://huggingface.co/transformers/model_doc/auto.html#transformers.AutoModelWithLMHead
"""

But that might be a bit too much. Maybe adding a README would be simpler.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not AutomodelWithLMHead, just CausalLM, but I can add that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the link is overkill, I just had an issue with (GPT, GPT-2, CTRL) which seems to imply that only those three models are supported.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works too but this shows checkpoints, whereas this script can also train from scratch so showing architectures would probably be better

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this links shows all kinds of LM. The script will only work with a model that can be loaded with AutoModelForCausalLM (since it uses that class).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(the other one is the deprecated one, will remove soon)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

from transformers.trainer_utils import is_main_process


logger = logging.getLogger(__name__)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it is.

The gist of it is that imo the transformer logging utility should only be used to control the logging of the transformers module, not of the users' scripts directly, as it is not made for that and would lead to very weird behavior.

In my opinion the control of logging in a user script should contain both:

import logging
from transformers import logging as hf_logging

hf_logging.set_verbosity_xxx()
logger = logging.getLogger(__name__)

# then do stuff with the logger without worrying about the HF logging which has already been managed before
logger.warn("xxx")

Comment on lines +157 to +164
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank):
transformers.utils.logging.set_verbosity_info()
logger.info("Training/evaluation parameters %s", training_args)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's exactly what I'm talking about :)

if data_args.block_size <= 0:
block_size = tokenizer.max_len
else:
block_size = min(data_args.block_size, tokenizer.max_len)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we print a warning here to tell the user their block_size isn't going to be used if it's larger than the tokenizer's max length?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add that.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@sgugger sgugger merged commit 47dfa65 into master Oct 28, 2020
@sgugger sgugger deleted the run_clm_script branch October 28, 2020 14:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants