Skip to content

[model] gemma4#10346

Merged
hiyouga merged 4 commits into
hiyouga:mainfrom
Kuangdd01:upd/gemma4
Apr 5, 2026
Merged

[model] gemma4#10346
hiyouga merged 4 commits into
hiyouga:mainfrom
Kuangdd01:upd/gemma4

Conversation

@Kuangdd01
Copy link
Copy Markdown
Collaborator

@Kuangdd01 Kuangdd01 commented Apr 3, 2026

What does this PR do?

fix #10343

Before submitting

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the Gemma4 multimodal model. Key changes include the implementation of the Gemma4Plugin in mm_plugin.py for processing image and video inputs, the registration of the gemma4 prompt template (which includes support for reasoning/thought tokens), and the addition of Gemma4 model groups to the project constants. Furthermore, a CLAUDE.md guide has been added to provide repository-specific instructions for AI coding assistants, and unit tests have been updated to validate the new plugin. I have no feedback to provide.

@Kuangdd01 Kuangdd01 marked this pull request as ready for review April 4, 2026 16:21
marksverdhei added a commit to NationalLibraryOfNorway/LLaMA-Factory that referenced this pull request Apr 4, 2026
Port upstream PR hiyouga#10346 to add Gemma 4 model support:
- Gemma4 multimodal plugin (image processing, pan-and-scan)
- Gemma4n multimodal plugin with audio support
- Chat templates for gemma4 and gemma4n
- Tool calling utilities for Gemma 4
- Model constants and visual module support
- HyperParallel workflow compatibility
marksverdhei added a commit to NationalLibraryOfNorway/LLaMA-Factory that referenced this pull request Apr 4, 2026
Port upstream PR hiyouga#10346 to add Gemma 4 model support:
- Gemma4 multimodal plugin (image processing, pan-and-scan)
- Gemma4n multimodal plugin with audio support
- Chat templates for gemma4 and gemma4n
- Tool calling utilities for Gemma 4
- Model constants and visual module support
- HyperParallel workflow compatibility
@hiyouga hiyouga merged commit eae6f0b into hiyouga:main Apr 5, 2026
15 of 16 checks passed
@hiyouga hiyouga added the solved This problem has been already solved label Apr 5, 2026
@lwang10-atlassian
Copy link
Copy Markdown

Are E2B and E4B training supported in this version? I saw the warning said it is not ready. Please let us know. Thanks.

@Kuangdd01
Copy link
Copy Markdown
Collaborator Author

Kuangdd01 commented Apr 6, 2026

Currently, with workaround mentioned here and #10359 merged, it is ready to finetune E2B and E4B (DO NOT ENABLE DEEPSPEED_ZERO3).

@hucorz
Copy link
Copy Markdown

hucorz commented Apr 6, 2026

@Kuangdd01 When will support zero3? I want to sft gemma4-31B-it. This would be very helpful

@Kuangdd01
Copy link
Copy Markdown
Collaborator Author

Kuangdd01 commented Apr 6, 2026

@Kuangdd01 When will support zero3? I want to sft gemma4-31B-it. This would be very helpful

You might want to try launching your fine-tuning task with FSDP sharding using Accelerate. It worked well for me.

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_forward_prefetch: false
  fsdp_cpu_ram_efficient_loading: true
  fsdp_offload_params: true
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16  # or fp16
num_machines: 1  # the number of nodes
num_processes: 2  # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

@lwang10-atlassian
Copy link
Copy Markdown

Will multiple GPUs for E2B version be supported in the future? Thanks.

@hucorz
Copy link
Copy Markdown

hucorz commented Apr 6, 2026

@Kuangdd01 When will support zero3? I want to sft gemma4-31B-it. This would be very helpful

You might want to try launching your fine-tuning task with FSDP sharding using Accelerate. It worked well for me.

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_forward_prefetch: false
  fsdp_cpu_ram_efficient_loading: true
  fsdp_offload_params: true
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16  # or fp16
num_machines: 1  # the number of nodes
num_processes: 2  # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Thank you for your reply. But I meet a error: Must flatten tensors with uniform dtype but got torch.float32 and torch.bfloat16. Details are as follows:

fsdp_config:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_forward_prefetch: false
  fsdp_cpu_ram_efficient_loading: true
  fsdp_offload_params: true
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
  activation_checkpointing: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16  # or fp16
num_machines: 1  # the number of nodes
num_processes: 8  # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

train_config.yaml:

### model
model_name_or_path:models/lm/gemma-4-31B-it
image_max_pixels: 1048576
video_max_pixels: 16384
trust_remote_code: true

### method
stage: sft
do_train: true
finetuning_type: full

### dataset
dataset: t2i-pf-judger-fix
template: gemma4
cutoff_len: 10240
# max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 8

### output
output_dir: saves/t2i/pf-judger-fix-gemma4-31B-it
logging_steps: 10
save_steps: 50
plot_loss: true
overwrite_output_dir: true
save_only_model: true
report_to: swanlab  # choices: [none, wandb, tensorboard, swanlab, mlflow]
use_swanlab: true
swanlab_project: t2i-pf-judger
swanlab_run_name: t2i-pf-judger-fix-gemma4-31B-it

### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 16
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.05
bf16: true
ddp_timeout: 180000000

resume_from_checkpoint: null

### eval
val_size: 0.05
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 50

command:

accelerate launch --config_file examples/t2i/gemma_fsdp.yaml src/train.py examples/t2i/pf-judger-gemma.yaml

error:

[rank6]: Traceback (most recent call last):                                                                                                                                            
[rank6]:   File "/mnt/tidal-alsh01/usr/huichao/LLaMa-Factory/LlamaFactory/src/train.py", line 28, in <module>                                                                          
[rank6]:     main()                                                                        
[rank6]:   File "/mnt/tidal-alsh01/usr/huichao/LLaMa-Factory/LlamaFactory/src/train.py", line 19, in main
[rank6]:     run_exp()                                                                     
[rank6]:   File "/mnt/tidal-alsh01/usr/huichao/LLaMa-Factory/LlamaFactory/src/llamafactory/train/tuner.py", line 139, in run_exp
[rank6]:     _training_function(config={"args": args, "callbacks": callbacks})
[rank6]:   File "/mnt/tidal-alsh01/usr/huichao/LLaMa-Factory/LlamaFactory/src/llamafactory/train/tuner.py", line 107, in _training_function                                            
[rank6]:     run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
[rank6]:   File "/mnt/tidal-alsh01/usr/huichao/LLaMa-Factory/LlamaFactory/src/llamafactory/train/sft/workflow.py", line 140, in run_sft                                                
[rank6]:     train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
[rank6]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                 
[rank6]:   File "/root/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 1424, in train
[rank6]:     return inner_training_loop(                                                                                                                                               
[rank6]:            ^^^^^^^^^^^^^^^^^^^^                                                   
[rank6]:   File "/root/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 1463, in _inner_training_loop                                                            
[rank6]:     model, train_dataloader = self._prepare_for_training(max_steps, train_dataloader, resume_from_checkpoint)
[rank6]:                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/root/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 1590, in _prepare_for_training                                                           
[rank6]:     model = self.accelerator.prepare(self.model)       
[rank6]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^          
[rank6]:   File "/root/miniconda3/lib/python3.11/site-packages/accelerate/accelerator.py", line 1555, in prepare
[rank6]:     result = tuple(                                                               
[rank6]:              ^^^^^^                                                               
[rank6]:   File "/root/miniconda3/lib/python3.11/site-packages/accelerate/accelerator.py", line 1556, in <genexpr>
[rank6]:     self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[rank6]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/root/miniconda3/lib/python3.11/site-packages/accelerate/accelerator.py", line 1398, in _prepare_one
[rank6]:     return self.prepare_model(obj, device_placement=device_placement)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/root/miniconda3/lib/python3.11/site-packages/accelerate/accelerator.py", line 1889, in prepare_model                                                                 
[rank6]:     model = FSDP(model, **kwargs)                                                 
[rank6]:             ^^^^^^^^^^^^^^^^^^^^^                                                 
[rank6]:   File "/root/miniconda3/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 497, in __init__
[rank6]:     _init_param_handle_from_module(                                               
[rank6]:   File "/root/miniconda3/lib/python3.11/site-packages/torch/distributed/fsdp/_init_utils.py", line 622, in _init_param_handle_from_module
[rank6]:     _init_param_handle_from_params(state, managed_params, fully_sharded_module)                                                                                               
[rank6]:   File "/root/miniconda3/lib/python3.11/site-packages/torch/distributed/fsdp/_init_utils.py", line 634, in _init_param_handle_from_params
[rank6]:     handle = FlatParamHandle(
[rank6]:              ^^^^^^^^^^^^^^^^
[rank6]:   File "/root/miniconda3/lib/python3.11/site-packages/torch/distributed/fsdp/_flat_param.py", line 588, in __init__
[rank6]:     self._init_flat_param_and_metadata(
[rank6]:   File "/root/miniconda3/lib/python3.11/site-packages/torch/distributed/fsdp/_flat_param.py", line 644, in _init_flat_param_and_metadata
[rank6]:     ) = self._validate_tensors_to_flatten(params)
[rank6]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/root/miniconda3/lib/python3.11/site-packages/torch/distributed/fsdp/_flat_param.py", line 788, in _validate_tensors_to_flatten
[rank6]:     raise ValueError(
[rank6]: ValueError: Must flatten tensors with uniform dtype but got torch.float32 and torch.bfloat16

@Kuangdd01
Copy link
Copy Markdown
Collaborator Author

Can you try with --pure_bf16?

@Kuangdd01
Copy link
Copy Markdown
Collaborator Author

Will multiple GPUs for E2B version be supported in the future? Thanks.

Is there something wrong with use_cache=True & disable_gradient_checkpointing=True when ddp training?

@xianghuang2
Copy link
Copy Markdown

Hi, I'd like to ask about the environment setup for fine-tuning Gemma-4. When loading the processor, Gemma-4 requires Transformers version 5.5.0 or above, but LlamaFactory has a strict upper bound of below 5.2.0. I'm wondering how you resolved this version conflict?

@Kuangdd01
Copy link
Copy Markdown
Collaborator Author

DISABLE_VERSION_CHECK=1

@lwang10-atlassian
Copy link
Copy Markdown

Will multiple GPUs for E2B version be supported in the future? Thanks.

Is there something wrong with use_cache=True & disable_gradient_checkpointing=True when ddp training?

Just one GPU is too slow for my training. I have one node with 8 GPUs, and i want to use all 8 to do DDP training and speed up the process.

marksverdhei added a commit to NationalLibraryOfNorway/LLaMA-Factory that referenced this pull request Apr 9, 2026
* [model] gemma4 (hiyouga#10346)

* fix: gemma4 mm_token_type_ids padding (hiyouga#10359)

* [fix] Add token_type_ids for Gemma 4 text-only SFT training

Gemma 4 models require token_type_ids and mm_token_type_ids tensors in the
forward pass, but the data collator only creates these through the multimodal
plugin when images/videos/audio are present. For text-only SFT, the forward
pass crashes with: TypeError: roll(): argument input (position 1) must be
Tensor, not bool

Inject zero tensors for both fields when the model type is gemma4 or gemma4n
and the fields are not already present from the mm plugin.

* [fix] Gemma 4 text-only SFT training patches from Olivia benchmarks

Fixes discovered during Gemma 4 31B SFT benchmarking on GH200:

- mm_plugin: Skip media count validation when no media is provided, prevents
  crash when text data contains literal HTML <video>/<audio> tags
- mm_plugin: Add default for full_image_sequence attribute to prevent
  AttributeError with newer transformers
- visual: Wrap get_projector in try/except for text-only CausalLM mode
- visual: Disable gemma4 composite model registration (not needed for
  text-only training, causes errors when vision/audio towers absent)
- template: Remove thought_words and ReasoningTemplate from gemma4/gemma4n
  templates (causes StopIteration on non-thinking training data)
- template: Switch gemma4 mm_plugin from gemma4-specific to base plugin
  (avoids multimodal processor dependency in text-only mode)
- sft/trainer: Fix create_optimizer signature for transformers 5.5+ compat

* [fp8] Add FP8 Gemma 4 test configs and fix mm_token_type_ids packing

- Add FP8 smoke test configs for Gemma 4: TorchAO full FT, TorchAO LoRA,
  Transformer Engine, and BF16 baseline (using pruned 2-layer test model)
- Fix mm_token_type_ids not being unpacked in neat_packing + FA2 path,
  which would cause shape mismatches for Gemma 4 text-only training

* [fp8] Add backward compatibility for non-Ada GPUs

Auto-detect GPU compute capability and use TorchAO emulation mode
(emulate=True) on GPUs with compute < 8.9. This allows fp8: true
to work on any CUDA GPU without crashing:

- Native FP8: Ada Lovelace (RTX 40xx, L4) and Hopper (H100, GH200)
  use hardware FP8 kernels with rowwise scaling for full speedup
- Emulated FP8: Ampere and older (RTX 30xx, A100) use software
  FP8 quantization/dequantization — no throughput gain but the
  code path is exercised and weights are stored in FP8
- TE backend: falls back to TorchAO emulation on non-native hardware
  instead of crashing

Clear warnings are logged indicating native vs emulated mode.

* [fp8] Add true FP8 weight and gradient storage for memory-efficient training

Implements FP8StorageLinear with two complementary memory optimizations:

Weight storage (float8_e4m3fn): weights stored in 1 byte/param between
steps, decompressed to bf16 for forward/backward. Halves weight memory.

Gradient compression (float8_e5m2): gradients compressed to fp8 via
post-accumulate hooks as produced during backward. Supports gradient
accumulation via dequant-add-requant cycle. Halves gradient memory.

Combined: 2B/param (fp8) vs 4B/param (bf16) = 50% reduction.

Works on any CUDA GPU. No Ada/Hopper required. Adds fp8_mode training
arg with storage/pure/accelerate/auto modes. Integrates into SFT
workflow with FP8StorageCallback managing the compress/materialize
lifecycle around optimizer steps.

* [model] Support EXPERTS_IMPLEMENTATION env var for MoE model loading

Allows overriding the experts implementation (eager, batched_mm,
grouped_mm) via environment variable when loading MoE models like
Gemma 4 MoE or Mixtral. Prevents crashes on GPUs that lack support
for specific expert computation backends.

* [fp8] Add MoE expert module support for FP8 weight storage

MoE models (Mixtral, Qwen MoE, GLM MoE, etc.) store expert weights as
3D nn.Parameter tensors, not nn.Linear. FP8StorageExperts wraps these
modules with the same compress/materialize lifecycle as FP8StorageLinear.

Expert detection is automatic via module naming convention (experts/expert)
plus verification of 3D parameter tensors. Gradient compression hooks
also cover expert parameters via unified _iter_fp8_params iterator.

* [fp8] Add FP8 storage mode to PT and DPO workflows

Same integration pattern as SFT: detect fp8+storage mode after model
load, convert eligible modules, and attach the lifecycle callback.

* [fp8] Add fused FP8Adafactor optimizer for memory-efficient training

FP8Adafactor reads fp8 weights and gradients directly, performs the
Adafactor update in fp32, and writes fp8 weights back -- all per-parameter.
Only one parameter is in fp32 at any time during the optimizer step.

Memory during optimizer step (31B model):
  Standard:  ~124 GB (all weights+grads in bf16)
  FP8 fused: ~66 GB (fp8 storage + one fp32 param at a time)

Also tags fp8-managed parameters with _fp8_ref for optimizer access,
and adds fused_optimizer mode to FP8StorageCallback to skip redundant
materialization/compression when the optimizer handles it directly.

Automatically activated when optim=adafactor + fp8=true + fp8_mode=storage.
Integrated into SFT, PT, and DPO workflows.

* [fp8] Add pure FP8 training mode with native scaled_mm (Ada/Hopper)

Pure mode uses torch._scaled_mm for hardware-accelerated fp8 matmul:
- Forward: input (e4m3) x weight (e4m3) via scaled_mm
- Backward dL/dX: grad_output (e5m2) x weight (e4m3)
- Backward dL/dW: input (e4m3) x grad_output (e5m2)

All three matmul operations use fp8 tensor cores for ~2x throughput
on Ada Lovelace (CC 8.9) and Hopper (CC 9.0) GPUs.

Gracefully falls back to storage mode on older GPUs (CC < 8.9).

Usage: fp8: true, fp8_mode: pure

* [fp8] Auto-detect GPU capability for pure vs storage mode selection

fp8_mode: auto now selects pure mode on Ada/Hopper (CC >= 8.9) and
storage mode on older GPUs. Also adds FP8PureLinear support to the
shared compress/materialize utilities.

* [fp8] Fix pure mode for 3D inputs and improve scaled_mm efficiency

Handle arbitrary batch dimensions (batch, seq, hidden) by flattening
to 2D for torch._scaled_mm which only supports 2D inputs. Pre-transpose
weight to contiguous layout to avoid repeated non-contiguous operations.

* [fp8] Fix ZeRO-3 compatibility: use module attributes for size checks

ZeRO-3 partitions weight tensors into 1D shards, so weight.numel()
returns the shard size instead of the full parameter count. This caused
all 412 linear layers to be skipped during FP8 conversion ("converted 0,
skipped 412"). Fix uses in_features * out_features (module attributes
that survive partitioning) for the numel and alignment checks.

Also adds on-the-fly fp8 quantization mode for ZeRO-3/FSDP: weights
stay as bf16 Parameters (partitioned by the distributed backend), and
are quantized to fp8 during forward after ZeRO-3 gathers them. This
gives ~2x matmul speedup from native fp8 tensor cores while ZeRO-3
handles memory sharding.

SFT workflow updated to detect ZeRO-3 and route to appropriate mode.

* [fp8] Fix ZeRO-3 compat for MoE experts and sync PT/DPO workflows

- MoE expert detection: use ds_shape (ZeRO-3 flattens params to 1D
  shards, so param.dim() returns 1 instead of 3). Falls back to
  param.shape when ds_shape is not present.

- Expert buffer creation: use ds_shape for buffer shape instead of
  the partitioned param.shape.

- PT and DPO workflows: sync with SFT's ZeRO-3-aware routing (skip
  storage mode with ZeRO-3, auto-detect via _detect_zero3).

* [fix] Revert create_optimizer signature for transformers 5.2 compat

* [fp8] Fix memory leak, scale dtype, cuBLASLt layout, and gradient flow

- Storage mode: re-compress weights after forward to free bf16 memory,
  preventing all layers from staying decompressed simultaneously (OOM fix)
- Pure mode: force fp32 scales in quantize functions (scaled_mm requirement)
- Pure mode: fix cuBLASLt layout - B operand must be column-major
- Pure mode: add weight_proxy parameter to _FP8MatmulFunction for gradient
  routing via STE (straight-through estimator), fixing broken grad flow

Bugs found during GH200 testing on Olivia.

* [fp8] Fix STE crash and missing gradient hooks in pure mode

- Remove buffered mode from FP8PureLinear entirely: compress() zeroed
  self.weight.data to empty(0), making STE gradient routing crash with
  shape mismatch. On-the-fly mode keeps bf16 weights intact.
- Add FP8PureLinear to _iter_fp8_params and _is_fp8_managed so gradient
  compression hooks cover pure mode params (prevents 52GB bf16 grad OOM).
- Install fp8 gradient hooks after convert_model_to_fp8_pure in all
  workflows (SFT, PT, DPO).

Bugs found during GH200 pure mode testing on Olivia.

* [fp8] Unify storage + native matmul: fp8 weights with scaled_mm on Ada/Hopper

FP8StorageLinear now auto-detects native fp8 matmul support (sm_89+) and
uses torch._scaled_mm directly on compressed fp8 buffers — no materialize
needed during forward. This gives both fp8 weight storage (1 byte/param)
AND native fp8 compute speedup in a single code path.

- Add use_native_fp8 flag to FP8StorageLinear, auto-detected in
  convert_model_to_fp8_storage
- Forward skips materialize when compressed + native fp8: runs _FP8MatmulFunction
  directly on fp8 buffers, gradients route through weight via STE
- Callback materializes weights before optimizer step, re-compresses after
- Remove fp8_mode training arg — single unified path auto-detects hardware
- Simplify all workflows (SFT/PT/DPO) to single convert_model_to_fp8_storage call

Memory for 26B MoE on GH200: ~26GB fp8 weights + 26GB fp8 grads = ~52GB
+ optimizer states. Previous pure mode needed 52GB bf16 weights.

* [fp8] Native fp8 checkpoint save/load for ecosystem compatibility

Override _save_to_state_dict and _load_from_state_dict on FP8StorageLinear
so checkpoints use standard keys with native fp8 dtype:

Save: _weight_fp8 → 'weight' (float8_e4m3fn), _weight_scale → 'weight_scale'
Load: detect fp8 dtype in 'weight' key, load directly with scale if present,
      fall back to bf16→fp8 quantization if scale missing

Handles three checkpoint formats transparently:
1. FP8-native: weight is float8 + weight_scale present → direct load
2. Compressed: weight is empty(0) + _weight_fp8 buffer → direct load
3. Standard bf16: weight is bf16 → load then compress

on_save is now a no-op since _save_to_state_dict handles everything.
Checkpoints are compatible with vLLM, transformers, and other engines.

* [fp8] Fix grad checkpointing crash: use backward hook for compress

compress() in forward() zeros self.weight.data before backward can use it
for grad_input = grad_output @ weight, causing size mismatch with gradient
checkpointing (which re-runs forward then backward).

Fix: register_full_backward_hook that compresses AFTER gradient computation.
Forward only materializes, backward uses bf16 weight, hook re-compresses.

Lifecycle with gradient checkpointing:
1. Forward: materialize fp8→bf16 (per layer)
2. Grad ckpt re-runs forward: materialize again
3. Backward: uses bf16 weight for grad computation
4. Backward hook: compress bf16→fp8, frees memory

Applied to both FP8StorageLinear and FP8StorageExperts.

* [fp8] Disable native fp8 matmul in storage mode (incompatible lifecycles)

STE in _FP8MatmulFunction needs non-empty self.weight for gradient
routing, but storage mode compress() zeros it to empty(0). These are
fundamentally incompatible. Storage mode uses F.linear with materialized
bf16 weights + backward hook to re-compress.

* [fp8] Fix grad checkpointing crash: use backward hook for compress

With gradient checkpointing, compress during initial forward (no_grad
context) to free bf16 memory per-layer. During backward re-run
(enable_grad), keep weight materialized for gradient computation;
backward hook re-compresses after grads are done.

Applies to both FP8StorageLinear and FP8StorageExperts.

Peak memory: ~60GB for 26B MoE (was 78GB+ OOM before).

* [fp8] Clean up dead native fp8 path and fix missing logger imports

- Remove use_native_fp8 flag from FP8StorageLinear (was always False,
  incompatible with storage mode's compress lifecycle)
- Remove dead _FP8MatmulFunction forward path that was never reached
- Remove FP8PureLinear from isinstance checks (not used in workflows)
- Add missing logger imports in pt/workflow.py and dpo/workflow.py

* [fp8] Native fp8 matmul for storage mode + unit tests

Enable native torch._scaled_mm on Ada/Hopper GPUs within FP8StorageLinear,
keeping the 1-byte/param compressed weight lifecycle. Key changes:

- FP8StorageLinear: add use_native_fp8 flag, forward branches to
  _FP8MatmulFunction when compressed + native supported
- _FP8MatmulFunction: accept module arg, manually route gradients to
  parameter (bypasses broken autograd fp8 accumulation), support fp8
  gradient compression hooks
- convert_model_to_fp8_storage: auto-detect GPU capability and ZeRO-3
- Add unit tests for quantize/dequantize, storage linear, pure linear,
  and native storage linear paths

---------

Co-authored-by: Kingsley <kingsleydodonow@gmail.com>
marksverdhei added a commit to NationalLibraryOfNorway/LLaMA-Factory that referenced this pull request Apr 9, 2026
* [model] gemma4 (hiyouga#10346)

* fix: gemma4 mm_token_type_ids padding (hiyouga#10359)

* [fix] Add token_type_ids for Gemma 4 text-only SFT training

Gemma 4 models require token_type_ids and mm_token_type_ids tensors in the
forward pass, but the data collator only creates these through the multimodal
plugin when images/videos/audio are present. For text-only SFT, the forward
pass crashes with: TypeError: roll(): argument input (position 1) must be
Tensor, not bool

Inject zero tensors for both fields when the model type is gemma4 or gemma4n
and the fields are not already present from the mm plugin.

* [fix] Gemma 4 text-only SFT training patches from Olivia benchmarks

Fixes discovered during Gemma 4 31B SFT benchmarking on GH200:

- mm_plugin: Skip media count validation when no media is provided, prevents
  crash when text data contains literal HTML <video>/<audio> tags
- mm_plugin: Add default for full_image_sequence attribute to prevent
  AttributeError with newer transformers
- visual: Wrap get_projector in try/except for text-only CausalLM mode
- visual: Disable gemma4 composite model registration (not needed for
  text-only training, causes errors when vision/audio towers absent)
- template: Remove thought_words and ReasoningTemplate from gemma4/gemma4n
  templates (causes StopIteration on non-thinking training data)
- template: Switch gemma4 mm_plugin from gemma4-specific to base plugin
  (avoids multimodal processor dependency in text-only mode)
- sft/trainer: Fix create_optimizer signature for transformers 5.5+ compat

* [fp8] Add FP8 Gemma 4 test configs and fix mm_token_type_ids packing

- Add FP8 smoke test configs for Gemma 4: TorchAO full FT, TorchAO LoRA,
  Transformer Engine, and BF16 baseline (using pruned 2-layer test model)
- Fix mm_token_type_ids not being unpacked in neat_packing + FA2 path,
  which would cause shape mismatches for Gemma 4 text-only training

* [fp8] Add backward compatibility for non-Ada GPUs

Auto-detect GPU compute capability and use TorchAO emulation mode
(emulate=True) on GPUs with compute < 8.9. This allows fp8: true
to work on any CUDA GPU without crashing:

- Native FP8: Ada Lovelace (RTX 40xx, L4) and Hopper (H100, GH200)
  use hardware FP8 kernels with rowwise scaling for full speedup
- Emulated FP8: Ampere and older (RTX 30xx, A100) use software
  FP8 quantization/dequantization — no throughput gain but the
  code path is exercised and weights are stored in FP8
- TE backend: falls back to TorchAO emulation on non-native hardware
  instead of crashing

Clear warnings are logged indicating native vs emulated mode.

* [fp8] Add true FP8 weight and gradient storage for memory-efficient training

Implements FP8StorageLinear with two complementary memory optimizations:

Weight storage (float8_e4m3fn): weights stored in 1 byte/param between
steps, decompressed to bf16 for forward/backward. Halves weight memory.

Gradient compression (float8_e5m2): gradients compressed to fp8 via
post-accumulate hooks as produced during backward. Supports gradient
accumulation via dequant-add-requant cycle. Halves gradient memory.

Combined: 2B/param (fp8) vs 4B/param (bf16) = 50% reduction.

Works on any CUDA GPU. No Ada/Hopper required. Adds fp8_mode training
arg with storage/pure/accelerate/auto modes. Integrates into SFT
workflow with FP8StorageCallback managing the compress/materialize
lifecycle around optimizer steps.

* [model] Support EXPERTS_IMPLEMENTATION env var for MoE model loading

Allows overriding the experts implementation (eager, batched_mm,
grouped_mm) via environment variable when loading MoE models like
Gemma 4 MoE or Mixtral. Prevents crashes on GPUs that lack support
for specific expert computation backends.

* [fp8] Add MoE expert module support for FP8 weight storage

MoE models (Mixtral, Qwen MoE, GLM MoE, etc.) store expert weights as
3D nn.Parameter tensors, not nn.Linear. FP8StorageExperts wraps these
modules with the same compress/materialize lifecycle as FP8StorageLinear.

Expert detection is automatic via module naming convention (experts/expert)
plus verification of 3D parameter tensors. Gradient compression hooks
also cover expert parameters via unified _iter_fp8_params iterator.

* [fp8] Add FP8 storage mode to PT and DPO workflows

Same integration pattern as SFT: detect fp8+storage mode after model
load, convert eligible modules, and attach the lifecycle callback.

* [fp8] Add fused FP8Adafactor optimizer for memory-efficient training

FP8Adafactor reads fp8 weights and gradients directly, performs the
Adafactor update in fp32, and writes fp8 weights back -- all per-parameter.
Only one parameter is in fp32 at any time during the optimizer step.

Memory during optimizer step (31B model):
  Standard:  ~124 GB (all weights+grads in bf16)
  FP8 fused: ~66 GB (fp8 storage + one fp32 param at a time)

Also tags fp8-managed parameters with _fp8_ref for optimizer access,
and adds fused_optimizer mode to FP8StorageCallback to skip redundant
materialization/compression when the optimizer handles it directly.

Automatically activated when optim=adafactor + fp8=true + fp8_mode=storage.
Integrated into SFT, PT, and DPO workflows.

* [fp8] Add pure FP8 training mode with native scaled_mm (Ada/Hopper)

Pure mode uses torch._scaled_mm for hardware-accelerated fp8 matmul:
- Forward: input (e4m3) x weight (e4m3) via scaled_mm
- Backward dL/dX: grad_output (e5m2) x weight (e4m3)
- Backward dL/dW: input (e4m3) x grad_output (e5m2)

All three matmul operations use fp8 tensor cores for ~2x throughput
on Ada Lovelace (CC 8.9) and Hopper (CC 9.0) GPUs.

Gracefully falls back to storage mode on older GPUs (CC < 8.9).

Usage: fp8: true, fp8_mode: pure

* [fp8] Auto-detect GPU capability for pure vs storage mode selection

fp8_mode: auto now selects pure mode on Ada/Hopper (CC >= 8.9) and
storage mode on older GPUs. Also adds FP8PureLinear support to the
shared compress/materialize utilities.

* [fp8] Fix pure mode for 3D inputs and improve scaled_mm efficiency

Handle arbitrary batch dimensions (batch, seq, hidden) by flattening
to 2D for torch._scaled_mm which only supports 2D inputs. Pre-transpose
weight to contiguous layout to avoid repeated non-contiguous operations.

* [fp8] Fix ZeRO-3 compatibility: use module attributes for size checks

ZeRO-3 partitions weight tensors into 1D shards, so weight.numel()
returns the shard size instead of the full parameter count. This caused
all 412 linear layers to be skipped during FP8 conversion ("converted 0,
skipped 412"). Fix uses in_features * out_features (module attributes
that survive partitioning) for the numel and alignment checks.

Also adds on-the-fly fp8 quantization mode for ZeRO-3/FSDP: weights
stay as bf16 Parameters (partitioned by the distributed backend), and
are quantized to fp8 during forward after ZeRO-3 gathers them. This
gives ~2x matmul speedup from native fp8 tensor cores while ZeRO-3
handles memory sharding.

SFT workflow updated to detect ZeRO-3 and route to appropriate mode.

* [fp8] Fix ZeRO-3 compat for MoE experts and sync PT/DPO workflows

- MoE expert detection: use ds_shape (ZeRO-3 flattens params to 1D
  shards, so param.dim() returns 1 instead of 3). Falls back to
  param.shape when ds_shape is not present.

- Expert buffer creation: use ds_shape for buffer shape instead of
  the partitioned param.shape.

- PT and DPO workflows: sync with SFT's ZeRO-3-aware routing (skip
  storage mode with ZeRO-3, auto-detect via _detect_zero3).

* [fix] Revert create_optimizer signature for transformers 5.2 compat

* [fp8] Fix memory leak, scale dtype, cuBLASLt layout, and gradient flow

- Storage mode: re-compress weights after forward to free bf16 memory,
  preventing all layers from staying decompressed simultaneously (OOM fix)
- Pure mode: force fp32 scales in quantize functions (scaled_mm requirement)
- Pure mode: fix cuBLASLt layout - B operand must be column-major
- Pure mode: add weight_proxy parameter to _FP8MatmulFunction for gradient
  routing via STE (straight-through estimator), fixing broken grad flow

Bugs found during GH200 testing on Olivia.

* [fp8] Fix STE crash and missing gradient hooks in pure mode

- Remove buffered mode from FP8PureLinear entirely: compress() zeroed
  self.weight.data to empty(0), making STE gradient routing crash with
  shape mismatch. On-the-fly mode keeps bf16 weights intact.
- Add FP8PureLinear to _iter_fp8_params and _is_fp8_managed so gradient
  compression hooks cover pure mode params (prevents 52GB bf16 grad OOM).
- Install fp8 gradient hooks after convert_model_to_fp8_pure in all
  workflows (SFT, PT, DPO).

Bugs found during GH200 pure mode testing on Olivia.

* [fp8] Unify storage + native matmul: fp8 weights with scaled_mm on Ada/Hopper

FP8StorageLinear now auto-detects native fp8 matmul support (sm_89+) and
uses torch._scaled_mm directly on compressed fp8 buffers — no materialize
needed during forward. This gives both fp8 weight storage (1 byte/param)
AND native fp8 compute speedup in a single code path.

- Add use_native_fp8 flag to FP8StorageLinear, auto-detected in
  convert_model_to_fp8_storage
- Forward skips materialize when compressed + native fp8: runs _FP8MatmulFunction
  directly on fp8 buffers, gradients route through weight via STE
- Callback materializes weights before optimizer step, re-compresses after
- Remove fp8_mode training arg — single unified path auto-detects hardware
- Simplify all workflows (SFT/PT/DPO) to single convert_model_to_fp8_storage call

Memory for 26B MoE on GH200: ~26GB fp8 weights + 26GB fp8 grads = ~52GB
+ optimizer states. Previous pure mode needed 52GB bf16 weights.

* [fp8] Native fp8 checkpoint save/load for ecosystem compatibility

Override _save_to_state_dict and _load_from_state_dict on FP8StorageLinear
so checkpoints use standard keys with native fp8 dtype:

Save: _weight_fp8 → 'weight' (float8_e4m3fn), _weight_scale → 'weight_scale'
Load: detect fp8 dtype in 'weight' key, load directly with scale if present,
      fall back to bf16→fp8 quantization if scale missing

Handles three checkpoint formats transparently:
1. FP8-native: weight is float8 + weight_scale present → direct load
2. Compressed: weight is empty(0) + _weight_fp8 buffer → direct load
3. Standard bf16: weight is bf16 → load then compress

on_save is now a no-op since _save_to_state_dict handles everything.
Checkpoints are compatible with vLLM, transformers, and other engines.

* [fp8] Fix grad checkpointing crash: use backward hook for compress

compress() in forward() zeros self.weight.data before backward can use it
for grad_input = grad_output @ weight, causing size mismatch with gradient
checkpointing (which re-runs forward then backward).

Fix: register_full_backward_hook that compresses AFTER gradient computation.
Forward only materializes, backward uses bf16 weight, hook re-compresses.

Lifecycle with gradient checkpointing:
1. Forward: materialize fp8→bf16 (per layer)
2. Grad ckpt re-runs forward: materialize again
3. Backward: uses bf16 weight for grad computation
4. Backward hook: compress bf16→fp8, frees memory

Applied to both FP8StorageLinear and FP8StorageExperts.

* [fp8] Disable native fp8 matmul in storage mode (incompatible lifecycles)

STE in _FP8MatmulFunction needs non-empty self.weight for gradient
routing, but storage mode compress() zeros it to empty(0). These are
fundamentally incompatible. Storage mode uses F.linear with materialized
bf16 weights + backward hook to re-compress.

* [fp8] Fix grad checkpointing crash: use backward hook for compress

With gradient checkpointing, compress during initial forward (no_grad
context) to free bf16 memory per-layer. During backward re-run
(enable_grad), keep weight materialized for gradient computation;
backward hook re-compresses after grads are done.

Applies to both FP8StorageLinear and FP8StorageExperts.

Peak memory: ~60GB for 26B MoE (was 78GB+ OOM before).

* [fp8] Clean up dead native fp8 path and fix missing logger imports

- Remove use_native_fp8 flag from FP8StorageLinear (was always False,
  incompatible with storage mode's compress lifecycle)
- Remove dead _FP8MatmulFunction forward path that was never reached
- Remove FP8PureLinear from isinstance checks (not used in workflows)
- Add missing logger imports in pt/workflow.py and dpo/workflow.py

* [fp8] Native fp8 matmul for storage mode + unit tests

Enable native torch._scaled_mm on Ada/Hopper GPUs within FP8StorageLinear,
keeping the 1-byte/param compressed weight lifecycle. Key changes:

- FP8StorageLinear: add use_native_fp8 flag, forward branches to
  _FP8MatmulFunction when compressed + native supported
- _FP8MatmulFunction: accept module arg, manually route gradients to
  parameter (bypasses broken autograd fp8 accumulation), support fp8
  gradient compression hooks
- convert_model_to_fp8_storage: auto-detect GPU capability and ZeRO-3
- Add unit tests for quantize/dequantize, storage linear, pure linear,
  and native storage linear paths

---------

Co-authored-by: Kingsley <kingsleydodonow@gmail.com>
michael-beebe pushed a commit to michael-beebe/LlamaFactory that referenced this pull request Apr 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

solved This problem has been already solved

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Will gemma4 be supported soon?

5 participants