[model] gemma4#10346
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds support for the Gemma4 multimodal model. Key changes include the implementation of the Gemma4Plugin in mm_plugin.py for processing image and video inputs, the registration of the gemma4 prompt template (which includes support for reasoning/thought tokens), and the addition of Gemma4 model groups to the project constants. Furthermore, a CLAUDE.md guide has been added to provide repository-specific instructions for AI coding assistants, and unit tests have been updated to validate the new plugin. I have no feedback to provide.
Port upstream PR hiyouga#10346 to add Gemma 4 model support: - Gemma4 multimodal plugin (image processing, pan-and-scan) - Gemma4n multimodal plugin with audio support - Chat templates for gemma4 and gemma4n - Tool calling utilities for Gemma 4 - Model constants and visual module support - HyperParallel workflow compatibility
Port upstream PR hiyouga#10346 to add Gemma 4 model support: - Gemma4 multimodal plugin (image processing, pan-and-scan) - Gemma4n multimodal plugin with audio support - Chat templates for gemma4 and gemma4n - Tool calling utilities for Gemma 4 - Model constants and visual module support - HyperParallel workflow compatibility
|
Are E2B and E4B training supported in this version? I saw the warning said it is not ready. Please let us know. Thanks. |
|
@Kuangdd01 When will support zero3? I want to sft gemma4-31B-it. This would be very helpful |
You might want to try launching your fine-tuning task with FSDP sharding using Accelerate. It worked well for me. compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_forward_prefetch: false
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: true
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16 # or fp16
num_machines: 1 # the number of nodes
num_processes: 2 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
|
|
Will multiple GPUs for E2B version be supported in the future? Thanks. |
Thank you for your reply. But I meet a error: Must flatten tensors with uniform dtype but got torch.float32 and torch.bfloat16. Details are as follows: fsdp_config: train_config.yaml: command: error: |
|
Can you try with |
Is there something wrong with |
|
Hi, I'd like to ask about the environment setup for fine-tuning Gemma-4. When loading the processor, Gemma-4 requires Transformers version 5.5.0 or above, but LlamaFactory has a strict upper bound of below 5.2.0. I'm wondering how you resolved this version conflict? |
|
DISABLE_VERSION_CHECK=1 |
Just one GPU is too slow for my training. I have one node with 8 GPUs, and i want to use all 8 to do DDP training and speed up the process. |
* [model] gemma4 (hiyouga#10346) * fix: gemma4 mm_token_type_ids padding (hiyouga#10359) * [fix] Add token_type_ids for Gemma 4 text-only SFT training Gemma 4 models require token_type_ids and mm_token_type_ids tensors in the forward pass, but the data collator only creates these through the multimodal plugin when images/videos/audio are present. For text-only SFT, the forward pass crashes with: TypeError: roll(): argument input (position 1) must be Tensor, not bool Inject zero tensors for both fields when the model type is gemma4 or gemma4n and the fields are not already present from the mm plugin. * [fix] Gemma 4 text-only SFT training patches from Olivia benchmarks Fixes discovered during Gemma 4 31B SFT benchmarking on GH200: - mm_plugin: Skip media count validation when no media is provided, prevents crash when text data contains literal HTML <video>/<audio> tags - mm_plugin: Add default for full_image_sequence attribute to prevent AttributeError with newer transformers - visual: Wrap get_projector in try/except for text-only CausalLM mode - visual: Disable gemma4 composite model registration (not needed for text-only training, causes errors when vision/audio towers absent) - template: Remove thought_words and ReasoningTemplate from gemma4/gemma4n templates (causes StopIteration on non-thinking training data) - template: Switch gemma4 mm_plugin from gemma4-specific to base plugin (avoids multimodal processor dependency in text-only mode) - sft/trainer: Fix create_optimizer signature for transformers 5.5+ compat * [fp8] Add FP8 Gemma 4 test configs and fix mm_token_type_ids packing - Add FP8 smoke test configs for Gemma 4: TorchAO full FT, TorchAO LoRA, Transformer Engine, and BF16 baseline (using pruned 2-layer test model) - Fix mm_token_type_ids not being unpacked in neat_packing + FA2 path, which would cause shape mismatches for Gemma 4 text-only training * [fp8] Add backward compatibility for non-Ada GPUs Auto-detect GPU compute capability and use TorchAO emulation mode (emulate=True) on GPUs with compute < 8.9. This allows fp8: true to work on any CUDA GPU without crashing: - Native FP8: Ada Lovelace (RTX 40xx, L4) and Hopper (H100, GH200) use hardware FP8 kernels with rowwise scaling for full speedup - Emulated FP8: Ampere and older (RTX 30xx, A100) use software FP8 quantization/dequantization — no throughput gain but the code path is exercised and weights are stored in FP8 - TE backend: falls back to TorchAO emulation on non-native hardware instead of crashing Clear warnings are logged indicating native vs emulated mode. * [fp8] Add true FP8 weight and gradient storage for memory-efficient training Implements FP8StorageLinear with two complementary memory optimizations: Weight storage (float8_e4m3fn): weights stored in 1 byte/param between steps, decompressed to bf16 for forward/backward. Halves weight memory. Gradient compression (float8_e5m2): gradients compressed to fp8 via post-accumulate hooks as produced during backward. Supports gradient accumulation via dequant-add-requant cycle. Halves gradient memory. Combined: 2B/param (fp8) vs 4B/param (bf16) = 50% reduction. Works on any CUDA GPU. No Ada/Hopper required. Adds fp8_mode training arg with storage/pure/accelerate/auto modes. Integrates into SFT workflow with FP8StorageCallback managing the compress/materialize lifecycle around optimizer steps. * [model] Support EXPERTS_IMPLEMENTATION env var for MoE model loading Allows overriding the experts implementation (eager, batched_mm, grouped_mm) via environment variable when loading MoE models like Gemma 4 MoE or Mixtral. Prevents crashes on GPUs that lack support for specific expert computation backends. * [fp8] Add MoE expert module support for FP8 weight storage MoE models (Mixtral, Qwen MoE, GLM MoE, etc.) store expert weights as 3D nn.Parameter tensors, not nn.Linear. FP8StorageExperts wraps these modules with the same compress/materialize lifecycle as FP8StorageLinear. Expert detection is automatic via module naming convention (experts/expert) plus verification of 3D parameter tensors. Gradient compression hooks also cover expert parameters via unified _iter_fp8_params iterator. * [fp8] Add FP8 storage mode to PT and DPO workflows Same integration pattern as SFT: detect fp8+storage mode after model load, convert eligible modules, and attach the lifecycle callback. * [fp8] Add fused FP8Adafactor optimizer for memory-efficient training FP8Adafactor reads fp8 weights and gradients directly, performs the Adafactor update in fp32, and writes fp8 weights back -- all per-parameter. Only one parameter is in fp32 at any time during the optimizer step. Memory during optimizer step (31B model): Standard: ~124 GB (all weights+grads in bf16) FP8 fused: ~66 GB (fp8 storage + one fp32 param at a time) Also tags fp8-managed parameters with _fp8_ref for optimizer access, and adds fused_optimizer mode to FP8StorageCallback to skip redundant materialization/compression when the optimizer handles it directly. Automatically activated when optim=adafactor + fp8=true + fp8_mode=storage. Integrated into SFT, PT, and DPO workflows. * [fp8] Add pure FP8 training mode with native scaled_mm (Ada/Hopper) Pure mode uses torch._scaled_mm for hardware-accelerated fp8 matmul: - Forward: input (e4m3) x weight (e4m3) via scaled_mm - Backward dL/dX: grad_output (e5m2) x weight (e4m3) - Backward dL/dW: input (e4m3) x grad_output (e5m2) All three matmul operations use fp8 tensor cores for ~2x throughput on Ada Lovelace (CC 8.9) and Hopper (CC 9.0) GPUs. Gracefully falls back to storage mode on older GPUs (CC < 8.9). Usage: fp8: true, fp8_mode: pure * [fp8] Auto-detect GPU capability for pure vs storage mode selection fp8_mode: auto now selects pure mode on Ada/Hopper (CC >= 8.9) and storage mode on older GPUs. Also adds FP8PureLinear support to the shared compress/materialize utilities. * [fp8] Fix pure mode for 3D inputs and improve scaled_mm efficiency Handle arbitrary batch dimensions (batch, seq, hidden) by flattening to 2D for torch._scaled_mm which only supports 2D inputs. Pre-transpose weight to contiguous layout to avoid repeated non-contiguous operations. * [fp8] Fix ZeRO-3 compatibility: use module attributes for size checks ZeRO-3 partitions weight tensors into 1D shards, so weight.numel() returns the shard size instead of the full parameter count. This caused all 412 linear layers to be skipped during FP8 conversion ("converted 0, skipped 412"). Fix uses in_features * out_features (module attributes that survive partitioning) for the numel and alignment checks. Also adds on-the-fly fp8 quantization mode for ZeRO-3/FSDP: weights stay as bf16 Parameters (partitioned by the distributed backend), and are quantized to fp8 during forward after ZeRO-3 gathers them. This gives ~2x matmul speedup from native fp8 tensor cores while ZeRO-3 handles memory sharding. SFT workflow updated to detect ZeRO-3 and route to appropriate mode. * [fp8] Fix ZeRO-3 compat for MoE experts and sync PT/DPO workflows - MoE expert detection: use ds_shape (ZeRO-3 flattens params to 1D shards, so param.dim() returns 1 instead of 3). Falls back to param.shape when ds_shape is not present. - Expert buffer creation: use ds_shape for buffer shape instead of the partitioned param.shape. - PT and DPO workflows: sync with SFT's ZeRO-3-aware routing (skip storage mode with ZeRO-3, auto-detect via _detect_zero3). * [fix] Revert create_optimizer signature for transformers 5.2 compat * [fp8] Fix memory leak, scale dtype, cuBLASLt layout, and gradient flow - Storage mode: re-compress weights after forward to free bf16 memory, preventing all layers from staying decompressed simultaneously (OOM fix) - Pure mode: force fp32 scales in quantize functions (scaled_mm requirement) - Pure mode: fix cuBLASLt layout - B operand must be column-major - Pure mode: add weight_proxy parameter to _FP8MatmulFunction for gradient routing via STE (straight-through estimator), fixing broken grad flow Bugs found during GH200 testing on Olivia. * [fp8] Fix STE crash and missing gradient hooks in pure mode - Remove buffered mode from FP8PureLinear entirely: compress() zeroed self.weight.data to empty(0), making STE gradient routing crash with shape mismatch. On-the-fly mode keeps bf16 weights intact. - Add FP8PureLinear to _iter_fp8_params and _is_fp8_managed so gradient compression hooks cover pure mode params (prevents 52GB bf16 grad OOM). - Install fp8 gradient hooks after convert_model_to_fp8_pure in all workflows (SFT, PT, DPO). Bugs found during GH200 pure mode testing on Olivia. * [fp8] Unify storage + native matmul: fp8 weights with scaled_mm on Ada/Hopper FP8StorageLinear now auto-detects native fp8 matmul support (sm_89+) and uses torch._scaled_mm directly on compressed fp8 buffers — no materialize needed during forward. This gives both fp8 weight storage (1 byte/param) AND native fp8 compute speedup in a single code path. - Add use_native_fp8 flag to FP8StorageLinear, auto-detected in convert_model_to_fp8_storage - Forward skips materialize when compressed + native fp8: runs _FP8MatmulFunction directly on fp8 buffers, gradients route through weight via STE - Callback materializes weights before optimizer step, re-compresses after - Remove fp8_mode training arg — single unified path auto-detects hardware - Simplify all workflows (SFT/PT/DPO) to single convert_model_to_fp8_storage call Memory for 26B MoE on GH200: ~26GB fp8 weights + 26GB fp8 grads = ~52GB + optimizer states. Previous pure mode needed 52GB bf16 weights. * [fp8] Native fp8 checkpoint save/load for ecosystem compatibility Override _save_to_state_dict and _load_from_state_dict on FP8StorageLinear so checkpoints use standard keys with native fp8 dtype: Save: _weight_fp8 → 'weight' (float8_e4m3fn), _weight_scale → 'weight_scale' Load: detect fp8 dtype in 'weight' key, load directly with scale if present, fall back to bf16→fp8 quantization if scale missing Handles three checkpoint formats transparently: 1. FP8-native: weight is float8 + weight_scale present → direct load 2. Compressed: weight is empty(0) + _weight_fp8 buffer → direct load 3. Standard bf16: weight is bf16 → load then compress on_save is now a no-op since _save_to_state_dict handles everything. Checkpoints are compatible with vLLM, transformers, and other engines. * [fp8] Fix grad checkpointing crash: use backward hook for compress compress() in forward() zeros self.weight.data before backward can use it for grad_input = grad_output @ weight, causing size mismatch with gradient checkpointing (which re-runs forward then backward). Fix: register_full_backward_hook that compresses AFTER gradient computation. Forward only materializes, backward uses bf16 weight, hook re-compresses. Lifecycle with gradient checkpointing: 1. Forward: materialize fp8→bf16 (per layer) 2. Grad ckpt re-runs forward: materialize again 3. Backward: uses bf16 weight for grad computation 4. Backward hook: compress bf16→fp8, frees memory Applied to both FP8StorageLinear and FP8StorageExperts. * [fp8] Disable native fp8 matmul in storage mode (incompatible lifecycles) STE in _FP8MatmulFunction needs non-empty self.weight for gradient routing, but storage mode compress() zeros it to empty(0). These are fundamentally incompatible. Storage mode uses F.linear with materialized bf16 weights + backward hook to re-compress. * [fp8] Fix grad checkpointing crash: use backward hook for compress With gradient checkpointing, compress during initial forward (no_grad context) to free bf16 memory per-layer. During backward re-run (enable_grad), keep weight materialized for gradient computation; backward hook re-compresses after grads are done. Applies to both FP8StorageLinear and FP8StorageExperts. Peak memory: ~60GB for 26B MoE (was 78GB+ OOM before). * [fp8] Clean up dead native fp8 path and fix missing logger imports - Remove use_native_fp8 flag from FP8StorageLinear (was always False, incompatible with storage mode's compress lifecycle) - Remove dead _FP8MatmulFunction forward path that was never reached - Remove FP8PureLinear from isinstance checks (not used in workflows) - Add missing logger imports in pt/workflow.py and dpo/workflow.py * [fp8] Native fp8 matmul for storage mode + unit tests Enable native torch._scaled_mm on Ada/Hopper GPUs within FP8StorageLinear, keeping the 1-byte/param compressed weight lifecycle. Key changes: - FP8StorageLinear: add use_native_fp8 flag, forward branches to _FP8MatmulFunction when compressed + native supported - _FP8MatmulFunction: accept module arg, manually route gradients to parameter (bypasses broken autograd fp8 accumulation), support fp8 gradient compression hooks - convert_model_to_fp8_storage: auto-detect GPU capability and ZeRO-3 - Add unit tests for quantize/dequantize, storage linear, pure linear, and native storage linear paths --------- Co-authored-by: Kingsley <kingsleydodonow@gmail.com>
* [model] gemma4 (hiyouga#10346) * fix: gemma4 mm_token_type_ids padding (hiyouga#10359) * [fix] Add token_type_ids for Gemma 4 text-only SFT training Gemma 4 models require token_type_ids and mm_token_type_ids tensors in the forward pass, but the data collator only creates these through the multimodal plugin when images/videos/audio are present. For text-only SFT, the forward pass crashes with: TypeError: roll(): argument input (position 1) must be Tensor, not bool Inject zero tensors for both fields when the model type is gemma4 or gemma4n and the fields are not already present from the mm plugin. * [fix] Gemma 4 text-only SFT training patches from Olivia benchmarks Fixes discovered during Gemma 4 31B SFT benchmarking on GH200: - mm_plugin: Skip media count validation when no media is provided, prevents crash when text data contains literal HTML <video>/<audio> tags - mm_plugin: Add default for full_image_sequence attribute to prevent AttributeError with newer transformers - visual: Wrap get_projector in try/except for text-only CausalLM mode - visual: Disable gemma4 composite model registration (not needed for text-only training, causes errors when vision/audio towers absent) - template: Remove thought_words and ReasoningTemplate from gemma4/gemma4n templates (causes StopIteration on non-thinking training data) - template: Switch gemma4 mm_plugin from gemma4-specific to base plugin (avoids multimodal processor dependency in text-only mode) - sft/trainer: Fix create_optimizer signature for transformers 5.5+ compat * [fp8] Add FP8 Gemma 4 test configs and fix mm_token_type_ids packing - Add FP8 smoke test configs for Gemma 4: TorchAO full FT, TorchAO LoRA, Transformer Engine, and BF16 baseline (using pruned 2-layer test model) - Fix mm_token_type_ids not being unpacked in neat_packing + FA2 path, which would cause shape mismatches for Gemma 4 text-only training * [fp8] Add backward compatibility for non-Ada GPUs Auto-detect GPU compute capability and use TorchAO emulation mode (emulate=True) on GPUs with compute < 8.9. This allows fp8: true to work on any CUDA GPU without crashing: - Native FP8: Ada Lovelace (RTX 40xx, L4) and Hopper (H100, GH200) use hardware FP8 kernels with rowwise scaling for full speedup - Emulated FP8: Ampere and older (RTX 30xx, A100) use software FP8 quantization/dequantization — no throughput gain but the code path is exercised and weights are stored in FP8 - TE backend: falls back to TorchAO emulation on non-native hardware instead of crashing Clear warnings are logged indicating native vs emulated mode. * [fp8] Add true FP8 weight and gradient storage for memory-efficient training Implements FP8StorageLinear with two complementary memory optimizations: Weight storage (float8_e4m3fn): weights stored in 1 byte/param between steps, decompressed to bf16 for forward/backward. Halves weight memory. Gradient compression (float8_e5m2): gradients compressed to fp8 via post-accumulate hooks as produced during backward. Supports gradient accumulation via dequant-add-requant cycle. Halves gradient memory. Combined: 2B/param (fp8) vs 4B/param (bf16) = 50% reduction. Works on any CUDA GPU. No Ada/Hopper required. Adds fp8_mode training arg with storage/pure/accelerate/auto modes. Integrates into SFT workflow with FP8StorageCallback managing the compress/materialize lifecycle around optimizer steps. * [model] Support EXPERTS_IMPLEMENTATION env var for MoE model loading Allows overriding the experts implementation (eager, batched_mm, grouped_mm) via environment variable when loading MoE models like Gemma 4 MoE or Mixtral. Prevents crashes on GPUs that lack support for specific expert computation backends. * [fp8] Add MoE expert module support for FP8 weight storage MoE models (Mixtral, Qwen MoE, GLM MoE, etc.) store expert weights as 3D nn.Parameter tensors, not nn.Linear. FP8StorageExperts wraps these modules with the same compress/materialize lifecycle as FP8StorageLinear. Expert detection is automatic via module naming convention (experts/expert) plus verification of 3D parameter tensors. Gradient compression hooks also cover expert parameters via unified _iter_fp8_params iterator. * [fp8] Add FP8 storage mode to PT and DPO workflows Same integration pattern as SFT: detect fp8+storage mode after model load, convert eligible modules, and attach the lifecycle callback. * [fp8] Add fused FP8Adafactor optimizer for memory-efficient training FP8Adafactor reads fp8 weights and gradients directly, performs the Adafactor update in fp32, and writes fp8 weights back -- all per-parameter. Only one parameter is in fp32 at any time during the optimizer step. Memory during optimizer step (31B model): Standard: ~124 GB (all weights+grads in bf16) FP8 fused: ~66 GB (fp8 storage + one fp32 param at a time) Also tags fp8-managed parameters with _fp8_ref for optimizer access, and adds fused_optimizer mode to FP8StorageCallback to skip redundant materialization/compression when the optimizer handles it directly. Automatically activated when optim=adafactor + fp8=true + fp8_mode=storage. Integrated into SFT, PT, and DPO workflows. * [fp8] Add pure FP8 training mode with native scaled_mm (Ada/Hopper) Pure mode uses torch._scaled_mm for hardware-accelerated fp8 matmul: - Forward: input (e4m3) x weight (e4m3) via scaled_mm - Backward dL/dX: grad_output (e5m2) x weight (e4m3) - Backward dL/dW: input (e4m3) x grad_output (e5m2) All three matmul operations use fp8 tensor cores for ~2x throughput on Ada Lovelace (CC 8.9) and Hopper (CC 9.0) GPUs. Gracefully falls back to storage mode on older GPUs (CC < 8.9). Usage: fp8: true, fp8_mode: pure * [fp8] Auto-detect GPU capability for pure vs storage mode selection fp8_mode: auto now selects pure mode on Ada/Hopper (CC >= 8.9) and storage mode on older GPUs. Also adds FP8PureLinear support to the shared compress/materialize utilities. * [fp8] Fix pure mode for 3D inputs and improve scaled_mm efficiency Handle arbitrary batch dimensions (batch, seq, hidden) by flattening to 2D for torch._scaled_mm which only supports 2D inputs. Pre-transpose weight to contiguous layout to avoid repeated non-contiguous operations. * [fp8] Fix ZeRO-3 compatibility: use module attributes for size checks ZeRO-3 partitions weight tensors into 1D shards, so weight.numel() returns the shard size instead of the full parameter count. This caused all 412 linear layers to be skipped during FP8 conversion ("converted 0, skipped 412"). Fix uses in_features * out_features (module attributes that survive partitioning) for the numel and alignment checks. Also adds on-the-fly fp8 quantization mode for ZeRO-3/FSDP: weights stay as bf16 Parameters (partitioned by the distributed backend), and are quantized to fp8 during forward after ZeRO-3 gathers them. This gives ~2x matmul speedup from native fp8 tensor cores while ZeRO-3 handles memory sharding. SFT workflow updated to detect ZeRO-3 and route to appropriate mode. * [fp8] Fix ZeRO-3 compat for MoE experts and sync PT/DPO workflows - MoE expert detection: use ds_shape (ZeRO-3 flattens params to 1D shards, so param.dim() returns 1 instead of 3). Falls back to param.shape when ds_shape is not present. - Expert buffer creation: use ds_shape for buffer shape instead of the partitioned param.shape. - PT and DPO workflows: sync with SFT's ZeRO-3-aware routing (skip storage mode with ZeRO-3, auto-detect via _detect_zero3). * [fix] Revert create_optimizer signature for transformers 5.2 compat * [fp8] Fix memory leak, scale dtype, cuBLASLt layout, and gradient flow - Storage mode: re-compress weights after forward to free bf16 memory, preventing all layers from staying decompressed simultaneously (OOM fix) - Pure mode: force fp32 scales in quantize functions (scaled_mm requirement) - Pure mode: fix cuBLASLt layout - B operand must be column-major - Pure mode: add weight_proxy parameter to _FP8MatmulFunction for gradient routing via STE (straight-through estimator), fixing broken grad flow Bugs found during GH200 testing on Olivia. * [fp8] Fix STE crash and missing gradient hooks in pure mode - Remove buffered mode from FP8PureLinear entirely: compress() zeroed self.weight.data to empty(0), making STE gradient routing crash with shape mismatch. On-the-fly mode keeps bf16 weights intact. - Add FP8PureLinear to _iter_fp8_params and _is_fp8_managed so gradient compression hooks cover pure mode params (prevents 52GB bf16 grad OOM). - Install fp8 gradient hooks after convert_model_to_fp8_pure in all workflows (SFT, PT, DPO). Bugs found during GH200 pure mode testing on Olivia. * [fp8] Unify storage + native matmul: fp8 weights with scaled_mm on Ada/Hopper FP8StorageLinear now auto-detects native fp8 matmul support (sm_89+) and uses torch._scaled_mm directly on compressed fp8 buffers — no materialize needed during forward. This gives both fp8 weight storage (1 byte/param) AND native fp8 compute speedup in a single code path. - Add use_native_fp8 flag to FP8StorageLinear, auto-detected in convert_model_to_fp8_storage - Forward skips materialize when compressed + native fp8: runs _FP8MatmulFunction directly on fp8 buffers, gradients route through weight via STE - Callback materializes weights before optimizer step, re-compresses after - Remove fp8_mode training arg — single unified path auto-detects hardware - Simplify all workflows (SFT/PT/DPO) to single convert_model_to_fp8_storage call Memory for 26B MoE on GH200: ~26GB fp8 weights + 26GB fp8 grads = ~52GB + optimizer states. Previous pure mode needed 52GB bf16 weights. * [fp8] Native fp8 checkpoint save/load for ecosystem compatibility Override _save_to_state_dict and _load_from_state_dict on FP8StorageLinear so checkpoints use standard keys with native fp8 dtype: Save: _weight_fp8 → 'weight' (float8_e4m3fn), _weight_scale → 'weight_scale' Load: detect fp8 dtype in 'weight' key, load directly with scale if present, fall back to bf16→fp8 quantization if scale missing Handles three checkpoint formats transparently: 1. FP8-native: weight is float8 + weight_scale present → direct load 2. Compressed: weight is empty(0) + _weight_fp8 buffer → direct load 3. Standard bf16: weight is bf16 → load then compress on_save is now a no-op since _save_to_state_dict handles everything. Checkpoints are compatible with vLLM, transformers, and other engines. * [fp8] Fix grad checkpointing crash: use backward hook for compress compress() in forward() zeros self.weight.data before backward can use it for grad_input = grad_output @ weight, causing size mismatch with gradient checkpointing (which re-runs forward then backward). Fix: register_full_backward_hook that compresses AFTER gradient computation. Forward only materializes, backward uses bf16 weight, hook re-compresses. Lifecycle with gradient checkpointing: 1. Forward: materialize fp8→bf16 (per layer) 2. Grad ckpt re-runs forward: materialize again 3. Backward: uses bf16 weight for grad computation 4. Backward hook: compress bf16→fp8, frees memory Applied to both FP8StorageLinear and FP8StorageExperts. * [fp8] Disable native fp8 matmul in storage mode (incompatible lifecycles) STE in _FP8MatmulFunction needs non-empty self.weight for gradient routing, but storage mode compress() zeros it to empty(0). These are fundamentally incompatible. Storage mode uses F.linear with materialized bf16 weights + backward hook to re-compress. * [fp8] Fix grad checkpointing crash: use backward hook for compress With gradient checkpointing, compress during initial forward (no_grad context) to free bf16 memory per-layer. During backward re-run (enable_grad), keep weight materialized for gradient computation; backward hook re-compresses after grads are done. Applies to both FP8StorageLinear and FP8StorageExperts. Peak memory: ~60GB for 26B MoE (was 78GB+ OOM before). * [fp8] Clean up dead native fp8 path and fix missing logger imports - Remove use_native_fp8 flag from FP8StorageLinear (was always False, incompatible with storage mode's compress lifecycle) - Remove dead _FP8MatmulFunction forward path that was never reached - Remove FP8PureLinear from isinstance checks (not used in workflows) - Add missing logger imports in pt/workflow.py and dpo/workflow.py * [fp8] Native fp8 matmul for storage mode + unit tests Enable native torch._scaled_mm on Ada/Hopper GPUs within FP8StorageLinear, keeping the 1-byte/param compressed weight lifecycle. Key changes: - FP8StorageLinear: add use_native_fp8 flag, forward branches to _FP8MatmulFunction when compressed + native supported - _FP8MatmulFunction: accept module arg, manually route gradients to parameter (bypasses broken autograd fp8 accumulation), support fp8 gradient compression hooks - convert_model_to_fp8_storage: auto-detect GPU capability and ZeRO-3 - Add unit tests for quantize/dequantize, storage linear, pure linear, and native storage linear paths --------- Co-authored-by: Kingsley <kingsleydodonow@gmail.com>
What does this PR do?
fix #10343
Before submitting