Gemma4 fixes and profiler#3591
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR introduces Gemma-4 Vision Language Model training support with unified profiling capabilities. Changes include new uv-based installation instructions, VLM training documentation and configuration examples, a comprehensive profiling analysis tool, trainer modifications for Gemma-4 loss computation, and Liger kernel optimizations for Gemma-4 components including LoRA support. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
📖 Documentation Preview: https://69d92b487955924ee94d243c--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit a2a064b |
There was a problem hiding this comment.
Actionable comments posted: 8
🧹 Nitpick comments (2)
examples/gemma4/e2b-vision-lora.yaml (1)
3-4: Clarify the top comment to match actual training scope.Current text says “Vision fine-tuning,” but this config freezes multimodal modules and tunes LM LoRA adapters. A clearer description will prevent misinterpretation.
✏️ Proposed wording
-# Vision fine-tuning of the multimodal Gemma4 model. -# Uses the base ProcessingStrategy (auto-detects image_token from processor). +# Multimodal instruction tuning on image-text data with Gemma4. +# Vision/audio modules remain frozen; LoRA is applied to language-model layers.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemma4/e2b-vision-lora.yaml` around lines 3 - 4, Update the top comment to accurately describe the training scope: replace "Vision fine-tuning of the multimodal Gemma4 model." with a line stating that multimodal/vision encoders are frozen and only LM LoRA adapters are being trained (e.g., "Fine-tuning LM LoRA adapters on multimodal Gemma4 with vision/multimodal modules frozen; ProcessingStrategy auto-detects image_token from processor."). Ensure the second comment about ProcessingStrategy remains unchanged.docs/agents/model_architectures.md (1)
142-143: Soften the loss-convergence claim to avoid over-promising.The current wording is very deterministic; suggest framing it as typical/observed behavior to reduce user confusion when runs differ.
✏️ Proposed wording
-Starting VLM loss of ~8-15 is **expected** (not a bug). The model converges to <1.0 within 30-50 steps. +Starting VLM loss of ~8-15 is **expected** (not a bug). In typical runs, loss often drops below 1.0 within ~30-50 steps.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/agents/model_architectures.md` around lines 142 - 143, Update the deterministic sentence "Starting VLM loss of ~8-15 is **expected** (not a bug). The model converges to <1.0 within 30-50 steps." to a softer, empirical phrasing: say that a starting VLM loss of ~8-15 is typical/commonly observed and that many runs converge below 1.0 within ~30–50 steps, but results can vary across runs and setups; replace the exact wording in model_architectures.md to reflect this less absolute framing.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/agents/model_architectures.md`:
- Around line 16-27: The fenced decision-tree block in
docs/agents/model_architectures.md lacks a language specifier causing
markdownlint MD040; update the opening fence to include a language identifier
(e.g., use "text") so the block becomes ```text and keep the existing contents
unchanged to satisfy the linter and preserve formatting.
In `@docs/agents/sft.md`:
- Around line 104-110: The example command only points at the trace file, which
prevents the analyzer from also loading the memory snapshot; update the docs so
the example calls the analyzer with the run directory (e.g., pass output_dir/
instead of output_dir/profiler_trace.json) so scripts/analyze_profile.py can
locate both profiler_trace.json and snapshot.pickle when inspecting a run.
In `@docs/multimodal.qmd`:
- Line 173: Update the sentence that currently states Gemma4 auto-sets
use_reentrant=False and ddp_find_unused_parameters=True to include the exception
for activation_offloading; explicitly note that when activation_offloading:
true, the ddp_find_unused_parameters setting is skipped (so the auto-detected
ddp_find_unused_parameters=True does not apply), and keep the rest of the
guidance about FSDP2 and fsdp_transformer_layer_cls_to_wrap:
Gemma4TextDecoderLayer unchanged; ensure you reference Gemma4, use_reentrant,
ddp_find_unused_parameters, and activation_offloading in the updated wording so
the caveat is clear.
In `@README.md`:
- Around line 98-124: The README has inconsistent Python versions: the
prerequisites state Python 3.11 while the "Using uv (recommended)" section
hardcodes 3.12 in the commands (uv venv --python 3.12, uv run --python 3.12);
update the text to clarify whether 3.12 is required or only preferred and make
the commands match the quick-start requirement—either change the prerequisites
to 3.12 or replace both uv commands to use 3.11 and add a parenthetical note
like "or 3.12 preferred" if appropriate so all references (uv venv --python, uv
run --python) are consistent with the declared required version.
In `@scripts/analyze_profile.py`:
- Around line 174-188: The code hardcodes n_steps_profiled=3 and only filters
cuda_events for the warmup cutoff, causing ms/step, GPU utilization and
recommendations to be wrong; update the logic to derive n_steps_profiled from
the actual profiler_steps (or compute it from the timestamps/count of steps) and
whenever you apply the warmup cutoff to cuda_events also apply the same cutoff
to wall_clock_us and any other CPU/time aggregates by recomputing min_ts/max_ts
and total_span from the filtered events (or by excluding events with ts <=
cutoff_ts), and decrement/adjust n_steps_profiled consistently (same change
needed in the later block around the other cutoff at lines 223-237); ensure all
reported per-step metrics use the filtered timestamps/span and the correct
n_steps_profiled variable (referencing n_steps_profiled, cuda_events,
wall_clock_us, cutoff_ts).
- Around line 736-742: The load_snapshot function currently calls pickle.load()
unsafely; add a clear docstring to load_snapshot(path) that states snapshots
must only be loaded from trusted sources and that deserializing untrusted
snapshot.pickle files can execute arbitrary code, and update the CLI help text
for the --path and --compare options to include a short security warning
pointing users to only use trusted snapshot files or to use alternative safe
formats; ensure references mention the load_snapshot function and the --path and
--compare CLI flags so reviewers can locate the changes.
In `@setup.py`:
- Around line 170-173: The "flash-linear" extra in setup.py currently always
re-adds "flash-linear-attention" and "causal-conv1d", undoing the platform
filtering done by parse_requirements(); update the extras_require entry for the
"flash-linear" key so it conditionally includes those packages only on supported
platforms (same check/masking logic used by parse_requirements()), i.e., ensure
the "flash-linear" list is built/filtered using the same platform guard or
helper that removes "flash-linear-attention" on ARM64/macOS so
axolotl[flash-linear] cannot reintroduce unsupported deps.
In `@src/axolotl/integrations/liger/plugin.py`:
- Around line 225-270: The Gemma4 branch currently ignores
cfg.liger_fused_linear_cross_entropy; add explicit handling inside the gemma4
branch (where cfg.liger_rms_norm, cfg.liger_glu_activation, cfg.liger_rope,
cfg.liger_layer_norm, cfg.liger_cross_entropy are processed) to either apply the
fused linear cross-entropy replacement on modeling_gemma4 (e.g., set
modeling_gemma4.nn.FusedLinearCrossEntropy = <appropriate Liger FLCE class> or
modeling_gemma4.nn.CrossEntropyLoss = LigerFusedLinearCrossEntropy if that’s the
intended symbol) or, if FLCE is unsupported for Gemma4, emit a clear LOG.warning
like the RoPE case stating "Liger fused linear cross entropy is not compatible
with Gemma4; skipping." Reference cfg.liger_fused_linear_cross_entropy,
modeling_gemma4, LigerCrossEntropyLoss (or LigerFusedLinearCrossEntropy), and
LOG.warning to locate where to add this check.
---
Nitpick comments:
In `@docs/agents/model_architectures.md`:
- Around line 142-143: Update the deterministic sentence "Starting VLM loss of
~8-15 is **expected** (not a bug). The model converges to <1.0 within 30-50
steps." to a softer, empirical phrasing: say that a starting VLM loss of ~8-15
is typical/commonly observed and that many runs converge below 1.0 within ~30–50
steps, but results can vary across runs and setups; replace the exact wording in
model_architectures.md to reflect this less absolute framing.
In `@examples/gemma4/e2b-vision-lora.yaml`:
- Around line 3-4: Update the top comment to accurately describe the training
scope: replace "Vision fine-tuning of the multimodal Gemma4 model." with a line
stating that multimodal/vision encoders are frozen and only LM LoRA adapters are
being trained (e.g., "Fine-tuning LM LoRA adapters on multimodal Gemma4 with
vision/multimodal modules frozen; ProcessingStrategy auto-detects image_token
from processor."). Ensure the second comment about ProcessingStrategy remains
unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: d98133ef-6a8d-445f-91fe-73173e3c737d
📒 Files selected for processing (11)
README.mddocs/agents/model_architectures.mddocs/agents/sft.mddocs/multimodal.qmdexamples/gemma4/e2b-vision-lora.yamlexamples/qwen3.5/35b-a3b-moe-vision-lora.yamlscripts/analyze_profile.pysetup.pysrc/axolotl/core/trainers/base.pysrc/axolotl/integrations/liger/plugin.pysrc/axolotl/monkeypatch/lora_kernels.py
| ``` | ||
| Is the model multimodal (has vision/audio encoder)? | ||
| ├─ YES: Add `freeze_mm_modules: true` if training text only | ||
| │ Add `chat_template: <model_template>` (e.g. gemma4, qwen3_5, gemma3) | ||
| │ LoRA: use regex `lora_target_modules` to restrict to language model | ||
| └─ NO: Train as a regular text model | ||
|
|
||
| Is the model MoE (e.g. Gemma4 26B-A4B, Qwen3.5 35B-A3B)? | ||
| ├─ YES: Add `lora_target_parameters` for expert LoRA | ||
| │ Consider ScatterMoE kernels (see Plugins section) | ||
| └─ NO: Standard LoRA config | ||
| ``` |
There was a problem hiding this comment.
Add a language identifier to the decision-tree fenced block.
Line 16 uses an unlabeled fenced block, which triggers markdownlint MD040.
✏️ Proposed fix
-```
+```text
Is the model multimodal (has vision/audio encoder)?
├─ YES: Add `freeze_mm_modules: true` if training text only
│ Add `chat_template: <model_template>` (e.g. gemma4, qwen3_5, gemma3)
│ LoRA: use regex `lora_target_modules` to restrict to language model
└─ NO: Train as a regular text model
@@
Is the model MoE (e.g. Gemma4 26B-A4B, Qwen3.5 35B-A3B)?
├─ YES: Add `lora_target_parameters` for expert LoRA
│ Consider ScatterMoE kernels (see Plugins section)
└─ NO: Standard LoRA config</details>
<!-- suggestion_start -->
<details>
<summary>📝 Committable suggestion</summary>
> ‼️ **IMPORTANT**
> Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
```suggestion
🧰 Tools
🪛 markdownlint-cli2 (0.22.0)
[warning] 16-16: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/agents/model_architectures.md` around lines 16 - 27, The fenced
decision-tree block in docs/agents/model_architectures.md lacks a language
specifier causing markdownlint MD040; update the opening fence to include a
language identifier (e.g., use "text") so the block becomes ```text and keep the
existing contents unchanged to satisfy the linter and preserve formatting.
| This produces `profiler_trace.json` (Chrome trace) and `snapshot.pickle` (memory snapshot) in `output_dir`. | ||
| View the Chrome trace at `chrome://tracing`. | ||
|
|
||
| To programmatically inspect the trace: | ||
| ```bash | ||
| python scripts/analyze_profile.py output_dir/profiler_trace.json | ||
| ``` |
There was a problem hiding this comment.
Point the example at the run directory if you want memory analysis too.
Passing output_dir/profiler_trace.json only analyzes the trace. Using output_dir/ matches the script's own usage examples and lets it load snapshot.pickle as well.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/agents/sft.md` around lines 104 - 110, The example command only points
at the trace file, which prevents the analyzer from also loading the memory
snapshot; update the docs so the example calls the analyzer with the run
directory (e.g., pass output_dir/ instead of output_dir/profiler_trace.json) so
scripts/analyze_profile.py can locate both profiler_trace.json and
snapshot.pickle when inspecting a run.
| ::: | ||
|
|
||
| ::: {.callout-tip} | ||
| For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True`. For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`. |
There was a problem hiding this comment.
Document the activation_offloading exception for ddp_find_unused_parameters.
Line 173 reads as unconditional, but Gemma4 guidance elsewhere notes this is skipped when activation_offloading: true. Add that caveat here for consistency.
✏️ Proposed fix
-For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True`. For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`.
+For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True` (except when `activation_offloading: true`, where `find_unused_parameters` is skipped). For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True`. For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`. | |
| For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True` (except when `activation_offloading: true`, where `find_unused_parameters` is skipped). For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`. |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/multimodal.qmd` at line 173, Update the sentence that currently states
Gemma4 auto-sets use_reentrant=False and ddp_find_unused_parameters=True to
include the exception for activation_offloading; explicitly note that when
activation_offloading: true, the ddp_find_unused_parameters setting is skipped
(so the auto-detected ddp_find_unused_parameters=True does not apply), and keep
the rest of the guidance about FSDP2 and fsdp_transformer_layer_cls_to_wrap:
Gemma4TextDecoderLayer unchanged; ensure you reference Gemma4, use_reentrant,
ddp_find_unused_parameters, and activation_offloading in the updated wording so
the caveat is clear.
| #### Using uv (recommended) | ||
|
|
||
| ```bash | ||
| # install uv if you don't already have it installed | ||
| curl -LsSf https://astral.sh/uv/install.sh | sh | ||
| source $HOME/.local/bin/env | ||
|
|
||
| # CUDA 12.8.1 tends to have better package compatibility | ||
| export UV_TORCH_BACKEND=cu128 | ||
|
|
||
| # create a new virtual environment | ||
| uv venv --python 3.12 | ||
| source .venv/bin/activate | ||
|
|
||
| uv pip install torch==2.10.0 torchvision | ||
| uv pip install --no-build-isolation axolotl[deepspeed] | ||
|
|
||
| # recommended - install cut-cross-entropy | ||
| uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main" | ||
|
|
||
| # (optional) - prefetch flash-attn2 and causal-conv1d kernels | ||
| uv run --python 3.12 python -c "from kernels import get_kernel; get_kernel('kernels-community/flash-attn2'); get_kernel('kernels-community/causal-conv1d')" | ||
|
|
||
| # Download example axolotl configs, deepspeed configs | ||
| axolotl fetch examples | ||
| axolotl fetch deepspeed_configs # OPTIONAL | ||
| ``` |
There was a problem hiding this comment.
Keep the recommended Python version consistent with the quick-start requirements.
The prerequisites still say Python 3.11, but the new recommended path provisions 3.12 and hardcodes uv run --python 3.12. Please clarify whether 3.12 is required or just preferred so readers don't set up the wrong interpreter.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@README.md` around lines 98 - 124, The README has inconsistent Python
versions: the prerequisites state Python 3.11 while the "Using uv (recommended)"
section hardcodes 3.12 in the commands (uv venv --python 3.12, uv run --python
3.12); update the text to clarify whether 3.12 is required or only preferred and
make the commands match the quick-start requirement—either change the
prerequisites to 3.12 or replace both uv commands to use 3.11 and add a
parenthetical note like "or 3.12 preferred" if appropriate so all references (uv
venv --python, uv run --python) are consistent with the declared required
version.
| n_steps_profiled = 3 # typical profiler_steps value | ||
|
|
||
| if skip_warmup and len(cuda_events) > 1000: | ||
| timestamps = sorted(set(float(ev.get("ts", 0)) for ev in cuda_events)) | ||
| min_ts, max_ts = timestamps[0], timestamps[-1] | ||
| total_span = max_ts - min_ts | ||
|
|
||
| # Step 0 is warmup (Triton compilation + autotune). It's typically | ||
| # the slowest step by far. Use 45% of wall-clock as cutoff -- step 0 | ||
| # usually takes >50% of total time when it includes compilation. | ||
| cutoff_ts = min_ts + total_span * 0.45 | ||
| before = len(cuda_events) | ||
| cuda_events = [ev for ev in cuda_events if float(ev.get("ts", 0)) > cutoff_ts] | ||
| n_steps_profiled -= 1 # now analyzing steps 1+ | ||
| print(f" Excluding step 0 (warmup): {before:,} -> {len(cuda_events):,} events") |
There was a problem hiding this comment.
Per-step numbers are wrong for anything except the assumed 3-step trace.
n_steps_profiled is hardcoded to 3, and the warmup cutoff is only applied to cuda_events while wall_clock_us still spans all CPU ops. Any run with a different profiler_steps value—or the default warmup exclusion enabled—will print incorrect ms/step, GPU utilization, and summary recommendations.
Also applies to: 223-237
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@scripts/analyze_profile.py` around lines 174 - 188, The code hardcodes
n_steps_profiled=3 and only filters cuda_events for the warmup cutoff, causing
ms/step, GPU utilization and recommendations to be wrong; update the logic to
derive n_steps_profiled from the actual profiler_steps (or compute it from the
timestamps/count of steps) and whenever you apply the warmup cutoff to
cuda_events also apply the same cutoff to wall_clock_us and any other CPU/time
aggregates by recomputing min_ts/max_ts and total_span from the filtered events
(or by excluding events with ts <= cutoff_ts), and decrement/adjust
n_steps_profiled consistently (same change needed in the later block around the
other cutoff at lines 223-237); ensure all reported per-step metrics use the
filtered timestamps/span and the correct n_steps_profiled variable (referencing
n_steps_profiled, cuda_events, wall_clock_us, cutoff_ts).
| def load_snapshot(path): | ||
| snap_file = Path(path) / "snapshot.pickle" if Path(path).is_dir() else Path(path) | ||
| if not snap_file.exists(): | ||
| return None | ||
| print(f"Loading {snap_file.name} ({snap_file.stat().st_size / 1e6:.0f} MB)...") | ||
| with open(snap_file, "rb") as f: | ||
| return pickle.load(f) # nosec B301 |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, check the exact file and lines
head -n 750 scripts/analyze_profile.py | tail -n 20 | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 884
🏁 Script executed:
# Get the full context of the load_snapshot function and surrounding code
sed -n '730,750p' scripts/analyze_profile.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 901
🏁 Script executed:
# Search for all calls to load_snapshot to understand how the function is used
rg "load_snapshot" --type python -B 2 -A 2Repository: axolotl-ai-cloud/axolotl
Length of output: 98
🏁 Script executed:
# Check if there's any documentation about security or trust boundaries
rg -i "trust|security|pickle" --type markdown --type txtRepository: axolotl-ai-cloud/axolotl
Length of output: 1035
🏁 Script executed:
# Look for CLI argument handling related to snapshot paths
rg "snapshot|path" scripts/ --type python | head -30Repository: axolotl-ai-cloud/axolotl
Length of output: 98
🏁 Script executed:
# Search for load_snapshot calls without type restriction
rg "load_snapshot" -B 2 -A 2Repository: axolotl-ai-cloud/axolotl
Length of output: 888
🏁 Script executed:
# Check the documentation file that mentions snapshot.pickle
cat -n docs/agents/sft.md | grep -A 5 -B 5 "snapshot.pickle"Repository: axolotl-ai-cloud/axolotl
Length of output: 513
🏁 Script executed:
# Look for CLI entry points or main functions
rg "def main|argparse|click" scripts/analyze_profile.py -B 2 -A 5Repository: axolotl-ai-cloud/axolotl
Length of output: 533
🏁 Script executed:
# Check if there's a CLI interface or how analyze_profile is invoked
head -n 50 scripts/analyze_profile.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 2042
🏁 Script executed:
# Search for any command-line argument handling for snapshot path
rg "snapshot|path" scripts/analyze_profile.py | head -20Repository: axolotl-ai-cloud/axolotl
Length of output: 1101
🏁 Script executed:
# Get the full argparse setup to see argument descriptions
sed -n '790,850p' scripts/analyze_profile.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 2868
🏁 Script executed:
# Search for any existing security-related comments or warnings in the file
rg -i "security|trust|unsafe|warning|untrusted" scripts/analyze_profile.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 138
🏁 Script executed:
# Check if there's any validation of the path argument before use
sed -n '800,900p' scripts/analyze_profile.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 4513
🏁 Script executed:
# Get the complete argument parser definition
sed -n '810,860p' scripts/analyze_profile.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 2231
🏁 Script executed:
# Search for the main() function and how arguments are documented
rg "add_argument|help=" scripts/analyze_profile.py -A 1Repository: axolotl-ai-cloud/axolotl
Length of output: 866
Add security guidance for loading snapshots from untrusted sources.
snapshot.pickle is deserialized using pickle.load(), which can execute arbitrary code. While this is safe for self-generated profiler artifacts, it's unsafe for snapshots shared from external or untrusted sources. The load_snapshot() function lacks documentation of this trust boundary, and neither the CLI help text nor the docs currently warn users about this risk.
Consider adding a docstring to load_snapshot() clarifying that snapshots should only be loaded from trusted sources, and updating the CLI help text for --path and --compare to note the security implications.
🧰 Tools
🪛 Ruff (0.15.9)
[error] 742-742: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue
(S301)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@scripts/analyze_profile.py` around lines 736 - 742, The load_snapshot
function currently calls pickle.load() unsafely; add a clear docstring to
load_snapshot(path) that states snapshots must only be loaded from trusted
sources and that deserializing untrusted snapshot.pickle files can execute
arbitrary code, and update the CLI help text for the --path and --compare
options to include a short security warning pointing users to only use trusted
snapshot files or to use alternative safe formats; ensure references mention the
load_snapshot function and the --path and --compare CLI flags so reviewers can
locate the changes.
| "flash-linear": [ | ||
| "flash-linear-attention", | ||
| "causal-conv1d", | ||
| ], |
There was a problem hiding this comment.
Gate the flash-linear extra on the same unsupported platforms as the base deps.
parse_requirements() already strips flash-linear-attention from the core install on ARM64/macOS, but this extra re-adds it unconditionally. That means axolotl[flash-linear] can still fail on the exact platforms the base install is trying to protect.
Suggested fix
if platform.machine() == "aarch64":
# skip on ARM64
skip_packages = [
"torchao",
"fla-core",
"flash-linear-attention",
]
+ extras_require_map.pop("flash-linear", None)
_install_requires = [
req
for req in _install_requires
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
]
if "Darwin" in platform.system():
# skip packages not compatible with OSX
skip_packages = [
"bitsandbytes",
"triton",
"mamba-ssm",
"xformers",
"liger-kernel",
]
+ extras_require_map.pop("flash-linear", None)
_install_requires = [
req
for req in _install_requires
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@setup.py` around lines 170 - 173, The "flash-linear" extra in setup.py
currently always re-adds "flash-linear-attention" and "causal-conv1d", undoing
the platform filtering done by parse_requirements(); update the extras_require
entry for the "flash-linear" key so it conditionally includes those packages
only on supported platforms (same check/masking logic used by
parse_requirements()), i.e., ensure the "flash-linear" list is built/filtered
using the same platform guard or helper that removes "flash-linear-attention" on
ARM64/macOS so axolotl[flash-linear] cannot reintroduce unsupported deps.
| elif cfg.model_config_type in ("gemma4", "gemma4_text"): | ||
| # Gemma4: offset=0 (NOT 1 like Gemma3), in_place=False required for | ||
| # gradient checkpointing compatibility, RoPE incompatible (separate q/k). | ||
| from liger_kernel.transformers.geglu import LigerGEGLUMLP | ||
| from transformers.models.gemma4 import modeling_gemma4 | ||
|
|
||
| if cfg.liger_rms_norm: | ||
| _OrigGemma4RMSNorm = modeling_gemma4.Gemma4RMSNorm | ||
|
|
||
| class _LigerGemma4RMSNorm(LigerRMSNorm): | ||
| """LigerRMSNorm for Gemma4 with in_place=False and with_scale support.""" | ||
|
|
||
| def __new__(cls, dim, eps=1e-6, with_scale=True): | ||
| if not with_scale: | ||
| return _OrigGemma4RMSNorm(dim, eps, with_scale=False) | ||
| return super().__new__(cls) | ||
|
|
||
| def __init__(self, dim, eps=1e-6, with_scale=True): | ||
| if not with_scale: | ||
| return | ||
| # offset=0.0 (standard), in_place=False (gradient checkpointing safe) | ||
| super().__init__( | ||
| dim, eps, offset=0.0, casting_mode="llama", in_place=False | ||
| ) | ||
|
|
||
| modeling_gemma4.Gemma4RMSNorm = _LigerGemma4RMSNorm | ||
| if cfg.liger_glu_activation: | ||
|
|
||
| class _LigerGemma4MLP(LigerGEGLUMLP): | ||
| def __init__(self, config, layer_idx=None): | ||
| super().__init__(config) | ||
|
|
||
| modeling_gemma4.Gemma4TextMLP = _LigerGemma4MLP | ||
| if cfg.liger_rope: | ||
| LOG.warning( | ||
| "Liger RoPE is not compatible with Gemma4 (separate q/k application). Skipping." | ||
| ) | ||
| if cfg.liger_layer_norm: | ||
| modeling_gemma4.nn.LayerNorm = LigerLayerNorm | ||
| if cfg.liger_cross_entropy: | ||
| modeling_gemma4.nn.CrossEntropyLoss = LigerCrossEntropyLoss | ||
| LOG.info( | ||
| f"Applied Liger kernels for gemma4: " | ||
| f"rms_norm={cfg.liger_rms_norm}, glu={cfg.liger_glu_activation}, " | ||
| f"rope=False (incompatible), layer_norm={cfg.liger_layer_norm}" | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
rg -n -C3 'gemma4|fused_linear_cross_entropy|lce_forward' src/axolotl/integrations/ligerRepository: axolotl-ai-cloud/axolotl
Length of output: 40617
liger_fused_linear_cross_entropy is silently ignored for Gemma4.
The Gemma4 branch patches RMSNorm, GLU, LayerNorm, and CrossEntropy, but does not handle cfg.liger_fused_linear_cross_entropy. Unlike the RoPE incompatibility (which raises a warning), enabling this config for Gemma4 receives no warning or implementation. The elif fallback at line 271 cannot apply a generic FLCE patch because the Gemma4 condition already matched. This creates a silent config mismatch where the setting appears valid but produces no effect.
Add explicit handling: either implement Gemma4 FLCE or log a warning that it is unsupported, consistent with the RoPE incompatibility message.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/integrations/liger/plugin.py` around lines 225 - 270, The Gemma4
branch currently ignores cfg.liger_fused_linear_cross_entropy; add explicit
handling inside the gemma4 branch (where cfg.liger_rms_norm,
cfg.liger_glu_activation, cfg.liger_rope, cfg.liger_layer_norm,
cfg.liger_cross_entropy are processed) to either apply the fused linear
cross-entropy replacement on modeling_gemma4 (e.g., set
modeling_gemma4.nn.FusedLinearCrossEntropy = <appropriate Liger FLCE class> or
modeling_gemma4.nn.CrossEntropyLoss = LigerFusedLinearCrossEntropy if that’s the
intended symbol) or, if FLCE is unsupported for Gemma4, emit a clear LOG.warning
like the RoPE case stating "Liger fused linear cross entropy is not compatible
with Gemma4; skipping." Reference cfg.liger_fused_linear_cross_entropy,
modeling_gemma4, LigerCrossEntropyLoss (or LigerFusedLinearCrossEntropy), and
LOG.warning to locate where to add this check.
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
Summary by CodeRabbit
Release Notes
New Features
uvpackage managerDocumentation
uv-based setupPerformance