Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions examples/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -535,12 +535,12 @@ LOWER_LIST=ops_bf16.txt python3 ../gaudi_spawn.py \
--validation_split_percentage 6
```

- Multi-card finetuning of Llama2-70B with DeepSpeed ZeRO-3 optimization and LoRA:
- Multi-card finetuning of Llama2-70B with DeepSpeed ZeRO-3 optimization, LoRA and FP8 precision:
Comment thread
vivekgoe marked this conversation as resolved.

> The following command requires Habana DeepSpeed 1.13.0 or later.

```bash
PT_HPU_MAX_COMPOUND_OP_SIZE=10 DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1 \
PT_HPU_MAX_COMPOUND_OP_SIZE=10 \
Comment thread
libinta marked this conversation as resolved.
python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \
--model_name_or_path meta-llama/Llama-2-70b-hf \
--deepspeed llama2_ds_zero3_config.json \
Expand All @@ -550,7 +550,7 @@ python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \
--num_train_epochs 2 \
--max_seq_len 2048 \
--per_device_train_batch_size 10 \
--per_device_eval_batch_size 10 \
--per_device_eval_batch_size 1 \
--gradient_checkpointing \
--evaluation_strategy epoch \
--eval_delay 2 \
Expand All @@ -571,7 +571,8 @@ python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \
--lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \
--validation_split_percentage 4 \
--use_flash_attention True \
--flash_attention_causal_mask True
--flash_attention_causal_mask True \
--fp8 True
```

- Multi-card finetuning of Llama2-70B with FSDP and LoRA:
Expand Down
55 changes: 16 additions & 39 deletions optimum/habana/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
DeepSpeedPlugin,
DistributedDataParallelKwargs,
DistributedType,
FP8RecipeKwargs,
GradientAccumulationPlugin,
GradScalerKwargs,
InitProcessGroupKwargs,
Expand Down Expand Up @@ -73,12 +72,11 @@
from .utils import (
GaudiDistributedType,
GaudiDynamoBackend,
GaudiFP8RecipeKwargs,
GaudiFullyShardedDataParallelPlugin,
GaudiTorchDynamoPlugin,
te_forward_convert,
te_setup_fp8_recipe_handler,
te_wrap_fp8,
te_wrap_fp8_forward_convert,
convert_model,
get_fp8_recipe,
)


Expand Down Expand Up @@ -113,7 +111,6 @@ def __init__(
dynamo_backend: GaudiDynamoBackend | str | None = None,
distribution_strategy: str = None,
force_autocast: bool = False,
fp8_recipe_format: str = None,
):
self.trackers = []
if project_config is not None:
Expand Down Expand Up @@ -181,7 +178,6 @@ def __init__(
self.scaler_handler = None
self.init_handler = None
self.fp8_recipe_handler = None
self.fp8_recipe_format = None
self.autocast_handler = None
if kwargs_handlers is not None:
for handler in kwargs_handlers:
Expand All @@ -203,9 +199,9 @@ def __init__(
raise ValueError("You can only pass one `InitProcessGroupKwargs` in `kwargs_handler`.")
else:
self.init_handler = handler
elif isinstance(handler, FP8RecipeKwargs):
elif isinstance(handler, GaudiFP8RecipeKwargs):
if self.fp8_recipe_handler is not None:
raise ValueError("You can only pass one `FP8RecipeKwargs` in `kwargs_handler`.")
raise ValueError("You can only pass one `GaudiFP8RecipeKwargs` in `kwargs_handler`.")
else:
self.fp8_recipe_handler = handler
elif isinstance(handler, AutocastKwargs):
Expand All @@ -225,8 +221,14 @@ def __init__(
_from_accelerator=True,
**kwargs,
)
if self.fp8_recipe_handler is None and self.state.is_fp8_enabled:
self.fp8_recipe_handler = te_setup_fp8_recipe_handler(self.fp8_recipe_format)

if self.state.is_fp8_enabled:
if self.fp8_recipe_handler is None:
self.fp8_recipe_handler = GaudiFP8RecipeKwargs()
# Handling FP8 recipe creation in init since both `prepare_model` and `_prepare_deepspeed` require it.
# (Base accelerator handles this in `prepare_model` function)
self.fp8_recipe_handler = get_fp8_recipe(self.fp8_recipe_handler)

trackers = filter_trackers(log_with, self.logging_dir)
if len(trackers) < 1 and log_with is not None:
warnings.warn(f"`log_with={log_with}` was passed but no supported trackers are currently installed.")
Expand Down Expand Up @@ -349,31 +351,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model)
else:
model.forward = convert_outputs_to_fp32(new_forward)
elif self.state.is_fp8_enabled:
model = te_wrap_fp8_forward_convert(model, self.fp8_recipe_handler)
# FP8 is not supported on Gaudi2 yet
# elif self.mixed_precision == "fp8":
# if not has_transformer_engine_layers(model):
# with torch.no_grad():
# convert_model(model)
# model._converted_to_transformer_engine = True
# model._original_forward = model.forward

# kwargs = self.fp8_recipe_handler.to_kwargs() if self.fp8_recipe_handler is not None else {}
# if "fp8_format" in kwargs:
# kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
# fp8_recipe = te_recipe.DelayedScaling(**kwargs)
# cuda_device_capacity = torch.cuda.get_device_capability()
# fp8_enabled = cuda_device_capacity[0] >= 9 or (
# cuda_device_capacity[0] == 8 and cuda_device_capacity[1] >= 9
# )
# if not fp8_enabled:
# logger.warn(
# f"The current device has compute capability of {cuda_device_capacity} which is "
# "insufficient for FP8 mixed precision training (requires a GPU Hopper/Ada Lovelace "
# "or higher, compute capability of 8.9 or higher). Will use FP16 instead."
# )
# model.forward = fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe)(model.forward)
if self.state.is_fp8_enabled:
model = convert_model(model)

if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr(
model, "hf_device_map", False
Expand Down Expand Up @@ -469,7 +448,7 @@ def _prepare_deepspeed(self, *args):
result = [
self._prepare_one(obj, first_pass=True)
if isinstance(obj, torch.utils.data.DataLoader)
else te_wrap_fp8(obj)
else convert_model(obj)
if isinstance(obj, torch.nn.Module) and self.state.is_fp8_enabled
else obj
for obj in args
Expand Down Expand Up @@ -685,8 +664,6 @@ def _prepare_deepspeed(self, *args):
result[i] = scheduler
# pointing for deepspeed_engine_wrapped.backward()
self.deepspeed_engine_wrapped = DeepSpeedEngineWrapper(engine)
if self.state.is_fp8_enabled:
model = te_forward_convert(engine, self.fp8_recipe_handler)
self._models.append(engine)
if optimizer is not None:
self._optimizers.append(optimizer)
Expand Down
8 changes: 4 additions & 4 deletions optimum/habana/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from .dataclasses import (
GaudiDistributedType,
GaudiDynamoBackend,
GaudiFP8RecipeKwargs,
GaudiFullyShardedDataParallelPlugin,
GaudiTorchDynamoPlugin,
)
from .transformer_engine import (
te_forward_convert,
te_setup_fp8_recipe_handler,
te_wrap_fp8,
te_wrap_fp8_forward_convert,
FP8ContextWrapper,
convert_model,
get_fp8_recipe,
)
46 changes: 45 additions & 1 deletion optimum/habana/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
from accelerate.utils import FullyShardedDataParallelPlugin
from accelerate.utils.constants import FSDP_BACKWARD_PREFETCH
from accelerate.utils.dataclasses import BaseEnum, TorchDynamoPlugin
from accelerate.utils.dataclasses import BaseEnum, KwargsHandler, TorchDynamoPlugin
from accelerate.utils.environment import str_to_bool


Expand Down Expand Up @@ -144,3 +144,47 @@ def __post_init__(self):
if self.sync_module_states:
device = torch.device("hpu")
self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)


@dataclass
class GaudiFP8RecipeKwargs(KwargsHandler):
"""
Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision training with `transformer-engine`.

Adapted from: https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/utils/dataclasses.py#L180

Args:
margin (`int`, *optional*, defaults to 0):
The margin to use for the scaling factor computation.
interval (`int`, *optional*, defaults to 16):
The interval to use for how often the scaling factor is recomputed.
fp8_format (`str`, *optional*, defaults to "HYBRID"):
The format to use for the FP8 recipe. Must be one of `E5M2` or `HYBRID`.
amax_history_len (`int`, *optional*, defaults to 1):
The length of the history to use for the scaling factor computation
amax_compute_algo (`str`, *optional*, defaults to "most_recent"):
The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`.
reduce_amax (`bool`, *optional*, defaults to "False"):
By default, if `torch.distributed` is initialized, the `amax` value for FP8
tensors is reduced across the `fp8_group` (specified in the `fp8_autocast`
call). This keeps the amaxes and scaling factors synced across the given
distributed group. If set to `False`, this reduction is skipped and every
HPU maintains local amaxes and scaling factors. To ensure results are
numerically identical across checkpointing boundaries in this case, all
ranks must checkpoint in order to store the local tensors.
"""

margin: int = 0
interval: int = 16
fp8_format: str = "HYBRID"
amax_compute_algo: str = "most_recent"
amax_history_len: int = 1
reduce_amax: bool = False

def __post_init__(self):
self.fp8_format = self.fp8_format.upper()
assert self.fp8_format in ("E5M2", "HYBRID"), "Only E5M2 and HYBRID FP8 formats are currently supported."
assert self.amax_compute_algo in (
"max",
"most_recent",
), "Only max and most_recent `amax_compute_algo` modes are currently supported."
Loading