Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions examples/language-modeling/run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import datasets
import evaluate
import torch
import transformers
from datasets import load_dataset
from transformers import (
Expand Down Expand Up @@ -141,6 +142,16 @@ class ModelArguments:
)
},
)
torch_dtype: Optional[str] = field(
default=None,
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
),
"choices": ["auto", "bfloat16", "float32"],
},
)
low_cpu_mem_usage: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -441,6 +452,11 @@ def main():
)

if model_args.model_name_or_path:
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
model = AutoModelForMaskedLM.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
Expand All @@ -449,6 +465,7 @@ def main():
revision=model_args.model_revision,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
torch_dtype=torch_dtype,
low_cpu_mem_usage=model_args.low_cpu_mem_usage,
)
else:
Expand Down
46 changes: 25 additions & 21 deletions tests/example_diff/run_mlm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,23 @@
---
> Training the library models for masked language modeling (BERT, ALBERT, RoBERTa...) on a text file or a dataset.
> Here is the full list of checkpoints on the hub that can be trained by this script:
35,36d33
36,37d34
< from datasets import load_dataset
<
37a35
38a36
> from datasets import load_dataset
46,49d43
47,50d44
< Trainer,
< TrainingArguments,
< is_torch_xla_available,
< set_seed,
54a49,50
55a50,51
> from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments
> from optimum.habana.utils import set_seed
56,57d51
57,58d52
< # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
< check_min_version("4.40.0.dev0")
59c53,59
60c54,60
< require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
---
> try:
Expand All @@ -31,7 +31,7 @@
> def check_optimum_habana_min_version(*a, **b):
> return ()
>
61a62,69
62a63,70
>
> # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks.
> check_min_version("4.38.0")
Expand All @@ -40,52 +40,56 @@
> require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
>
>
141c149
144c152
< "choices": ["auto", "bfloat16", "float16", "float32"],
---
> "choices": ["auto", "bfloat16", "float32"],
152c160
< "set True will benefit LLM loading time and RAM consumption."
---
> "Setting it to True will benefit LLM loading time and RAM consumption."
226c234
237c245
< streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
---
> streaming: bool = field(default=False, metadata={"help": "Enable streaming mode."})
250c258
261c269
< parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
---
> parser = HfArgumentParser((ModelArguments, DataTrainingArguments, GaudiTrainingArguments))
288a297,303
299a308,314
> gaudi_config = GaudiConfig.from_pretrained(
> training_args.gaudi_config_name,
> cache_dir=model_args.cache_dir,
> revision=model_args.model_revision,
> use_auth_token=True if model_args.use_auth_token else None,
> )
>
289a305
300a316
> mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast
291,292c307,309
302,303c318,320
< f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
< + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
---
> f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
> + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
> + f"mixed-precision training: {mixed_precision}"
294d310
305d321
< # Set the verbosity to info of the Transformers logger (on main process only):
616c632
633c649
< trainer = Trainer(
---
> trainer = GaudiTrainer(
617a634
634a651
> gaudi_config=gaudi_config,
623,626c640,641
640,643c657,658
< compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
< preprocess_logits_for_metrics=preprocess_logits_for_metrics
< if training_args.do_eval and not is_torch_xla_available()
< else None,
---
> compute_metrics=compute_metrics if training_args.do_eval else None,
> preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None,
640,643c655,661
657,660c672,678
< max_train_samples = (
< data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
< )
Expand All @@ -98,9 +102,9 @@
> data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
> )
> metrics["train_samples"] = min(max_train_samples, len(train_dataset))
652d669
669d686
<
655,656c672,677
672,673c689,694
< max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
< metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
---
Expand All @@ -110,7 +114,7 @@
> )
> metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
>
679,683d699
696,700d716
<
<
< def _mp_fn(index):
Expand Down