diff --git a/examples/language-modeling/run_mlm.py b/examples/language-modeling/run_mlm.py index b2e722609f..9bbf3fdcb0 100644 --- a/examples/language-modeling/run_mlm.py +++ b/examples/language-modeling/run_mlm.py @@ -31,6 +31,7 @@ import datasets import evaluate +import torch import transformers from datasets import load_dataset from transformers import ( @@ -141,6 +142,16 @@ class ModelArguments: ) }, ) + torch_dtype: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " + "dtype will be automatically derived from the model's weights." + ), + "choices": ["auto", "bfloat16", "float32"], + }, + ) low_cpu_mem_usage: bool = field( default=False, metadata={ @@ -441,6 +452,11 @@ def main(): ) if model_args.model_name_or_path: + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ) model = AutoModelForMaskedLM.from_pretrained( model_args.model_name_or_path, from_tf=bool(".ckpt" in model_args.model_name_or_path), @@ -449,6 +465,7 @@ def main(): revision=model_args.model_revision, token=model_args.token, trust_remote_code=model_args.trust_remote_code, + torch_dtype=torch_dtype, low_cpu_mem_usage=model_args.low_cpu_mem_usage, ) else: diff --git a/tests/example_diff/run_mlm.txt b/tests/example_diff/run_mlm.txt index 3024e5d22c..d93476af72 100644 --- a/tests/example_diff/run_mlm.txt +++ b/tests/example_diff/run_mlm.txt @@ -5,23 +5,23 @@ --- > Training the library models for masked language modeling (BERT, ALBERT, RoBERTa...) on a text file or a dataset. > Here is the full list of checkpoints on the hub that can be trained by this script: -35,36d33 +36,37d34 < from datasets import load_dataset < -37a35 +38a36 > from datasets import load_dataset -46,49d43 +47,50d44 < Trainer, < TrainingArguments, < is_torch_xla_available, < set_seed, -54a49,50 +55a50,51 > from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments > from optimum.habana.utils import set_seed -56,57d51 +57,58d52 < # Will error if the minimal version of Transformers is not installed. Remove at your own risks. < check_min_version("4.40.0.dev0") -59c53,59 +60c54,60 < require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") --- > try: @@ -31,7 +31,7 @@ > def check_optimum_habana_min_version(*a, **b): > return () > -61a62,69 +62a63,70 > > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. > check_min_version("4.38.0") @@ -40,19 +40,23 @@ > require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") > > -141c149 +144c152 +< "choices": ["auto", "bfloat16", "float16", "float32"], +--- +> "choices": ["auto", "bfloat16", "float32"], +152c160 < "set True will benefit LLM loading time and RAM consumption." --- > "Setting it to True will benefit LLM loading time and RAM consumption." -226c234 +237c245 < streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"}) --- > streaming: bool = field(default=False, metadata={"help": "Enable streaming mode."}) -250c258 +261c269 < parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) --- > parser = HfArgumentParser((ModelArguments, DataTrainingArguments, GaudiTrainingArguments)) -288a297,303 +299a308,314 > gaudi_config = GaudiConfig.from_pretrained( > training_args.gaudi_config_name, > cache_dir=model_args.cache_dir, @@ -60,24 +64,24 @@ > use_auth_token=True if model_args.use_auth_token else None, > ) > -289a305 +300a316 > mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast -291,292c307,309 +302,303c318,320 < f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " < + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" --- > f"Process rank: {training_args.local_rank}, device: {training_args.device}, " > + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, " > + f"mixed-precision training: {mixed_precision}" -294d310 +305d321 < # Set the verbosity to info of the Transformers logger (on main process only): -616c632 +633c649 < trainer = Trainer( --- > trainer = GaudiTrainer( -617a634 +634a651 > gaudi_config=gaudi_config, -623,626c640,641 +640,643c657,658 < compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None, < preprocess_logits_for_metrics=preprocess_logits_for_metrics < if training_args.do_eval and not is_torch_xla_available() @@ -85,7 +89,7 @@ --- > compute_metrics=compute_metrics if training_args.do_eval else None, > preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None, -640,643c655,661 +657,660c672,678 < max_train_samples = ( < data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) < ) @@ -98,9 +102,9 @@ > data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) > ) > metrics["train_samples"] = min(max_train_samples, len(train_dataset)) -652d669 +669d686 < -655,656c672,677 +672,673c689,694 < max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) < metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) --- @@ -110,7 +114,7 @@ > ) > metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) > -679,683d699 +696,700d716 < < < def _mp_fn(index):