Skip to content

mixtral: drop training-branching hack for SFT segfault & add ZeRO-3 leaf utility#2185

Merged
regisss merged 23 commits into
huggingface:mainfrom
yafshar:mixtral/remove-sft-segfault-hack
Aug 21, 2025
Merged

mixtral: drop training-branching hack for SFT segfault & add ZeRO-3 leaf utility#2185
regisss merged 23 commits into
huggingface:mainfrom
yafshar:mixtral/remove-sft-segfault-hack

Conversation

@yafshar
Copy link
Copy Markdown
Contributor

@yafshar yafshar commented Jul 30, 2025

What does this PR do?

1. Removes the temporary MOE-kernel workaround

  • Deletes the TODO block if self.training … else … branch that selected different HPU MOE kernels.
  • Replaces both call_sparse_moe_op and call_dynamic_moe_op with the single HPU-optimized
    torch.ops.hpu.mixture_of_experts after Synapse 1.21.0 fixed the segfault reported in PR [1.20.0] Temporary workaround to avoid segmentation fault #1798.
  • Keeps the post-kernel all-reduce for inference unchanged.

2. Adds reusable ZeRO-3 leaf-promotion utility

3. Wires the utility into sft.py script

  • Update the README

4. Provides new ZeRO-3 config template for Mixtral

  • Adds examples/language-modeling/mixtral_zero3_config.json
  • This config enables ZeRO Stage 3 with overlap communication to support torch.ops.hpu.mixture_of_experts

Tests:

main

>>> PT_HPU_LAZY_MODE=1 PT_ENABLE_INT64_SUPPORT=1 python ../gaudi_spawn.py --world_size 4 --use_deepspeed sft.py \
--model_name_or_path mistralai/Mixtral-8x7B-Instruct-v0.1 \
--dataset_name "philschmid/dolly-15k-oai-style" \
--subset 'data/' \
--streaming False \
--deepspeed ../language-modeling/llama2_ds_zero3_config.json \
--output_dir="./model_mixtral" \
--do_train \
--max_steps=500 \
--logging_steps=10 \
--save_steps=100 \
--per_device_train_batch_size=2 \
--per_device_eval_batch_size=1 \
--gradient_accumulation_steps=2 \
--learning_rate=1e-4 \
--lr_scheduler_type="cosine" \
--warmup_steps=100 \
--weight_decay=0.05 \
--optim="paged_adamw_32bit" \
--lora_target_modules "q_proj" "v_proj" \
--bf16 \
--remove_unused_columns=False \
--max_seq_length 512 \
--run_name="sft_mixtral" \
--report_to=none \
--use_habana \
--use_lazy_mode


***** train metrics *****
  epoch                       =     1.3972
  max_memory_allocated (GB)   =      80.43
  memory_allocated (GB)       =      23.33
  total_flos                  =   107879GF
  total_memory_available (GB) =     126.54
  train_loss                  =     5.1121
  train_runtime               = 0:33:59.01
  train_samples_per_second    =      3.923
  train_steps_per_second      =      0.245

this PR

>>> PT_HPU_LAZY_MODE=1 PT_ENABLE_INT64_SUPPORT=1 python ../gaudi_spawn.py --world_size 4 --use_deepspeed sft.py \
--model_name_or_path mistralai/Mixtral-8x7B-Instruct-v0.1 \
--dataset_name "philschmid/dolly-15k-oai-style" \
--subset 'data/' \
--streaming False \
--deepspeed ../language-modeling/mixtral_ds_zero3_config.json \
--output_dir="./model_mixtral" \
--do_train \
--max_steps=500 \
--logging_steps=10 \
--save_steps=100 \
--per_device_train_batch_size=2 \
--per_device_eval_batch_size=1 \
--gradient_accumulation_steps=2 \
--learning_rate=1e-4 \
--lr_scheduler_type="cosine" \
--warmup_steps=100 \
--weight_decay=0.05 \
--optim="paged_adamw_32bit" \
--lora_target_modules "q_proj" "v_proj" \
--bf16 \
--remove_unused_columns=False \
--max_seq_length 512 \
--run_name="sft_mixtral" \
--report_to=none \
--use_habana \
--use_lazy_mode \
--use_zero3_leaf_promotion

***** train metrics *****
  epoch                       =     1.3972
  max_memory_allocated (GB)   =      46.23
  memory_allocated (GB)       =      23.49
  total_flos                  =   107879GF
  total_memory_available (GB) =     126.54
  train_loss                  =     5.1126
  train_runtime               = 0:21:16.90
  train_samples_per_second    =      6.265
  train_steps_per_second      =      0.392

📊 Training Performance Comparison

Metric Main This PR Improvement / Change
Train Runtime 33 min 59 sec 21 min 16 sec ⬇️ ~37.4% faster
Train Samples/sec 3.923 6.265 ⬆️ ~59.6% increase
Train Steps/sec 0.245 0.392 ⬆️ ~60% increase
Max Memory Allocated (GB) 80.43 46.23 ⬇️ ~42.5% less memory usage
Train Loss 5.1121 5.1126 ⬆️ Negligible increase

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

yafshar added 8 commits July 23, 2025 10:25
The workaround that chose between `call_sparse_moe_op` (training) and
`call_dynamic_moe_op` (inference) was introduced to avoid a segmentation
fault during SFT training on earlier Synapse releases. (See PR huggingface#1798)
The underlying bug is fixed in Synapse 1.21.0, so the hack is no longer
needed.

Replace the branching logic with the unified
`torch.ops.hpu.mixture_of_experts` call for both training and
inference, and remove the TODO comment.
- Introduced `apply_zero3_leaf_promotion` to mark model submodules as ZeRO-3 leaf modules
- The function is a no-op unless both:
  - is_deepspeed_zero3_enabled=True (caller asserts ZeRO-3 active)
  - use_zero3_leaf_promotion=True   (user opt-in flag)
- Uses a registry-based approach for model-type-specific leaf class mapping
Replace inline DeepSpeed leaf-module patching with the new
`optimum.habana.distributed.apply_zero3_leaf_promotion` utility.
Activation is controlled by the existing script_args flags
`use_zero3_leaf_promotion` and the runtime ZeRO-3 status check.
- Enables ZeRO Stage 3 with overlap communication to support
  `torch.ops.hpu.mixture_of_experts`
@yafshar yafshar requested a review from regisss as a code owner July 30, 2025 21:02
Copy link
Copy Markdown
Collaborator

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice PR! I think it's worth adding a regression test in https://github.com/huggingface/optimum-habana/blob/main/tests/test_examples.py. You can use the same command you provided in this PR.

@yafshar
Copy link
Copy Markdown
Contributor Author

yafshar commented Aug 11, 2025

Nice PR! I think it's worth adding a regression test in https://github.com/huggingface/optimum-habana/blob/main/tests/test_examples.py. You can use the same command you provided in this PR.

Thanks! It takes about 30 minutes to run, which is why I initially left it out. Please let me know if you'd like me to include it.

@regisss
Copy link
Copy Markdown
Collaborator

regisss commented Aug 11, 2025

Nice PR! I think it's worth adding a regression test in https://github.com/huggingface/optimum-habana/blob/main/tests/test_examples.py. You can use the same command you provided in this PR.

Thanks! It takes about 30 minutes to run, which is why I initially left it out. Please let me know if you'd like me to include it.

I think it's okay to include it. Worst case, I'll make it run less training steps later.

@yafshar
Copy link
Copy Markdown
Contributor Author

yafshar commented Aug 12, 2025

@regisss I added the test, just need to double check the reference numbers and then I will ping you. The G3 sounds OK, I only need to fix G2. I also reduced the max_steps to do the test in less time on 8 cards rather than 4

@yafshar
Copy link
Copy Markdown
Contributor Author

yafshar commented Aug 13, 2025

@regisss The PR is ready for your review. test commands are updated for 8 HPUs, so those and reference numbers can be further optimized in the future. For now, I followed the steps outlined in the README to mimic the test setup. I also excluded perplexity due to the long runtime, but it can be added later if needed.

Comment thread examples/trl/sft.py Outdated
Comment thread examples/trl/sft.py Outdated
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment thread optimum/habana/distributed/zero3_utils.py Outdated
Comment thread tests/test_examples.py Outdated
@yafshar yafshar force-pushed the mixtral/remove-sft-segfault-hack branch from 12cc695 to 491d626 Compare August 19, 2025 12:24
@yafshar yafshar force-pushed the mixtral/remove-sft-segfault-hack branch from 491d626 to 48f9c0c Compare August 19, 2025 12:25
Comment thread examples/trl/sft.py Outdated
Comment thread examples/trl/README.md Outdated
Copy link
Copy Markdown
Member

@IlyasMoutawwakil IlyasMoutawwakil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! thanks for iterating on this ! I left one last nit but it's not important

Copy link
Copy Markdown
Collaborator

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🚀

@regisss regisss merged commit d186356 into huggingface:main Aug 21, 2025
2 of 4 checks passed
@yafshar yafshar deleted the mixtral/remove-sft-segfault-hack branch August 21, 2025 12:10
gplutop7 pushed a commit to HabanaAI/optimum-habana-fork that referenced this pull request Oct 15, 2025
…eaf utility (huggingface#2185) (huggingface#607)

Co-authored-by: Yaser Afshar <yaser.afshar@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants