Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Features:
**Requirements**:

- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- Python >=3.11 (3.12 recommended)
- PyTorch ≥2.9.1

### Google Colab
Expand All @@ -95,6 +95,34 @@ Features:

### Installation

#### Using uv (recommended)

```bash
# install uv if you don't already have it installed
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env

# CUDA 12.8.1 tends to have better package compatibility
export UV_TORCH_BACKEND=cu128

# create a new virtual environment
uv venv --python 3.12
source .venv/bin/activate

uv pip install torch==2.10.0 torchvision
uv pip install --no-build-isolation axolotl[deepspeed]

# recommended - install cut-cross-entropy
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"

# (optional) - prefetch flash-attn2 and causal-conv1d kernels
uv run --python 3.12 python -c "from kernels import get_kernel; get_kernel('kernels-community/flash-attn2'); get_kernel('kernels-community/causal-conv1d')"

# Download example axolotl configs, deepspeed configs
axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL
```
Comment on lines +98 to +124

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Keep the recommended Python version consistent with the quick-start requirements.

The prerequisites still say Python 3.11, but the new recommended path provisions 3.12 and hardcodes uv run --python 3.12. Please clarify whether 3.12 is required or just preferred so readers don't set up the wrong interpreter.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@README.md` around lines 98 - 124, The README has inconsistent Python
versions: the prerequisites state Python 3.11 while the "Using uv (recommended)"
section hardcodes 3.12 in the commands (uv venv --python 3.12, uv run --python
3.12); update the text to clarify whether 3.12 is required or only preferred and
make the commands match the quick-start requirement—either change the
prerequisites to 3.12 or replace both uv commands to use 3.11 and add a
parenthetical note like "or 3.12 preferred" if appropriate so all references (uv
venv --python, uv run --python) are consistent with the declared required
version.


#### Using pip

```bash
Expand Down
88 changes: 88 additions & 0 deletions docs/agents/model_architectures.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,64 @@

Model-specific quirks, required settings, and known issues. Check this before debugging training failures on specific model families.

## VLM (Vision Language Model) Quick Start

All VLM configs require these four lines:
```yaml
processor_type: AutoProcessor
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
```

Decision tree for VLM config:
```text
Is the model multimodal (has vision/audio encoder)?
├─ YES: Add `freeze_mm_modules: true` if training text only
│ Add `chat_template: <model_template>` (e.g. gemma4, qwen3_5, gemma3)
│ LoRA: use regex `lora_target_modules` to restrict to language model
└─ NO: Train as a regular text model

Is the model MoE (e.g. Gemma4 26B-A4B, Qwen3.5 35B-A3B)?
├─ YES: Add `lora_target_parameters` for expert LoRA
│ Consider ScatterMoE kernels (see Plugins section)
└─ NO: Standard LoRA config
```

## Plugins & Optimizations

### Cut Cross Entropy (CCE)

Computes loss from hidden states + lm_head weight without materializing the full logits tensor, saving significant VRAM. Install if not already present:

```bash
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
```

```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
```

### ScatterMoE Kernels

Fuses expert + LoRA computation into a single kernel for MoE models. Significant speedup for models with many experts.

```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe

# Expert LoRA targets (3D parameter tensors, not nn.Linear):
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```

Supported: Gemma4 (`gemma4_text`), Mixtral, Qwen MoE variants. The plugin auto-detects model type and routing function. Without ScatterMoE, expert LoRA still works but runs base expert matmul and LoRA as separate operations.

## Gemma 4

**Models**: `google/gemma-4-26B-A4B` (MoE), `google/gemma-4-31B` (dense), `google/gemma-4-E2B`, `google/gemma-4-E4B`
Expand Down Expand Up @@ -66,6 +124,36 @@ fsdp_config:
experts_implementation: scattermoe
```

### VLM (Vision) Training

All Gemma4 models load as `Gemma4ForConditionalGeneration` with a vision tower. No custom `ProcessingStrategy` needed — the base class auto-detects the image token.

```yaml
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B
processor_type: AutoProcessor
freeze_mm_modules: true
chat_template: gemma4

skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
```

A starting VLM loss of ~8-15 is typical. In most runs, loss converges below 1.0 within ~30-50 steps, though results may vary across configurations.

For the 26B-A4B MoE variant with ScatterMoE + expert LoRA + CCE, add:
```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```

### Common issues

| Symptom | Cause | Fix |
Expand Down
24 changes: 24 additions & 0 deletions docs/agents/sft.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,30 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
| FSDP save hangs | Use `fsdp_state_dict_type: FULL_STATE_DICT` |
| DeepSpeed CheckpointError | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |

## Profiling

To profile training and identify optimization opportunities:

```yaml
# Profile steps 3-7 (after warmup/autotuning settles)
profiler_steps_start: 3
profiler_steps: 5
```

This produces `profiler_trace.json` (Chrome trace) and `snapshot.pickle` (memory snapshot) in `output_dir`.
View the Chrome trace at `chrome://tracing`.

To programmatically inspect the trace:
```bash
python scripts/analyze_profile.py output_dir/
```
Comment on lines +104 to +110

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Point the example at the run directory if you want memory analysis too.

Passing output_dir/profiler_trace.json only analyzes the trace. Using output_dir/ matches the script's own usage examples and lets it load snapshot.pickle as well.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/agents/sft.md` around lines 104 - 110, The example command only points
at the trace file, which prevents the analyzer from also loading the memory
snapshot; update the docs so the example calls the analyzer with the run
directory (e.g., pass output_dir/ instead of output_dir/profiler_trace.json) so
scripts/analyze_profile.py can locate both profiler_trace.json and
snapshot.pickle when inspecting a run.


The trace shows per-kernel CUDA times, memory allocations, and operator-level breakdown. Look for:
- **Large matmul kernels**: candidates for fusion or quantization
- **Memory copies (H2D/D2H)**: unnecessary data movement
- **Small frequent kernels**: candidates for kernel fusion
- **Gaps between kernels**: pipeline bubbles from CPU overhead

Full troubleshooting: [training_stability.qmd](../training_stability.qmd), [debugging.qmd](../debugging.qmd)

## File Map
Expand Down
35 changes: 35 additions & 0 deletions docs/multimodal.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ format:

## Supported Models

- [Gemma-4](#sec-gemma-4) *(NEW)*
- [Mllama](#sec-mllama)
- [Llama4](#sec-llama4)
- [Pixtral](#sec-pixtral)
Expand Down Expand Up @@ -138,6 +139,40 @@ base_model: mistralai/Voxtral-Mini-3B-2507
processor_type: VoxtralProcessor
```

### Gemma-4 {#sec-gemma-4}

All Gemma 4 variants (E2B, E4B, 26B-A4B, 31B) load as multimodal models even for text-only training.

```yaml
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B, 31B

chat_template: gemma4
freeze_mm_modules: true # freeze vision/audio encoders for text-only or vision LoRA

# For the 26B-A4B MoE model, enable ScatterMoE and expert LoRA:
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe

lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'

# MoE expert LoRA (3D tensors, not nn.Linear) — only for 26B-A4B:
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```

::: {.callout-warning}
Gemma 4 VLM training starts with high loss (~8-15). This is expected — see the [training stability guide](training_stability.qmd) for details.
:::

::: {.callout-tip}
For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True`. However, when `activation_offloading: true`, `ddp_find_unused_parameters` is skipped (checkpoint wrappers conflict with it); use `freeze_mm_modules: true` instead to handle unused vision/audio params. For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`.
:::

### Gemma-3 {#sec-gemma-3}

::: {.callout-tip}
Expand Down
62 changes: 62 additions & 0 deletions examples/gemma4/e2b-vision-lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Gemma 4 E2B Vision LoRA
#
# Fine-tuning LM LoRA adapters on multimodal Gemma4 with vision/multimodal modules frozen.
# Uses the base ProcessingStrategy (auto-detects image_token from processor).

base_model: google/gemma-4-E2B-it
processor_type: AutoProcessor
freeze_mm_modules: true

plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false

# Required for vision/multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false

chat_template: gemma4
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:100]

val_set_size: 0
output_dir: ./outputs/gemma4-e2b-vision-lora

adapter: lora
sequence_len: 2048
pad_to_sequence_len: false

lora_r: 16
lora_alpha: 32
lora_dropout: 0
# Target language model only — vision encoder is frozen via freeze_mm_modules
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'

gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
max_steps: 10
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002

bf16: auto
tf32: true

gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
sdp_attention: true

warmup_ratio: 0.1
weight_decay: 0.0

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
62 changes: 62 additions & 0 deletions examples/qwen3.5/35b-a3b-moe-vision-lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Qwen 3.5 35B-A3B MoE Vision LoRA
#
# Vision fine-tuning of the hybrid DeltaNet + Attention MoE model.
# 256 experts, 8 active per token, with early-fusion vision support.

base_model: Qwen/Qwen3.5-35B-A3B
processor_type: AutoProcessor

# Required for vision/multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false

chat_template: qwen3_5
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:100]

val_set_size: 0
output_dir: ./outputs/qwen35-35b-a3b-vision-lora

adapter: lora
sequence_len: 4096
pad_to_sequence_len: false

lora_r: 16
lora_alpha: 32
lora_dropout: 0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj

gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
max_steps: 10
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002

bf16: auto
tf32: true

gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
flash_attention: true

warmup_ratio: 0.1
weight_decay: 0.0

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
Loading
Loading