diff --git a/examples/language-modeling/mixtral_ds_zero3_config.json b/examples/language-modeling/mixtral_ds_zero3_config.json new file mode 100755 index 0000000000..277dff686d --- /dev/null +++ b/examples/language-modeling/mixtral_ds_zero3_config.json @@ -0,0 +1,23 @@ +{ + "steps_per_print": 64, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "gradient_accumulation_steps": "auto", + "bf16": { + "enabled": true + }, + "gradient_clipping": "auto", + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "stage3_gather_16bit_weights_on_model_save": true + }, + "timers": { + "throughput": { + "enabled": true, + "synchronized": false + } + }, + "wall_clock_breakdown": false +} diff --git a/examples/trl/README.md b/examples/trl/README.md index 679c88fcde..b3cd08b70b 100644 --- a/examples/trl/README.md +++ b/examples/trl/README.md @@ -116,7 +116,7 @@ PT_HPU_MAX_COMPOUND_OP_SIZE=10 PT_HPU_LAZY_MODE=1 python3 ../gaudi_spawn.py --wo --dataset_name "philschmid/dolly-15k-oai-style" \ --subset 'data/' \ --streaming False \ - --deepspeed ../language-modeling/llama2_ds_zero3_config.json \ + --deepspeed ../language-modeling/mixtral_ds_zero3_config.json \ --output_dir="./model_mixtral" \ --do_train \ --max_steps=500 \ @@ -137,7 +137,8 @@ PT_HPU_MAX_COMPOUND_OP_SIZE=10 PT_HPU_LAZY_MODE=1 python3 ../gaudi_spawn.py --wo --run_name="sft_mixtral" \ --report_to=none \ --use_habana \ - --use_lazy_mode + --use_lazy_mode \ + --use_zero3_leaf_promotion ``` ## DPO pipeline diff --git a/examples/trl/sft.py b/examples/trl/sft.py index b0b91e73a3..6ab2864e07 100644 --- a/examples/trl/sft.py +++ b/examples/trl/sft.py @@ -16,6 +16,7 @@ ) from optimum.habana import GaudiConfig +from optimum.habana.distributed import apply_zero3_leaf_promotion from optimum.habana.trl import GaudiSFTConfig, GaudiSFTTrainer from optimum.habana.utils import set_seed @@ -142,11 +143,13 @@ def create_datasets(tokenizer, args, seed=None): return train_data, valid_data, formating_func low_cpu_mem_usage = True + is_zero3_enabled = False if is_deepspeed_available(): from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled if is_deepspeed_zero3_enabled(): low_cpu_mem_usage = False + is_zero3_enabled = True base_model = AutoModelForCausalLM.from_pretrained( script_args.model_name_or_path, @@ -164,6 +167,9 @@ def create_datasets(tokenizer, args, seed=None): base_model.generation_config.flash_attention_recompute = script_args.flash_attention_recompute base_model.generation_config.flash_attention_causal_mask = script_args.flash_attention_causal_mask + if is_zero3_enabled and training_args.use_zero3_leaf_promotion: + apply_zero3_leaf_promotion(base_model) + tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training diff --git a/optimum/habana/distributed/__init__.py b/optimum/habana/distributed/__init__.py index af269ee68c..fa0b0b8d3d 100644 --- a/optimum/habana/distributed/__init__.py +++ b/optimum/habana/distributed/__init__.py @@ -4,6 +4,7 @@ from .distributed_runner import DistributedRunner from .fast_ddp import all_reduce_gradients +from .zero3_utils import apply_zero3_leaf_promotion def rank_and_world(group=None): diff --git a/optimum/habana/distributed/zero3_utils.py b/optimum/habana/distributed/zero3_utils.py new file mode 100644 index 0000000000..e05e8b5a91 --- /dev/null +++ b/optimum/habana/distributed/zero3_utils.py @@ -0,0 +1,58 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch + + +logger = logging.getLogger(__name__) + + +def apply_zero3_leaf_promotion(model: torch.nn.Module) -> None: + """ + Promote registered modules to ZeRO-3 leafs. + + Parameters + ---------- + model : torch.nn.Module + The model (or PEFT wrapper) on which `set_z3_leaf_modules` will be called. + """ + # Caller guarantees DeepSpeed is available; safe to import + from deepspeed.utils import set_z3_leaf_modules + + config = getattr(model, "config", model) + model_type = getattr(config, "model_type", None) + + if model_type == "llama": + from optimum.habana.transformers.models.llama.modeling_llama import GaudiLlamaDecoderLayer + + set_z3_leaf_modules(model, [GaudiLlamaDecoderLayer]) + + elif model_type == "mixtral": + from optimum.habana.transformers.models.mixtral.modeling_mixtral import GaudiMixtralSparseMoeBlock + + set_z3_leaf_modules(model, [GaudiMixtralSparseMoeBlock]) + + elif model_type == "qwen3_moe": + from optimum.habana.transformers.models.qwen3_moe.modeling_qwen3_moe import GaudiQwen3MoeSparseMoeBlock + + set_z3_leaf_modules(model, [GaudiQwen3MoeSparseMoeBlock]) + + else: + logger.debug(f"Model type '{model_type}' is not registered for ZeRO-3 leaf promotion.") + return + + logger.debug(f"Model type '{model_type}' is registered for ZeRO-3 leaf promotion.") diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index b38b807399..3dbd11d003 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -229,87 +229,34 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens routing_weights, selected_experts = calculate_routing_tensors(router_logits, self.top_k, hidden_states.dtype) - # TODO - # This is a hack solution to avoid segmentation fault during SFT training. - # Remove this section after the issue is fixed. - if self.training: - final_hidden_states = self.call_sparse_moe_op( - shape=original_shape, - hidden_states=hidden_states, - expert_routing_table=selected_experts, - router_weights=routing_weights, - ) - else: - final_hidden_states = self.call_dynamic_moe_op( - hidden_states=hidden_states, - expert_routing_table=selected_experts, - router_weights=routing_weights, - ) - - if self.ep_size > 1: - final_hidden_states = _all_reduce(final_hidden_states) - elif deepspeed_available and (not self.training): - from deepspeed import comm - - if comm.is_initialized(): - comm.all_reduce(final_hidden_states) - - return final_hidden_states.view(original_shape), router_logits - - def call_dynamic_moe_op( - self, - hidden_states, - expert_routing_table, - router_weights, - ): # pre-processing for custom op inputs w1_list = [self.experts[i].w1.weight for i in self.experts_range] w2_list = [self.experts[i].w2.weight for i in self.experts_range] w3_list = [self.experts[i].w3.weight for i in self.experts_range] - return torch.ops.hpu.mixture_of_experts( + final_hidden_states = torch.ops.hpu.mixture_of_experts( hidden_states=hidden_states, - expert_routing_table=expert_routing_table, - router_weights=router_weights, + expert_routing_table=selected_experts, + router_weights=routing_weights, w1=w1_list, + w2=w3_list, # Note that there is a different naming convention of w1, w2, and w3 between optimum habana's mixtral model and dynamic MoE kernel. w3=w2_list, - w2=w3_list, permuted_weights=True, activation="silu", experts_min=self.experts_min, experts_max=self.experts_max, ) - def call_sparse_moe_op( - self, - shape, - hidden_states, - expert_routing_table, - router_weights, - ): - dtype = hidden_states.dtype - device = hidden_states.device - - padded_weights = torch.zeros((hidden_states.shape[0], self.num_experts), dtype=dtype, device=device) - padded_weights.scatter_(-1, expert_routing_table, router_weights) - padded_weights = padded_weights.view(shape[0], shape[1], self.num_experts).permute(2, 0, 1).unsqueeze(-1) - - current_state_static = hidden_states - - final_hidden_states = torch.zeros(shape, dtype=dtype, device=device) - - # Loop over all available experts in the model and perform the computation on each expert - for expert_idx in range(self.num_experts): - expert_layer = self.experts[expert_idx] - padded_weight = padded_weights[expert_idx] - current_hidden_states_static = expert_layer(current_state_static).view(shape) * padded_weight - final_hidden_states += current_hidden_states_static + if not self.training: + if self.ep_size > 1: + final_hidden_states = _all_reduce(final_hidden_states) + elif deepspeed_available: + from deepspeed import comm - # Support long sequences exceeding 8192 - if not self.training and shape[1] > 8192: - htcore.mark_step() + if comm.is_initialized(): + comm.all_reduce(final_hidden_states) - return final_hidden_states + return final_hidden_states.view(original_shape), router_logits class GaudiMixtralAttentionLongSequence: diff --git a/tests/baselines/fixture/tests/test_examples.json b/tests/baselines/fixture/tests/test_examples.json index b9ead5eab3..10ef0b1ea6 100644 --- a/tests/baselines/fixture/tests/test_examples.json +++ b/tests/baselines/fixture/tests/test_examples.json @@ -147,6 +147,16 @@ "train_samples_per_second": 13.065 } }, + "tests/test_examples.py::DeepspeedSFTMixtralExampleTester::test_sft_Mixtral-8x7B-Instruct-v0.1_deepspeed": { + "gaudi2": { + "train_runtime": 286.0, + "train_samples_per_second": 12.0 + }, + "gaudi3": { + "train_runtime": 270.0, + "train_samples_per_second": 13.1 + } + }, "tests/test_examples.py::DeepspeedSummarizationExampleTester::test_run_summarization_flan-t5-xxl_deepspeed": { "gaudi2": { "eval_rougeLsum": 27.9095, diff --git a/tests/configs/examples/Mixtral_8x7B_Instruct_v0_1.json b/tests/configs/examples/Mixtral_8x7B_Instruct_v0_1.json new file mode 100644 index 0000000000..61e0c32424 --- /dev/null +++ b/tests/configs/examples/Mixtral_8x7B_Instruct_v0_1.json @@ -0,0 +1,70 @@ + { + "gaudi2": { + "trl-sft-mixtral": { + "num_train_epochs": 1, + "eval_batch_size": 1, + "distribution": { + "deepspeed": { + "learning_rate": 0.0001, + "train_batch_size": 2, + "metrics": [ + "train_runtime", + "train_samples_per_second" + ], + "extra_arguments": [ + "--bf16 True", + "--subset data/", + "--streaming False", + "--gradient_accumulation_steps 2", + "--warmup_steps 100", + "--lr_scheduler_type cosine", + "--logging_steps 10", + "--lora_target_modules q_proj v_proj", + "--max_seq_length 512", + "--weight_decay 0.05", + "--report_to none", + "--max_steps 100", + "--optim paged_adamw_32bit", + "--remove_unused_columns False", + "--use_zero3_leaf_promotion True", + "--deepspeed examples/language-modeling/mixtral_ds_zero3_config.json" + ] + } + } + } + }, + "gaudi3": { + "trl-sft-mixtral": { + "num_train_epochs": 1, + "eval_batch_size": 1, + "distribution": { + "deepspeed": { + "learning_rate": 0.0001, + "train_batch_size": 2, + "metrics": [ + "train_runtime", + "train_samples_per_second" + ], + "extra_arguments": [ + "--bf16 True", + "--subset data/", + "--streaming False", + "--gradient_accumulation_steps 2", + "--warmup_steps 100", + "--lr_scheduler_type cosine", + "--logging_steps 10", + "--lora_target_modules q_proj v_proj", + "--max_seq_length 512", + "--weight_decay 0.05", + "--report_to none", + "--max_steps 100", + "--optim paged_adamw_32bit", + "--remove_unused_columns False", + "--use_zero3_leaf_promotion True", + "--deepspeed examples/language-modeling/mixtral_ds_zero3_config.json" + ] + } + } + } + } +} \ No newline at end of file diff --git a/tests/test_examples.py b/tests/test_examples.py index 8ac152d0bf..c619574093 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -169,7 +169,7 @@ def is_valid_model_type(model_type: str) -> bool: "sft": _get_supported_models_for_script( MODELS_TO_TEST_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING, - ["llama", "qwen2"], + ["llama", "qwen2", "mixtral"], ), "dpo": _get_supported_models_for_script( MODELS_TO_TEST_MAPPING, @@ -270,12 +270,16 @@ def to_test( return False elif "Qwen2-72B" in model_name and task_name != "trl-sft-qwen": return False + elif "Mixtral-8x7B" in model_name and task_name != "trl-sft-mixtral": + return False elif "llama" in model_name and "trl-sft-chat" in task_name: return False elif ("qwen2" in model_name or "Qwen2" in model_name) and task_name == "trl-sft": return False elif "llama" in model_name and "trl-sft-qwen" in task_name: return False + elif "llama" in model_name and "trl-sft-mixtral" in task_name: + return False elif "Llama-3.1-8B" in model_name: if multi_card: return False @@ -321,6 +325,8 @@ def to_test( return True elif "Qwen2-72B" in model_name and not IS_GAUDI1 and deepspeed: return True + elif "Mixtral-8x7B" in model_name and not IS_GAUDI1 and deepspeed: + return True elif model_name == "albert-xxlarge-v1": if (("RUN_ALBERT_XXL_1X" in os.environ) and strtobool(os.environ["RUN_ALBERT_XXL_1X"])) or multi_card: # ALBERT XXL 1X is tested only if the required flag is present because it takes long @@ -540,6 +546,10 @@ def test(self): env_variables["PT_HPU_LAZY_MODE"] = "0" env_variables["PT_ENABLE_INT64_SUPPORT"] = "1" + if model_name == "mistralai/Mixtral-8x7B-Instruct-v0.1": + env_variables["PT_HPU_LAZY_MODE"] = "1" + env_variables["PT_ENABLE_INT64_SUPPORT"] = "1" + if self.EXAMPLE_NAME == "run_audio_classification": extra_command_line_arguments.append("--sdp_on_bf16") if "wav2vec2" in model_name: @@ -701,7 +711,6 @@ def _create_command_line( cmd_line += [ f"{script}", f"--model_name_or_path {model_name}", - f"--gaudi_config_name {gaudi_config_name}", f"{task_option}", "--do_train", f"--output_dir {output_dir}", @@ -709,11 +718,15 @@ def _create_command_line( f"--learning_rate {lr}", f"--per_device_train_batch_size {train_batch_size}", f"--per_device_eval_batch_size {eval_batch_size}", - f" --num_train_epochs {num_epochs}", + f"--num_train_epochs {num_epochs}", "--use_habana", "--throughput_warmup_steps 3", "--save_strategy no", ] + if gaudi_config_name: + cmd_line += [ + f"--gaudi_config_name {gaudi_config_name}", + ] if "compile" in task or "--torch_compile" in extra_command_line_arguments: cmd_line += ["--use_lazy_mode False"] @@ -973,6 +986,13 @@ class DeepspeedSFTExampleTester(ExampleTesterBase, metaclass=ExampleTestMeta, ex DATASET_NAME = "philschmid/dolly-15k-oai-style" +class DeepspeedSFTMixtralExampleTester( + ExampleTesterBase, metaclass=ExampleTestMeta, example_name="sft", deepspeed=True +): + TASK_NAME = "trl-sft-mixtral" + DATASET_NAME = "philschmid/dolly-15k-oai-style" + + class MultiCardSFTChatExampleTester(ExampleTesterBase, metaclass=ExampleTestMeta, example_name="sft", multi_card=True): TASK_NAME = "trl-sft-chat" DATASET_NAME = "philschmid/dolly-15k-oai-style" diff --git a/tests/utils.py b/tests/utils.py index 36e131e5b5..bebd416bc4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -67,6 +67,7 @@ "gemma": [("google/gemma-2b-it", "Habana/gpt2")], "chatglm": [("THUDM/chatglm3-6b", "Habana/gpt2")], "llava": [("llava-hf/llava-1.5-7b-hf", "Habana/gpt2")], + "mixtral": [("mistralai/Mixtral-8x7B-Instruct-v0.1", "")], } MODELS_TO_TEST_FOR_QUESTION_ANSWERING = [