Skip to content

Gemma4 fixes and profiler#3591

Merged
winglian merged 4 commits into
mainfrom
misc-fixes-20260409
Apr 10, 2026
Merged

Gemma4 fixes and profiler#3591
winglian merged 4 commits into
mainfrom
misc-fixes-20260409

Conversation

@winglian
Copy link
Copy Markdown
Collaborator

@winglian winglian commented Apr 10, 2026

Summary by CodeRabbit

Release Notes

  • New Features

    • Added alternative installation method using uv package manager
    • Enabled Vision-Language Model (VLM) training support with configuration guidance
    • Added training profiling and performance analysis capabilities
    • Introduced example configurations for Gemma-4 and Qwen3.5 multimodal fine-tuning
    • Added flash-linear attention as optional dependency
  • Documentation

    • Expanded installation documentation with uv-based setup
    • Added VLM usage guide with decision trees for model selection
    • Documented profiling workflow and expected outputs
    • Included Gemma-4 multimodal training configuration examples
  • Performance

    • Optimized Gemma-4 training computation
    • Added Liger kernel optimizations for Gemma-4

@winglian winglian requested a review from NanoCode012 April 10, 2026 06:08
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 10, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 794c9ace-8dda-42cd-820e-2234672b5f38

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR introduces Gemma-4 Vision Language Model training support with unified profiling capabilities. Changes include new uv-based installation instructions, VLM training documentation and configuration examples, a comprehensive profiling analysis tool, trainer modifications for Gemma-4 loss computation, and Liger kernel optimizations for Gemma-4 components including LoRA support.

Changes

Cohort / File(s) Summary
Installation & Setup
README.md, setup.py
Added uv-based installation subsection with torch/torchvision/axolotl/flash-attn setup; added flash-linear extra for dependencies.
Documentation
docs/agents/model_architectures.md, docs/agents/sft.md, docs/multimodal.qmd
Added VLM baseline configuration guidance, decision tree, plugins/optimizations (Cut Cross Entropy, ScatterMoE), Gemma-4 VLM training sections; introduced profiling config fields and trace analysis guidance.
Configuration Examples
examples/gemma4/e2b-vision-lora.yaml, examples/qwen3.5/35b-a3b-moe-vision-lora.yaml
New Gemma-4 and Qwen 3.5 vision LoRA training configurations with multimodal setup, MoE-specific kernel flags, and expert-LoRA targeting.
Profiling Infrastructure
scripts/analyze_profile.py
New 1483-line profiling analysis script supporting trace and memory snapshot parsing, CPU overhead detection, fragmentation diagnosis, A/B comparison, and OOM risk assessment.
Training Infrastructure
src/axolotl/core/trainers/base.py
Added Gemma-4-specific loss computation in AxolotlTrainer.compute_loss that extracts vocab_size and calls unwrapped.loss_function.
Liger Integration
src/axolotl/integrations/liger/plugin.py
Added Gemma-4 monkey-patching in pre_model_load for RMSNorm, MLP, LayerNorm, and CrossEntropyLoss with Liger optimized kernels.
LoRA Kernels
src/axolotl/monkeypatch/lora_kernels.py
Added Gemma-4 QKV patch variant supporting new shared_kv_states parameter and updated KV-sharing logic for transformers >= 5.6.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

under review

Suggested reviewers

  • NanoCode012
  • djsaunde
  • SalmanMohammadi
🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.63% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title captures the two primary focuses of this pull request: Gemma4-specific fixes and the addition of profiling capabilities.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch misc-fixes-20260409

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 10, 2026

📖 Documentation Preview: https://69d92b487955924ee94d243c--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit a2a064b

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

🧹 Nitpick comments (2)
examples/gemma4/e2b-vision-lora.yaml (1)

3-4: Clarify the top comment to match actual training scope.

Current text says “Vision fine-tuning,” but this config freezes multimodal modules and tunes LM LoRA adapters. A clearer description will prevent misinterpretation.

✏️ Proposed wording
-# Vision fine-tuning of the multimodal Gemma4 model.
-# Uses the base ProcessingStrategy (auto-detects image_token from processor).
+# Multimodal instruction tuning on image-text data with Gemma4.
+# Vision/audio modules remain frozen; LoRA is applied to language-model layers.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemma4/e2b-vision-lora.yaml` around lines 3 - 4, Update the top
comment to accurately describe the training scope: replace "Vision fine-tuning
of the multimodal Gemma4 model." with a line stating that multimodal/vision
encoders are frozen and only LM LoRA adapters are being trained (e.g.,
"Fine-tuning LM LoRA adapters on multimodal Gemma4 with vision/multimodal
modules frozen; ProcessingStrategy auto-detects image_token from processor.").
Ensure the second comment about ProcessingStrategy remains unchanged.
docs/agents/model_architectures.md (1)

142-143: Soften the loss-convergence claim to avoid over-promising.

The current wording is very deterministic; suggest framing it as typical/observed behavior to reduce user confusion when runs differ.

✏️ Proposed wording
-Starting VLM loss of ~8-15 is **expected** (not a bug). The model converges to <1.0 within 30-50 steps.
+Starting VLM loss of ~8-15 is **expected** (not a bug). In typical runs, loss often drops below 1.0 within ~30-50 steps.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/agents/model_architectures.md` around lines 142 - 143, Update the
deterministic sentence "Starting VLM loss of ~8-15 is **expected** (not a bug).
The model converges to <1.0 within 30-50 steps." to a softer, empirical
phrasing: say that a starting VLM loss of ~8-15 is typical/commonly observed and
that many runs converge below 1.0 within ~30–50 steps, but results can vary
across runs and setups; replace the exact wording in model_architectures.md to
reflect this less absolute framing.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@docs/agents/model_architectures.md`:
- Around line 16-27: The fenced decision-tree block in
docs/agents/model_architectures.md lacks a language specifier causing
markdownlint MD040; update the opening fence to include a language identifier
(e.g., use "text") so the block becomes ```text and keep the existing contents
unchanged to satisfy the linter and preserve formatting.

In `@docs/agents/sft.md`:
- Around line 104-110: The example command only points at the trace file, which
prevents the analyzer from also loading the memory snapshot; update the docs so
the example calls the analyzer with the run directory (e.g., pass output_dir/
instead of output_dir/profiler_trace.json) so scripts/analyze_profile.py can
locate both profiler_trace.json and snapshot.pickle when inspecting a run.

In `@docs/multimodal.qmd`:
- Line 173: Update the sentence that currently states Gemma4 auto-sets
use_reentrant=False and ddp_find_unused_parameters=True to include the exception
for activation_offloading; explicitly note that when activation_offloading:
true, the ddp_find_unused_parameters setting is skipped (so the auto-detected
ddp_find_unused_parameters=True does not apply), and keep the rest of the
guidance about FSDP2 and fsdp_transformer_layer_cls_to_wrap:
Gemma4TextDecoderLayer unchanged; ensure you reference Gemma4, use_reentrant,
ddp_find_unused_parameters, and activation_offloading in the updated wording so
the caveat is clear.

In `@README.md`:
- Around line 98-124: The README has inconsistent Python versions: the
prerequisites state Python 3.11 while the "Using uv (recommended)" section
hardcodes 3.12 in the commands (uv venv --python 3.12, uv run --python 3.12);
update the text to clarify whether 3.12 is required or only preferred and make
the commands match the quick-start requirement—either change the prerequisites
to 3.12 or replace both uv commands to use 3.11 and add a parenthetical note
like "or 3.12 preferred" if appropriate so all references (uv venv --python, uv
run --python) are consistent with the declared required version.

In `@scripts/analyze_profile.py`:
- Around line 174-188: The code hardcodes n_steps_profiled=3 and only filters
cuda_events for the warmup cutoff, causing ms/step, GPU utilization and
recommendations to be wrong; update the logic to derive n_steps_profiled from
the actual profiler_steps (or compute it from the timestamps/count of steps) and
whenever you apply the warmup cutoff to cuda_events also apply the same cutoff
to wall_clock_us and any other CPU/time aggregates by recomputing min_ts/max_ts
and total_span from the filtered events (or by excluding events with ts <=
cutoff_ts), and decrement/adjust n_steps_profiled consistently (same change
needed in the later block around the other cutoff at lines 223-237); ensure all
reported per-step metrics use the filtered timestamps/span and the correct
n_steps_profiled variable (referencing n_steps_profiled, cuda_events,
wall_clock_us, cutoff_ts).
- Around line 736-742: The load_snapshot function currently calls pickle.load()
unsafely; add a clear docstring to load_snapshot(path) that states snapshots
must only be loaded from trusted sources and that deserializing untrusted
snapshot.pickle files can execute arbitrary code, and update the CLI help text
for the --path and --compare options to include a short security warning
pointing users to only use trusted snapshot files or to use alternative safe
formats; ensure references mention the load_snapshot function and the --path and
--compare CLI flags so reviewers can locate the changes.

In `@setup.py`:
- Around line 170-173: The "flash-linear" extra in setup.py currently always
re-adds "flash-linear-attention" and "causal-conv1d", undoing the platform
filtering done by parse_requirements(); update the extras_require entry for the
"flash-linear" key so it conditionally includes those packages only on supported
platforms (same check/masking logic used by parse_requirements()), i.e., ensure
the "flash-linear" list is built/filtered using the same platform guard or
helper that removes "flash-linear-attention" on ARM64/macOS so
axolotl[flash-linear] cannot reintroduce unsupported deps.

In `@src/axolotl/integrations/liger/plugin.py`:
- Around line 225-270: The Gemma4 branch currently ignores
cfg.liger_fused_linear_cross_entropy; add explicit handling inside the gemma4
branch (where cfg.liger_rms_norm, cfg.liger_glu_activation, cfg.liger_rope,
cfg.liger_layer_norm, cfg.liger_cross_entropy are processed) to either apply the
fused linear cross-entropy replacement on modeling_gemma4 (e.g., set
modeling_gemma4.nn.FusedLinearCrossEntropy = <appropriate Liger FLCE class> or
modeling_gemma4.nn.CrossEntropyLoss = LigerFusedLinearCrossEntropy if that’s the
intended symbol) or, if FLCE is unsupported for Gemma4, emit a clear LOG.warning
like the RoPE case stating "Liger fused linear cross entropy is not compatible
with Gemma4; skipping." Reference cfg.liger_fused_linear_cross_entropy,
modeling_gemma4, LigerCrossEntropyLoss (or LigerFusedLinearCrossEntropy), and
LOG.warning to locate where to add this check.

---

Nitpick comments:
In `@docs/agents/model_architectures.md`:
- Around line 142-143: Update the deterministic sentence "Starting VLM loss of
~8-15 is **expected** (not a bug). The model converges to <1.0 within 30-50
steps." to a softer, empirical phrasing: say that a starting VLM loss of ~8-15
is typical/commonly observed and that many runs converge below 1.0 within ~30–50
steps, but results can vary across runs and setups; replace the exact wording in
model_architectures.md to reflect this less absolute framing.

In `@examples/gemma4/e2b-vision-lora.yaml`:
- Around line 3-4: Update the top comment to accurately describe the training
scope: replace "Vision fine-tuning of the multimodal Gemma4 model." with a line
stating that multimodal/vision encoders are frozen and only LM LoRA adapters are
being trained (e.g., "Fine-tuning LM LoRA adapters on multimodal Gemma4 with
vision/multimodal modules frozen; ProcessingStrategy auto-detects image_token
from processor."). Ensure the second comment about ProcessingStrategy remains
unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: d98133ef-6a8d-445f-91fe-73173e3c737d

📥 Commits

Reviewing files that changed from the base of the PR and between 4ef608d and c09bb74.

📒 Files selected for processing (11)
  • README.md
  • docs/agents/model_architectures.md
  • docs/agents/sft.md
  • docs/multimodal.qmd
  • examples/gemma4/e2b-vision-lora.yaml
  • examples/qwen3.5/35b-a3b-moe-vision-lora.yaml
  • scripts/analyze_profile.py
  • setup.py
  • src/axolotl/core/trainers/base.py
  • src/axolotl/integrations/liger/plugin.py
  • src/axolotl/monkeypatch/lora_kernels.py

Comment thread docs/agents/model_architectures.md Outdated
Comment on lines +16 to +27
```
Is the model multimodal (has vision/audio encoder)?
├─ YES: Add `freeze_mm_modules: true` if training text only
│ Add `chat_template: <model_template>` (e.g. gemma4, qwen3_5, gemma3)
│ LoRA: use regex `lora_target_modules` to restrict to language model
└─ NO: Train as a regular text model

Is the model MoE (e.g. Gemma4 26B-A4B, Qwen3.5 35B-A3B)?
├─ YES: Add `lora_target_parameters` for expert LoRA
│ Consider ScatterMoE kernels (see Plugins section)
└─ NO: Standard LoRA config
```
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add a language identifier to the decision-tree fenced block.

Line 16 uses an unlabeled fenced block, which triggers markdownlint MD040.

✏️ Proposed fix
-```
+```text
 Is the model multimodal (has vision/audio encoder)?
   ├─ YES: Add `freeze_mm_modules: true` if training text only
   │       Add `chat_template: <model_template>` (e.g. gemma4, qwen3_5, gemma3)
   │       LoRA: use regex `lora_target_modules` to restrict to language model
   └─ NO: Train as a regular text model
@@
 Is the model MoE (e.g. Gemma4 26B-A4B, Qwen3.5 35B-A3B)?
   ├─ YES: Add `lora_target_parameters` for expert LoRA
   │       Consider ScatterMoE kernels (see Plugins section)
   └─ NO: Standard LoRA config
</details>

<!-- suggestion_start -->

<details>
<summary>📝 Committable suggestion</summary>

> ‼️ **IMPORTANT**
> Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

```suggestion

🧰 Tools
🪛 markdownlint-cli2 (0.22.0)

[warning] 16-16: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/agents/model_architectures.md` around lines 16 - 27, The fenced
decision-tree block in docs/agents/model_architectures.md lacks a language
specifier causing markdownlint MD040; update the opening fence to include a
language identifier (e.g., use "text") so the block becomes ```text and keep the
existing contents unchanged to satisfy the linter and preserve formatting.

Comment thread docs/agents/sft.md
Comment on lines +104 to +110
This produces `profiler_trace.json` (Chrome trace) and `snapshot.pickle` (memory snapshot) in `output_dir`.
View the Chrome trace at `chrome://tracing`.

To programmatically inspect the trace:
```bash
python scripts/analyze_profile.py output_dir/profiler_trace.json
```
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Point the example at the run directory if you want memory analysis too.

Passing output_dir/profiler_trace.json only analyzes the trace. Using output_dir/ matches the script's own usage examples and lets it load snapshot.pickle as well.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/agents/sft.md` around lines 104 - 110, The example command only points
at the trace file, which prevents the analyzer from also loading the memory
snapshot; update the docs so the example calls the analyzer with the run
directory (e.g., pass output_dir/ instead of output_dir/profiler_trace.json) so
scripts/analyze_profile.py can locate both profiler_trace.json and
snapshot.pickle when inspecting a run.

Comment thread docs/multimodal.qmd Outdated
:::

::: {.callout-tip}
For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True`. For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Document the activation_offloading exception for ddp_find_unused_parameters.

Line 173 reads as unconditional, but Gemma4 guidance elsewhere notes this is skipped when activation_offloading: true. Add that caveat here for consistency.

✏️ Proposed fix
-For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True`. For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`.
+For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True` (except when `activation_offloading: true`, where `find_unused_parameters` is skipped). For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True`. For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`.
For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True` (except when `activation_offloading: true`, where `find_unused_parameters` is skipped). For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/multimodal.qmd` at line 173, Update the sentence that currently states
Gemma4 auto-sets use_reentrant=False and ddp_find_unused_parameters=True to
include the exception for activation_offloading; explicitly note that when
activation_offloading: true, the ddp_find_unused_parameters setting is skipped
(so the auto-detected ddp_find_unused_parameters=True does not apply), and keep
the rest of the guidance about FSDP2 and fsdp_transformer_layer_cls_to_wrap:
Gemma4TextDecoderLayer unchanged; ensure you reference Gemma4, use_reentrant,
ddp_find_unused_parameters, and activation_offloading in the updated wording so
the caveat is clear.

Comment thread README.md
Comment on lines +98 to +124
#### Using uv (recommended)

```bash
# install uv if you don't already have it installed
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env

# CUDA 12.8.1 tends to have better package compatibility
export UV_TORCH_BACKEND=cu128

# create a new virtual environment
uv venv --python 3.12
source .venv/bin/activate

uv pip install torch==2.10.0 torchvision
uv pip install --no-build-isolation axolotl[deepspeed]

# recommended - install cut-cross-entropy
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"

# (optional) - prefetch flash-attn2 and causal-conv1d kernels
uv run --python 3.12 python -c "from kernels import get_kernel; get_kernel('kernels-community/flash-attn2'); get_kernel('kernels-community/causal-conv1d')"

# Download example axolotl configs, deepspeed configs
axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL
```
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Keep the recommended Python version consistent with the quick-start requirements.

The prerequisites still say Python 3.11, but the new recommended path provisions 3.12 and hardcodes uv run --python 3.12. Please clarify whether 3.12 is required or just preferred so readers don't set up the wrong interpreter.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@README.md` around lines 98 - 124, The README has inconsistent Python
versions: the prerequisites state Python 3.11 while the "Using uv (recommended)"
section hardcodes 3.12 in the commands (uv venv --python 3.12, uv run --python
3.12); update the text to clarify whether 3.12 is required or only preferred and
make the commands match the quick-start requirement—either change the
prerequisites to 3.12 or replace both uv commands to use 3.11 and add a
parenthetical note like "or 3.12 preferred" if appropriate so all references (uv
venv --python, uv run --python) are consistent with the declared required
version.

Comment thread scripts/analyze_profile.py Outdated
Comment on lines +174 to +188
n_steps_profiled = 3 # typical profiler_steps value

if skip_warmup and len(cuda_events) > 1000:
timestamps = sorted(set(float(ev.get("ts", 0)) for ev in cuda_events))
min_ts, max_ts = timestamps[0], timestamps[-1]
total_span = max_ts - min_ts

# Step 0 is warmup (Triton compilation + autotune). It's typically
# the slowest step by far. Use 45% of wall-clock as cutoff -- step 0
# usually takes >50% of total time when it includes compilation.
cutoff_ts = min_ts + total_span * 0.45
before = len(cuda_events)
cuda_events = [ev for ev in cuda_events if float(ev.get("ts", 0)) > cutoff_ts]
n_steps_profiled -= 1 # now analyzing steps 1+
print(f" Excluding step 0 (warmup): {before:,} -> {len(cuda_events):,} events")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Per-step numbers are wrong for anything except the assumed 3-step trace.

n_steps_profiled is hardcoded to 3, and the warmup cutoff is only applied to cuda_events while wall_clock_us still spans all CPU ops. Any run with a different profiler_steps value—or the default warmup exclusion enabled—will print incorrect ms/step, GPU utilization, and summary recommendations.

Also applies to: 223-237

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@scripts/analyze_profile.py` around lines 174 - 188, The code hardcodes
n_steps_profiled=3 and only filters cuda_events for the warmup cutoff, causing
ms/step, GPU utilization and recommendations to be wrong; update the logic to
derive n_steps_profiled from the actual profiler_steps (or compute it from the
timestamps/count of steps) and whenever you apply the warmup cutoff to
cuda_events also apply the same cutoff to wall_clock_us and any other CPU/time
aggregates by recomputing min_ts/max_ts and total_span from the filtered events
(or by excluding events with ts <= cutoff_ts), and decrement/adjust
n_steps_profiled consistently (same change needed in the later block around the
other cutoff at lines 223-237); ensure all reported per-step metrics use the
filtered timestamps/span and the correct n_steps_profiled variable (referencing
n_steps_profiled, cuda_events, wall_clock_us, cutoff_ts).

Comment on lines +736 to +742
def load_snapshot(path):
snap_file = Path(path) / "snapshot.pickle" if Path(path).is_dir() else Path(path)
if not snap_file.exists():
return None
print(f"Loading {snap_file.name} ({snap_file.stat().st_size / 1e6:.0f} MB)...")
with open(snap_file, "rb") as f:
return pickle.load(f) # nosec B301
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, check the exact file and lines
head -n 750 scripts/analyze_profile.py | tail -n 20 | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 884


🏁 Script executed:

# Get the full context of the load_snapshot function and surrounding code
sed -n '730,750p' scripts/analyze_profile.py | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 901


🏁 Script executed:

# Search for all calls to load_snapshot to understand how the function is used
rg "load_snapshot" --type python -B 2 -A 2

Repository: axolotl-ai-cloud/axolotl

Length of output: 98


🏁 Script executed:

# Check if there's any documentation about security or trust boundaries
rg -i "trust|security|pickle" --type markdown --type txt

Repository: axolotl-ai-cloud/axolotl

Length of output: 1035


🏁 Script executed:

# Look for CLI argument handling related to snapshot paths
rg "snapshot|path" scripts/ --type python | head -30

Repository: axolotl-ai-cloud/axolotl

Length of output: 98


🏁 Script executed:

# Search for load_snapshot calls without type restriction
rg "load_snapshot" -B 2 -A 2

Repository: axolotl-ai-cloud/axolotl

Length of output: 888


🏁 Script executed:

# Check the documentation file that mentions snapshot.pickle
cat -n docs/agents/sft.md | grep -A 5 -B 5 "snapshot.pickle"

Repository: axolotl-ai-cloud/axolotl

Length of output: 513


🏁 Script executed:

# Look for CLI entry points or main functions
rg "def main|argparse|click" scripts/analyze_profile.py -B 2 -A 5

Repository: axolotl-ai-cloud/axolotl

Length of output: 533


🏁 Script executed:

# Check if there's a CLI interface or how analyze_profile is invoked
head -n 50 scripts/analyze_profile.py | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 2042


🏁 Script executed:

# Search for any command-line argument handling for snapshot path
rg "snapshot|path" scripts/analyze_profile.py | head -20

Repository: axolotl-ai-cloud/axolotl

Length of output: 1101


🏁 Script executed:

# Get the full argparse setup to see argument descriptions
sed -n '790,850p' scripts/analyze_profile.py | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 2868


🏁 Script executed:

# Search for any existing security-related comments or warnings in the file
rg -i "security|trust|unsafe|warning|untrusted" scripts/analyze_profile.py

Repository: axolotl-ai-cloud/axolotl

Length of output: 138


🏁 Script executed:

# Check if there's any validation of the path argument before use
sed -n '800,900p' scripts/analyze_profile.py | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 4513


🏁 Script executed:

# Get the complete argument parser definition
sed -n '810,860p' scripts/analyze_profile.py | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 2231


🏁 Script executed:

# Search for the main() function and how arguments are documented
rg "add_argument|help=" scripts/analyze_profile.py -A 1

Repository: axolotl-ai-cloud/axolotl

Length of output: 866


Add security guidance for loading snapshots from untrusted sources.

snapshot.pickle is deserialized using pickle.load(), which can execute arbitrary code. While this is safe for self-generated profiler artifacts, it's unsafe for snapshots shared from external or untrusted sources. The load_snapshot() function lacks documentation of this trust boundary, and neither the CLI help text nor the docs currently warn users about this risk.

Consider adding a docstring to load_snapshot() clarifying that snapshots should only be loaded from trusted sources, and updating the CLI help text for --path and --compare to note the security implications.

🧰 Tools
🪛 Ruff (0.15.9)

[error] 742-742: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue

(S301)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@scripts/analyze_profile.py` around lines 736 - 742, The load_snapshot
function currently calls pickle.load() unsafely; add a clear docstring to
load_snapshot(path) that states snapshots must only be loaded from trusted
sources and that deserializing untrusted snapshot.pickle files can execute
arbitrary code, and update the CLI help text for the --path and --compare
options to include a short security warning pointing users to only use trusted
snapshot files or to use alternative safe formats; ensure references mention the
load_snapshot function and the --path and --compare CLI flags so reviewers can
locate the changes.

Comment thread setup.py Outdated
Comment on lines +170 to +173
"flash-linear": [
"flash-linear-attention",
"causal-conv1d",
],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Gate the flash-linear extra on the same unsupported platforms as the base deps.

parse_requirements() already strips flash-linear-attention from the core install on ARM64/macOS, but this extra re-adds it unconditionally. That means axolotl[flash-linear] can still fail on the exact platforms the base install is trying to protect.

Suggested fix
         if platform.machine() == "aarch64":
             # skip on ARM64
             skip_packages = [
                 "torchao",
                 "fla-core",
                 "flash-linear-attention",
             ]
+            extras_require_map.pop("flash-linear", None)
             _install_requires = [
                 req
                 for req in _install_requires
                 if re.split(r"[>=<]", req)[0].strip() not in skip_packages
             ]
         if "Darwin" in platform.system():
             # skip packages not compatible with OSX
             skip_packages = [
                 "bitsandbytes",
                 "triton",
                 "mamba-ssm",
                 "xformers",
                 "liger-kernel",
             ]
+            extras_require_map.pop("flash-linear", None)
             _install_requires = [
                 req
                 for req in _install_requires
                 if re.split(r"[>=<]", req)[0].strip() not in skip_packages
             ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@setup.py` around lines 170 - 173, The "flash-linear" extra in setup.py
currently always re-adds "flash-linear-attention" and "causal-conv1d", undoing
the platform filtering done by parse_requirements(); update the extras_require
entry for the "flash-linear" key so it conditionally includes those packages
only on supported platforms (same check/masking logic used by
parse_requirements()), i.e., ensure the "flash-linear" list is built/filtered
using the same platform guard or helper that removes "flash-linear-attention" on
ARM64/macOS so axolotl[flash-linear] cannot reintroduce unsupported deps.

Comment on lines +225 to +270
elif cfg.model_config_type in ("gemma4", "gemma4_text"):
# Gemma4: offset=0 (NOT 1 like Gemma3), in_place=False required for
# gradient checkpointing compatibility, RoPE incompatible (separate q/k).
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from transformers.models.gemma4 import modeling_gemma4

if cfg.liger_rms_norm:
_OrigGemma4RMSNorm = modeling_gemma4.Gemma4RMSNorm

class _LigerGemma4RMSNorm(LigerRMSNorm):
"""LigerRMSNorm for Gemma4 with in_place=False and with_scale support."""

def __new__(cls, dim, eps=1e-6, with_scale=True):
if not with_scale:
return _OrigGemma4RMSNorm(dim, eps, with_scale=False)
return super().__new__(cls)

def __init__(self, dim, eps=1e-6, with_scale=True):
if not with_scale:
return
# offset=0.0 (standard), in_place=False (gradient checkpointing safe)
super().__init__(
dim, eps, offset=0.0, casting_mode="llama", in_place=False
)

modeling_gemma4.Gemma4RMSNorm = _LigerGemma4RMSNorm
if cfg.liger_glu_activation:

class _LigerGemma4MLP(LigerGEGLUMLP):
def __init__(self, config, layer_idx=None):
super().__init__(config)

modeling_gemma4.Gemma4TextMLP = _LigerGemma4MLP
if cfg.liger_rope:
LOG.warning(
"Liger RoPE is not compatible with Gemma4 (separate q/k application). Skipping."
)
if cfg.liger_layer_norm:
modeling_gemma4.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
modeling_gemma4.nn.CrossEntropyLoss = LigerCrossEntropyLoss
LOG.info(
f"Applied Liger kernels for gemma4: "
f"rms_norm={cfg.liger_rms_norm}, glu={cfg.liger_glu_activation}, "
f"rope=False (incompatible), layer_norm={cfg.liger_layer_norm}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
rg -n -C3 'gemma4|fused_linear_cross_entropy|lce_forward' src/axolotl/integrations/liger

Repository: axolotl-ai-cloud/axolotl

Length of output: 40617


liger_fused_linear_cross_entropy is silently ignored for Gemma4.

The Gemma4 branch patches RMSNorm, GLU, LayerNorm, and CrossEntropy, but does not handle cfg.liger_fused_linear_cross_entropy. Unlike the RoPE incompatibility (which raises a warning), enabling this config for Gemma4 receives no warning or implementation. The elif fallback at line 271 cannot apply a generic FLCE patch because the Gemma4 condition already matched. This creates a silent config mismatch where the setting appears valid but produces no effect.

Add explicit handling: either implement Gemma4 FLCE or log a warning that it is unsupported, consistent with the RoPE incompatibility message.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/integrations/liger/plugin.py` around lines 225 - 270, The Gemma4
branch currently ignores cfg.liger_fused_linear_cross_entropy; add explicit
handling inside the gemma4 branch (where cfg.liger_rms_norm,
cfg.liger_glu_activation, cfg.liger_rope, cfg.liger_layer_norm,
cfg.liger_cross_entropy are processed) to either apply the fused linear
cross-entropy replacement on modeling_gemma4 (e.g., set
modeling_gemma4.nn.FusedLinearCrossEntropy = <appropriate Liger FLCE class> or
modeling_gemma4.nn.CrossEntropyLoss = LigerFusedLinearCrossEntropy if that’s the
intended symbol) or, if FLCE is unsupported for Gemma4, emit a clear LOG.warning
like the RoPE case stating "Liger fused linear cross entropy is not compatible
with Gemma4; skipping." Reference cfg.liger_fused_linear_cross_entropy,
modeling_gemma4, LigerCrossEntropyLoss (or LigerFusedLinearCrossEntropy), and
LOG.warning to locate where to add this check.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 10, 2026

Codecov Report

❌ Patch coverage is 2.56410% with 38 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/integrations/liger/plugin.py 0.00% 29 Missing ⚠️
src/axolotl/core/trainers/base.py 10.00% 9 Missing ⚠️

📢 Thoughts on this report? Let us know!

@winglian winglian added the scheduled_release This PR is slated for the upcoming release label Apr 10, 2026
@winglian winglian merged commit 29fa4de into main Apr 10, 2026
18 checks passed
@winglian winglian deleted the misc-fixes-20260409 branch April 10, 2026 20:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

scheduled_release This PR is slated for the upcoming release

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant