Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4e2cc58
Drop training-branching hack for SFT segfault
yafshar Jul 23, 2025
f447155
Add recomp flag for lazy and torch.compile modes
yafshar Jul 23, 2025
8045f32
Merge branch 'main' into mixtral/remove-sft-segfault-hack
yafshar Jul 30, 2025
95f366c
Revert "Add recomp flag for lazy and torch.compile modes"
yafshar Jul 30, 2025
7fb9301
feat: add helper to register ZeRO-3 leaf modules
yafshar Jul 30, 2025
42c4093
feat: adopt apply_zero3_leaf_promotion utility
yafshar Jul 30, 2025
510fa78
feat(config): add DeepSpeed ZeRO-3 config
yafshar Jul 30, 2025
bfee838
Update the README
yafshar Jul 30, 2025
1b37314
Merge branch 'main' into mixtral/remove-sft-segfault-hack
yafshar Jul 30, 2025
d0ac1f5
Minor fix, update the zero config
yafshar Jul 30, 2025
016e200
Merge branch 'main' into mixtral/remove-sft-segfault-hack
yafshar Aug 4, 2025
b430a60
Merge branch 'main' into mixtral/remove-sft-segfault-hack
yafshar Aug 8, 2025
6649ae3
Adding a regression test
yafshar Aug 12, 2025
29cc3dd
Merge branch 'main' into mixtral/remove-sft-segfault-hack
yafshar Aug 12, 2025
96110e8
Fix the env variable for sft-trl-mixtral
yafshar Aug 13, 2025
743e60b
Update reference for G2
yafshar Aug 13, 2025
739853e
Merge branch 'main' into mixtral/remove-sft-segfault-hack
yafshar Aug 14, 2025
e699822
Rename ZeRO-3 availability flag
yafshar Aug 19, 2025
48f9c0c
Replace dynamic import with explicit class imports for clarity
yafshar Aug 19, 2025
4ad5ae6
Avoid adding empty gaudi_config_name to cmd args
yafshar Aug 19, 2025
6b15860
Move ZeRO-3 leaf promotion check to caller
yafshar Aug 19, 2025
1771837
Moved mixtral_ds_zero3_config.json to language-modeling folder
yafshar Aug 20, 2025
3ed77d5
Correct the config path
yafshar Aug 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions examples/language-modeling/mixtral_ds_zero3_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"steps_per_print": 64,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"bf16": {
"enabled": true
},
"gradient_clipping": "auto",
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"stage3_gather_16bit_weights_on_model_save": true
},
"timers": {
"throughput": {
"enabled": true,
"synchronized": false
}
},
"wall_clock_breakdown": false
}
5 changes: 3 additions & 2 deletions examples/trl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ PT_HPU_MAX_COMPOUND_OP_SIZE=10 PT_HPU_LAZY_MODE=1 python3 ../gaudi_spawn.py --wo
--dataset_name "philschmid/dolly-15k-oai-style" \
--subset 'data/' \
--streaming False \
--deepspeed ../language-modeling/llama2_ds_zero3_config.json \
--deepspeed ../language-modeling/mixtral_ds_zero3_config.json \
--output_dir="./model_mixtral" \
--do_train \
--max_steps=500 \
Expand All @@ -137,7 +137,8 @@ PT_HPU_MAX_COMPOUND_OP_SIZE=10 PT_HPU_LAZY_MODE=1 python3 ../gaudi_spawn.py --wo
--run_name="sft_mixtral" \
--report_to=none \
--use_habana \
--use_lazy_mode
--use_lazy_mode \
--use_zero3_leaf_promotion
```

## DPO pipeline
Expand Down
6 changes: 6 additions & 0 deletions examples/trl/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)

from optimum.habana import GaudiConfig
from optimum.habana.distributed import apply_zero3_leaf_promotion
from optimum.habana.trl import GaudiSFTConfig, GaudiSFTTrainer
from optimum.habana.utils import set_seed

Expand Down Expand Up @@ -142,11 +143,13 @@ def create_datasets(tokenizer, args, seed=None):
return train_data, valid_data, formating_func

low_cpu_mem_usage = True
is_zero3_enabled = False
if is_deepspeed_available():
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled

if is_deepspeed_zero3_enabled():
low_cpu_mem_usage = False
is_zero3_enabled = True

base_model = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
Expand All @@ -164,6 +167,9 @@ def create_datasets(tokenizer, args, seed=None):
base_model.generation_config.flash_attention_recompute = script_args.flash_attention_recompute
base_model.generation_config.flash_attention_causal_mask = script_args.flash_attention_causal_mask

if is_zero3_enabled and training_args.use_zero3_leaf_promotion:
apply_zero3_leaf_promotion(base_model)

tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .distributed_runner import DistributedRunner
from .fast_ddp import all_reduce_gradients
from .zero3_utils import apply_zero3_leaf_promotion


def rank_and_world(group=None):
Expand Down
58 changes: 58 additions & 0 deletions optimum/habana/distributed/zero3_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# coding=utf-8
# Copyright 2025 the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import torch


logger = logging.getLogger(__name__)


def apply_zero3_leaf_promotion(model: torch.nn.Module) -> None:
"""
Promote registered modules to ZeRO-3 leafs.

Parameters
----------
model : torch.nn.Module
The model (or PEFT wrapper) on which `set_z3_leaf_modules` will be called.
"""
# Caller guarantees DeepSpeed is available; safe to import
from deepspeed.utils import set_z3_leaf_modules

config = getattr(model, "config", model)
model_type = getattr(config, "model_type", None)

if model_type == "llama":
from optimum.habana.transformers.models.llama.modeling_llama import GaudiLlamaDecoderLayer

set_z3_leaf_modules(model, [GaudiLlamaDecoderLayer])

elif model_type == "mixtral":
from optimum.habana.transformers.models.mixtral.modeling_mixtral import GaudiMixtralSparseMoeBlock

set_z3_leaf_modules(model, [GaudiMixtralSparseMoeBlock])

elif model_type == "qwen3_moe":
from optimum.habana.transformers.models.qwen3_moe.modeling_qwen3_moe import GaudiQwen3MoeSparseMoeBlock

set_z3_leaf_modules(model, [GaudiQwen3MoeSparseMoeBlock])

else:
logger.debug(f"Model type '{model_type}' is not registered for ZeRO-3 leaf promotion.")
return

logger.debug(f"Model type '{model_type}' is registered for ZeRO-3 leaf promotion.")
77 changes: 12 additions & 65 deletions optimum/habana/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,87 +229,34 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens

routing_weights, selected_experts = calculate_routing_tensors(router_logits, self.top_k, hidden_states.dtype)

# TODO
# This is a hack solution to avoid segmentation fault during SFT training.
# Remove this section after the issue is fixed.
if self.training:
final_hidden_states = self.call_sparse_moe_op(
shape=original_shape,
hidden_states=hidden_states,
expert_routing_table=selected_experts,
router_weights=routing_weights,
)
else:
final_hidden_states = self.call_dynamic_moe_op(
hidden_states=hidden_states,
expert_routing_table=selected_experts,
router_weights=routing_weights,
)

if self.ep_size > 1:
final_hidden_states = _all_reduce(final_hidden_states)
elif deepspeed_available and (not self.training):
from deepspeed import comm

if comm.is_initialized():
comm.all_reduce(final_hidden_states)

return final_hidden_states.view(original_shape), router_logits

def call_dynamic_moe_op(
self,
hidden_states,
expert_routing_table,
router_weights,
):
# pre-processing for custom op inputs
w1_list = [self.experts[i].w1.weight for i in self.experts_range]
w2_list = [self.experts[i].w2.weight for i in self.experts_range]
w3_list = [self.experts[i].w3.weight for i in self.experts_range]

return torch.ops.hpu.mixture_of_experts(
final_hidden_states = torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
expert_routing_table=selected_experts,
router_weights=routing_weights,
w1=w1_list,
w2=w3_list, # Note that there is a different naming convention of w1, w2, and w3 between optimum habana's mixtral model and dynamic MoE kernel.
w3=w2_list,
w2=w3_list,
permuted_weights=True,
activation="silu",
experts_min=self.experts_min,
experts_max=self.experts_max,
)

def call_sparse_moe_op(
self,
shape,
hidden_states,
expert_routing_table,
router_weights,
):
dtype = hidden_states.dtype
device = hidden_states.device

padded_weights = torch.zeros((hidden_states.shape[0], self.num_experts), dtype=dtype, device=device)
padded_weights.scatter_(-1, expert_routing_table, router_weights)
padded_weights = padded_weights.view(shape[0], shape[1], self.num_experts).permute(2, 0, 1).unsqueeze(-1)

current_state_static = hidden_states

final_hidden_states = torch.zeros(shape, dtype=dtype, device=device)

# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
padded_weight = padded_weights[expert_idx]
current_hidden_states_static = expert_layer(current_state_static).view(shape) * padded_weight
final_hidden_states += current_hidden_states_static
if not self.training:
if self.ep_size > 1:
final_hidden_states = _all_reduce(final_hidden_states)
elif deepspeed_available:
from deepspeed import comm

# Support long sequences exceeding 8192
if not self.training and shape[1] > 8192:
htcore.mark_step()
if comm.is_initialized():
comm.all_reduce(final_hidden_states)

return final_hidden_states
return final_hidden_states.view(original_shape), router_logits


class GaudiMixtralAttentionLongSequence:
Expand Down
10 changes: 10 additions & 0 deletions tests/baselines/fixture/tests/test_examples.json
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,16 @@
"train_samples_per_second": 13.065
}
},
"tests/test_examples.py::DeepspeedSFTMixtralExampleTester::test_sft_Mixtral-8x7B-Instruct-v0.1_deepspeed": {
"gaudi2": {
"train_runtime": 286.0,
"train_samples_per_second": 12.0
},
"gaudi3": {
"train_runtime": 270.0,
"train_samples_per_second": 13.1
}
},
"tests/test_examples.py::DeepspeedSummarizationExampleTester::test_run_summarization_flan-t5-xxl_deepspeed": {
"gaudi2": {
"eval_rougeLsum": 27.9095,
Expand Down
70 changes: 70 additions & 0 deletions tests/configs/examples/Mixtral_8x7B_Instruct_v0_1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
{
"gaudi2": {
"trl-sft-mixtral": {
"num_train_epochs": 1,
"eval_batch_size": 1,
"distribution": {
"deepspeed": {
"learning_rate": 0.0001,
"train_batch_size": 2,
"metrics": [
"train_runtime",
"train_samples_per_second"
],
"extra_arguments": [
"--bf16 True",
"--subset data/",
"--streaming False",
"--gradient_accumulation_steps 2",
"--warmup_steps 100",
"--lr_scheduler_type cosine",
"--logging_steps 10",
"--lora_target_modules q_proj v_proj",
"--max_seq_length 512",
"--weight_decay 0.05",
"--report_to none",
"--max_steps 100",
"--optim paged_adamw_32bit",
"--remove_unused_columns False",
"--use_zero3_leaf_promotion True",
"--deepspeed examples/language-modeling/mixtral_ds_zero3_config.json"
]
}
}
}
},
"gaudi3": {
"trl-sft-mixtral": {
"num_train_epochs": 1,
"eval_batch_size": 1,
"distribution": {
"deepspeed": {
"learning_rate": 0.0001,
"train_batch_size": 2,
"metrics": [
"train_runtime",
"train_samples_per_second"
],
"extra_arguments": [
"--bf16 True",
"--subset data/",
"--streaming False",
"--gradient_accumulation_steps 2",
"--warmup_steps 100",
"--lr_scheduler_type cosine",
"--logging_steps 10",
"--lora_target_modules q_proj v_proj",
"--max_seq_length 512",
"--weight_decay 0.05",
"--report_to none",
"--max_steps 100",
"--optim paged_adamw_32bit",
"--remove_unused_columns False",
"--use_zero3_leaf_promotion True",
"--deepspeed examples/language-modeling/mixtral_ds_zero3_config.json"
]
}
}
}
}
}
Loading
Loading