Skip to content

super nemo support#3508

Merged
winglian merged 20 commits into
axolotl-ai-cloud:mainfrom
ved1beta:nemo-nemo
Mar 30, 2026
Merged

super nemo support#3508
winglian merged 20 commits into
axolotl-ai-cloud:mainfrom
ved1beta:nemo-nemo

Conversation

@ved1beta

@ved1beta ved1beta commented Mar 18, 2026

Copy link
Copy Markdown
Member

super nemo support

https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16

no sample packing T_T'

examples/nemotron-h/nano-30b-a3b-qlora.yaml
max_vram :69Gbs

', 'grad_norm': '0.2281', 'learning_rate': '1.452e-05', 'ppl': '2.731', 'memory/max_active (GiB)': '68.22', 'memory/max_allocated (GiB)': '68.22', 'memory/device_reserved (GiB)': '77.14', 'tokens/train_per_sec_per_gpu': '189', 'tokens/total': 2623744, 'tokens/trainable': 704234, 'epoch': '0.02908'}
{'loss': '0.9717', 'grad_norm': '0.2569', 'learning_rate': '1.454e-05', 'ppl': '2.642', 'memory/max_active (GiB)': '69.89', 'memory/max_allocated (GiB)': '69.89', 'memory/device_reserved (GiB)': '77.14', 'tokens/train_per_sec_per_gpu': '145.2', 'tokens/total': 2628352, 'tokens/trainable': 705280, 'epoch': '0.02913'}
{'loss': '0.6611', 'grad_norm': '0.2496', 'learning_rate': '1.457e-05', 'ppl': '1.937', 'memory/max_active (GiB)': '67.66', 'memory/max_allocated (GiB)': '67.66', 'memory/device_reserved (GiB)': '77.14', 'tokens/train_per_sec_per_gpu': '48.39', 'tokens/total': 2633856, 'tokens/trainable': 706148, 'epoch': '0.02917'}
{'loss': '0.8849', 'grad_norm': '0.4', 'learning_rate': '1.459e-05', 'ppl': '2.423', 'memory/max_active (GiB)': '67.95', 'memory/max_allocated (GiB)': '67.95', 'memory/device_reserved (GiB)': '77.14', 'tokens/train_per_sec_per_gpu': '42.57', 'tokens/total': 2637184, 'tokens/trainable': 707015, 'epoch': '0.02922'}
{'loss': '0.7247', 'grad_norm': '0.3224', 'learning_rate': '1.461e-05', 'ppl': '2.064', 'memory/max_active (GiB)': '66.84', 'memory/max_allocated (GiB)': '66.84', 'memory/device_reserved (GiB)': '77.14', 'tokens/train_per_sec_per_gpu': '138.2', 'tokens/total': 2640640, 'tokens/trainable': 708233, 'epoch': '0.02926'}
{'loss': '0.8672', 'grad_norm': '0.1406', 'learning_rate': '1.463e-05', 'ppl': '2.38', 'memory/max_active (GiB)': '68.6', 'memory/max_allocated (GiB)': '68.6', 'memory/device_reserved (GiB)': '77.14', 'tokens/train_per_sec_per_gpu': '236.7', 'tokens/total': 2645632, 'tokens/trainable': 710846, 'epoch': '0.02931'}
{'loss': '0.6126', 'grad_norm': '0.3173', 'learning_rate': '1.465e-05', 'ppl': '1.845', 'memory/max_active (GiB)': '69.72', 'memory/max_allocated (GiB)': '69.72', 'memory/device_reserved (GiB)': '77.14', 'tokens/train_per_sec_per_gpu': '63.65', 'tokens/total': 2651136, 'tokens/trainable': 711825, 'epoch': '0.02935'}
{'loss': '0.6396', 'grad_norm': '0.3753', 'learning_rate': '1.468e-05', 'ppl': '1.896', 'memory/max_active (GiB)': '67.65', 'memory/max_allocated (GiB)': '67.65', 'memory/device_reserved (GiB)': '77.14', 'tokens/train_per_sec_per_gpu': '91.66', 'tokens/total': 2654592, 'tokens/trainable': 712405, 'epoch': '0.0294'}
{'loss': '0.7012', 'grad_norm': '0.2779', 'learning_rate': '1.47e-05', 'ppl': '2.016', 'memory/max_active (GiB)': '67.94', 'memory/max_allocated (GiB)': '67.94', 'memory/device_reserved (GiB)': '77.14', 'tokens/train_per_sec_per_gpu': '55.39', 'tokens/total': 2657536, 'tokens/trainable': 713062, 'epoch': '0.02944'}
{'loss': '0.6969', 'grad_norm': '0.3095', 'learning_rate': '1.472e-05', 'ppl': '2.007', 'memory/max_active (GiB)': '67.94', 'memory/max_allocated (GiB)': '67.94', 'memory/device_reserved (GiB)': '77.14', 'tokens/train_per_sec_per_gpu': '79.2', 'tokens/total': 2660480, 'tokens/trainable': 713689, 'epoch': '0.02949'}
{'loss': '0.6547', 'grad_norm': '0.2354', 'learning_rate': '1.474e-05', 'ppl': '1.925', 'memory/max_active (GiB)': '69.72', 'memory/max_allocated (GiB)': '69.72', 'memory/device_reserved (GiB)': '77.14', 'tokens/train_per_sec_per_gpu': '68.17', 'tokens/total': 2665216, 'tokens/trainable': 714624, 'epoch': '0.02953'}
{'loss': '0.7779', 'grad_norm': '0.2468', 'learning_rate': '1.477e-05', 'ppl': '2.177', 'memory/max_active (GiB)': '67.49', 'memory/max_allocated (GiB)': '67.49', 'memory/device_reserved (GiB)': '77.14', 'tokens/train_per_sec_per_gpu': '83.81', 'tokens/total': 2667904, 'tokens/trainable': 715566, 'epoch': '0.02957'}
{'loss': '0.705', 'grad_norm': '0.3831', 'learning_rate': '1.479e-05', 'ppl': '2.024', 'memory/max_active (GiB)': '66.65', 'memory/max_allocated (GiB)': '66.65', 'memory/device_reserved (GiB)': '77.14', 'tokens/train_per_sec_per_gpu': '44.59', 'tokens/total': 2671104, 'tokens/trainable': 716221, 'epoch': '0.02962'}
{'loss': '0.9239', 'grad_norm': '0.2548', 'learning_rate': '1.481e-05', 'ppl': '2.519', 'memory/max_

chat template : cladue

image

Summary by CodeRabbit

  • New Features

    • Added support for training Nemotron-H models with QLoRA optimization and mixed Mamba2/Attention/MoE architecture
    • New chat template support for Nemotron-H models with proper message formatting
  • Documentation

    • Added complete end-to-end training configuration example for Nemotron-H Nano 30B model featuring 4-bit quantization and gradient checkpointing

@coderabbitai

coderabbitai Bot commented Mar 18, 2026

Copy link
Copy Markdown
Contributor

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: bd030814-742d-431e-a4d9-2f2354a1b4ce

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR adds comprehensive support for the Nemotron-H model architecture (a mixed Mamba2/Attention/MoE model) across the training framework. Changes include new model registration entries, model-specific loading and patching logic, a quantization defensive patch, configuration examples, and chat template support.

Changes

Cohort / File(s) Summary
Configuration & Chat Template Support
examples/nemotron-h/nano-30b-a3b-qlora.yaml, src/axolotl/utils/chat_templates/templates/nemotron.jinja
Adds a complete QLoRA training configuration for Nemotron-H and introduces a Jinja chat template that formats messages with system/user/assistant roles and generation prompt support.
Model Type Registration
src/axolotl/common/architectures.py, src/axolotl/utils/schemas/enums.py
Registers "nemotron_h" in MOE_ARCH_BLOCK mapping and ChatTemplate enum to enable model type recognition.
Model Loading & Patching
src/axolotl/loaders/model.py, src/axolotl/loaders/patch_manager.py, src/axolotl/loaders/utils.py
Adds nemotron_h-specific handling for 4-bit quantization storage, gradient checkpointing support flag, and embedding layer identification.
Expert Quantization Patching
src/axolotl/monkeypatch/moe_quant.py
Implements defensive patch to prevent PEFT from wrapping ParametrizationList modules during LoRA target module injection for quantized MoE experts.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

Possibly related PRs

Suggested labels

under review

Suggested reviewers

  • winglian
  • NanoCode012
🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 71.43% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title 'super nemo support' is vague and does not clearly describe the specific changes made. While it references 'nemo,' the PR actually adds support for Nemotron-H models (nano and super variants) with specific configurations including QLoRA, quantization, and training infrastructure updates. Consider revising the title to be more specific and descriptive, such as 'Add Nemotron-H model support with QLoRA training configuration' or 'Add NVIDIA Nemotron-3 Super and Nano support with quantization and training configs.'
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/nemotron-h/nano-30b-a3b-qlora.yaml`:
- Line 26: The example config enables sample_packing which contradicts the PR
objective "no sample packing"; update the config entry for sample_packing in the
nano-30b-a3b-qlora.yaml example by either removing the sample_packing line or
explicitly setting sample_packing: false and add a short inline comment
referencing the "no sample packing" constraint so users aren’t misled; locate
the sample_packing key in the YAML and change it accordingly.

In `@src/axolotl/utils/chat_templates/templates/nemotron.jinja`:
- Line 1: The template unconditionally indexes messages[0] which will raise when
messages is empty; update the nemotron.jinja condition to first check that
messages is non-empty (e.g., messages or messages|length > 0) before accessing
messages[0].role so the guard prevents rendering errors when messages is empty
and preserves the existing system-role branch behavior.

In `@src/axolotl/utils/schemas/enums.py`:
- Line 64: The ChatTemplate enum entry nemotron_h has the wrong value; change
the value of the enum member (named nemotron_h) in
src/axolotl/utils/schemas/enums.py to "nemotron" so it matches the template key
derived from the nemotron.jinja filename (see chat_templates/base.py); ensure
the enum member name can remain nemotron_h if needed but its string value must
be "nemotron" so chat_template: nemotron selects the Nemotron template.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: edd94896-cbe4-4a7a-b0a9-c95b30e78821

📥 Commits

Reviewing files that changed from the base of the PR and between 5ef3f28 and 67f2fc6.

📒 Files selected for processing (8)
  • examples/nemotron-h/nano-30b-a3b-qlora.yaml
  • src/axolotl/common/architectures.py
  • src/axolotl/loaders/model.py
  • src/axolotl/loaders/patch_manager.py
  • src/axolotl/loaders/utils.py
  • src/axolotl/monkeypatch/moe_quant.py
  • src/axolotl/utils/chat_templates/templates/nemotron.jinja
  • src/axolotl/utils/schemas/enums.py

dataset_prepared_path: last_run_prepared

sequence_len: 4096
sample_packing: true

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Example config conflicts with stated Nemotron-H limitation.

Line 26 enables sample_packing, but the PR objective explicitly states “no sample packing.” This example will likely mislead users into an unsupported setup.

Proposed fix
-sample_packing: true
+sample_packing: false
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
sample_packing: true
sample_packing: false
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/nemotron-h/nano-30b-a3b-qlora.yaml` at line 26, The example config
enables sample_packing which contradicts the PR objective "no sample packing";
update the config entry for sample_packing in the nano-30b-a3b-qlora.yaml
example by either removing the sample_packing line or explicitly setting
sample_packing: false and add a short inline comment referencing the "no sample
packing" constraint so users aren’t misled; locate the sample_packing key in the
YAML and change it accordingly.

@@ -0,0 +1,16 @@
{%- if messages[0].role == 'system' %}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard against empty messages before indexing.

Line 1 accesses messages[0] unconditionally; empty inputs will fail at render time.

Proposed fix
-{%- if messages[0].role == 'system' %}
+{%- if messages and messages[0].role == 'system' %}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
{%- if messages[0].role == 'system' %}
{%- if messages and messages[0].role == 'system' %}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/utils/chat_templates/templates/nemotron.jinja` at line 1, The
template unconditionally indexes messages[0] which will raise when messages is
empty; update the nemotron.jinja condition to first check that messages is
non-empty (e.g., messages or messages|length > 0) before accessing
messages[0].role so the guard prevents rendering errors when messages is empty
and preserves the existing system-role branch behavior.

qwen3 = "qwen3"
qwen3_5 = "qwen3_5"
falcon_h1 = "falcon_h1"
nemotron_h = "nemotron_h"

@coderabbitai coderabbitai Bot Mar 18, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

ChatTemplate value does not match the Nemotron template key.

Line 64 introduces nemotron_h, but template keys are derived from filename stems (see src/axolotl/utils/chat_templates/base.py), and the added file is nemotron.jinja (key nemotron). This prevents chat_template: nemotron_h from selecting the Nemotron template.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/utils/schemas/enums.py` at line 64, The ChatTemplate enum entry
nemotron_h has the wrong value; change the value of the enum member (named
nemotron_h) in src/axolotl/utils/schemas/enums.py to "nemotron" so it matches
the template key derived from the nemotron.jinja filename (see
chat_templates/base.py); ensure the enum member name can remain nemotron_h if
needed but its string value must be "nemotron" so chat_template: nemotron
selects the Nemotron template.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ved1beta make sure to verify this. thanks!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is correct. The template file is wrong which I left a review below on

@codecov

codecov Bot commented Mar 18, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 17.64706% with 112 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
.../axolotl/monkeypatch/models/nemotron_h/modeling.py 0.00% 90 Missing ⚠️
src/axolotl/loaders/patch_manager.py 19.04% 17 Missing ⚠️
src/axolotl/monkeypatch/lora_kernels.py 0.00% 4 Missing ⚠️
src/axolotl/loaders/utils.py 50.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@NanoCode012 NanoCode012 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Sample packing support

# pip install mamba-ssm causal-conv1d # for fast Mamba2 CUDA kernels

base_model: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add cut cross entropy

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget this?

- k_proj
- v_proj
- o_proj

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add commented out section in case they want to train on experts

if self.cfg.model_config_type in [
"jamba",
"qwen2_moe",
"nemotron_h",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this explicitly needed?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ved1beta check

Comment thread src/axolotl/monkeypatch/moe_quant.py Outdated
Comment on lines +256 to +261
# Patch _check_target_module_exists to skip ParametrizationList modules.
# After quantize_moe_experts runs, expert params become ParametrizationList
# modules at paths like "...experts.parametrizations.up_proj". Without this
# patch, lora_target_modules name-suffix matching finds "up_proj" there and
# tries to wrap it in LoRA, which PEFT rejects.
_original_check = BaseTuner._check_target_module_exists

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain here what this means or why this happens? Is this for this model specifically or did we miss it earlier?

@@ -0,0 +1,16 @@
{%- if messages[0].role == 'system' %}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This model's arch is nemotron_h . we should keep file consistent

@winglian winglian added the scheduled_release This PR is slated for the upcoming release label Mar 21, 2026
@zerofata

Copy link
Copy Markdown

Tested this but got an error.

# nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16
# Hybrid Mamba2 / Attention / MoE architecture (model_type: nemotron_h)
# 120B total params, ~12B active per token, 88 layers
#
# Requirements:
#   pip install mamba-ssm causal-conv1d   # for fast Mamba2 CUDA kernels
#   flash-attn >= 2.0                     # for flash attention + sample packing

base_model: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16

trust_remote_code: true

# Nemotron-H attention layers live in NemotronHBlock.mixer, not the standard
# layer.self_attn, so the lora kernel patches can't discover them.
# relu2 (mlp_hidden_act) is also unsupported by lora_mlp_kernel.
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false

datasets:
  - path: ./data/nemotron_sft_1_masked_20260323_232338.jsonl

val_set_size: 0.0
output_dir: ./Nemotron-3-Super-120B-SFT-1
dataset_prepared_path: last_run_prepared

sequence_len: 10756
sample_packing: true

use_cut_cross_entropy: true

load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
lora_r: 16
lora_alpha: 32
peft_use_rslora: true
lora_dropout: 0.0
lora_target_modules:
  # Attention projection layers (present in ~12 attention layers out of 88)
  - q_proj
  - k_proj
  - v_proj
  - o_proj
  # Uncomment to also train MoE expert weights:
  - up_proj
  - down_proj
  - gate_proj

wandb_project: Nemotron-3-Super-120B-SFT
wandb_entity:
wandb_watch:
wandb_name: Nemotron-3-Super-120B-SFT-1
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 6e-5

bf16: auto
tf32: true

resume_from_checkpoint:
logging_steps: 1
flash_attention: true

warmup_ratio: 0.1
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0

special_tokens:
  pad_token: <|end_of_text|>

fsdp_config:
  fsdp_version: 2
  offload_params: false
  cpu_ram_efficient_loading: false
  auto_wrap_policy: TRANSFORMER_BASED_WRAP
  transformer_layer_cls_to_wrap: NemotronHBlock
  state_dict_type: FULL_STATE_DICT
  sharding_strategy: FULL_SHARD
  reshard_after_forward: true
  activation_checkpointing: true
wandb: WARNING Symlinked 1 file into the W&B run directory; call wandb.save again to sync new files.
[2026-03-23 21:14:10,929] [INFO] [axolotl.utils.callbacks] The Axolotl config has been saved to the WandB run under files.
  0%|                                                                                                                              | 0/69 [00:00<?, ?it/s][rank3]: Traceback (most recent call last):
[rank3]:   File "<frozen runpy>", line 198, in _run_module_as_main
[rank3]:   File "<frozen runpy>", line 88, in _run_code
[rank3]:   File "/workspace/axolotl/src/axolotl/cli/train.py", line 121, in <module>
[rank3]:     fire.Fire(do_cli)
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/fire/core.py", line 135, in Fire
[rank3]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank3]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/fire/core.py", line 468, in _Fire
[rank3]:     component, remaining_args = _CallAndUpdateTrace(
[rank3]:                                 ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/fire/core.py", line 684, in _CallAndUpdateTrace
[rank3]:     component = fn(*varargs, **kwargs)
[rank3]:                 ^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/workspace/axolotl/src/axolotl/cli/train.py", line 88, in do_cli
[rank3]:     return do_train(parsed_cfg, parsed_cli_args)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/workspace/axolotl/src/axolotl/cli/train.py", line 45, in do_train
[rank3]:     model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
[rank3]:                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/workspace/axolotl/src/axolotl/telemetry/errors.py", line 124, in wrapper
[rank3]:     return func(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/workspace/axolotl/src/axolotl/train.py", line 591, in train
[rank3]:     execute_training(cfg, trainer, resume_from_checkpoint)
[rank3]:   File "/workspace/axolotl/src/axolotl/train.py", line 219, in execute_training
[rank3]:     trainer.train(resume_from_checkpoint=resume_from_checkpoint)
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1424, in train
[rank3]:     return inner_training_loop(
[rank3]:            ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1506, in _inner_training_loop
[rank3]:     self._run_epoch(
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1734, in _run_epoch
[rank3]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank3]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/workspace/axolotl/src/axolotl/core/trainers/mixins/activation_checkpointing.py", line 46, in training_step
[rank3]:     return super().training_step(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1906, in training_step
[rank3]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/workspace/axolotl/src/axolotl/core/trainers/base.py", line 391, in compute_loss
[rank3]:     return super().compute_loss(
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1978, in compute_loss
[rank3]:     outputs = model(**inputs)
[rank3]:               ^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1881, in _call_impl
[rank3]:     return inner()
[rank3]:            ^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1829, in inner
[rank3]:     result = forward_call(*args, **kwargs)
[rank3]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/peft/peft_model.py", line 1923, in forward
[rank3]:     return self.base_model(
[rank3]:            ^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/peft/tuners/tuners_utils.py", line 311, in forward
[rank3]:     return self.model.forward(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 1718, in forward
[rank3]:     nemotron_h_outputs = self.model(
[rank3]:                          ^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 1489, in forward
[rank3]:     hidden_states = mixer_block(
[rank3]:                     ^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
[rank3]:     return self.checkpoint_fn(  # type: ignore[misc]
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/_compile.py", line 53, in inner
[rank3]:     return disable_fn(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
[rank3]:     return fn(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/utils/checkpoint.py", line 503, in checkpoint
[rank3]:     ret = function(*args, **kwargs)
[rank3]:           ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1881, in _call_impl
[rank3]:     return inner()
[rank3]:            ^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1829, in inner
[rank3]:     result = forward_call(*args, **kwargs)
[rank3]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 822, in forward
[rank3]:     hidden_states = self.mixer(
[rank3]:                     ^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
[rank3]:     return self.checkpoint_fn(  # type: ignore[misc]
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/_compile.py", line 53, in inner
[rank3]:     return disable_fn(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
[rank3]:     return fn(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/utils/checkpoint.py", line 503, in checkpoint
[rank3]:     ret = function(*args, **kwargs)
[rank3]:           ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 736, in forward
[rank3]:     return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 456, in cuda_kernels_forward
[rank3]:     out, ssm_state = mamba_split_conv1d_scan_combined(
[rank3]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 997, in mamba_split_conv1d_scan_combined
[rank3]:     return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 581, in apply
[rank3]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/torch/amp/autocast_mode.py", line 527, in decorate_fwd
[rank3]:     return fwd(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 881, in forward
[rank3]:     out = F.linear(out, outproj_weight, outproj_bias)
[rank3]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/dist-packages/bitsandbytes/nn/modules.py", line 402, in __torch_function__
[rank3]:     return super().__torch_function__(func, types, args, kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: RuntimeError: mat1 and mat2 shapes cannot be multiplied (10816x8192 and 1x8388608)
[rank1]: Traceback (most recent call last):
[rank1]:   File "<frozen runpy>", line 198, in _run_module_as_main
[rank1]:   File "<frozen runpy>", line 88, in _run_code
[rank1]:   File "/workspace/axolotl/src/axolotl/cli/train.py", line 121, in <module>
[rank1]:     fire.Fire(do_cli)
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/fire/core.py", line 135, in Fire
[rank1]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank1]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/fire/core.py", line 468, in _Fire
[rank1]:     component, remaining_args = _CallAndUpdateTrace(
[rank1]:                                 ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/fire/core.py", line 684, in _CallAndUpdateTrace
[rank1]:     component = fn(*varargs, **kwargs)
[rank1]:                 ^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/axolotl/src/axolotl/cli/train.py", line 88, in do_cli
[rank1]:     return do_train(parsed_cfg, parsed_cli_args)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/axolotl/src/axolotl/cli/train.py", line 45, in do_train
[rank1]:     model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
[rank1]:                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/axolotl/src/axolotl/telemetry/errors.py", line 124, in wrapper
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/axolotl/src/axolotl/train.py", line 591, in train
[rank1]:     execute_training(cfg, trainer, resume_from_checkpoint)
[rank1]:   File "/workspace/axolotl/src/axolotl/train.py", line 219, in execute_training
[rank1]:     trainer.train(resume_from_checkpoint=resume_from_checkpoint)
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1424, in train
[rank1]:     return inner_training_loop(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1506, in _inner_training_loop
[rank1]:     self._run_epoch(
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1734, in _run_epoch
[rank1]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank1]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/axolotl/src/axolotl/core/trainers/mixins/activation_checkpointing.py", line 46, in training_step
[rank1]:     return super().training_step(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1906, in training_step
[rank1]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/axolotl/src/axolotl/core/trainers/base.py", line 391, in compute_loss
[rank1]:     return super().compute_loss(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1978, in compute_loss
[rank1]:     outputs = model(**inputs)
[rank1]:               ^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1881, in _call_impl
[rank1]:     return inner()
[rank1]:            ^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1829, in inner
[rank1]:     result = forward_call(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/peft/peft_model.py", line 1923, in forward
[rank1]:     return self.base_model(
[rank1]:            ^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/peft/tuners/tuners_utils.py", line 311, in forward
[rank1]:     return self.model.forward(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 1718, in forward
[rank1]:     nemotron_h_outputs = self.model(
[rank1]:                          ^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 1489, in forward
[rank1]:     hidden_states = mixer_block(
[rank1]:                     ^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
[rank1]:     return self.checkpoint_fn(  # type: ignore[misc]
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/_compile.py", line 53, in inner
[rank1]:     return disable_fn(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
[rank1]:     return fn(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/utils/checkpoint.py", line 503, in checkpoint
[rank1]:     ret = function(*args, **kwargs)
[rank1]:           ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1881, in _call_impl
[rank1]:     return inner()
[rank1]:            ^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1829, in inner
[rank1]:     result = forward_call(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 822, in forward
[rank1]:     hidden_states = self.mixer(
[rank1]:                     ^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
[rank1]:     return self.checkpoint_fn(  # type: ignore[misc]
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/_compile.py", line 53, in inner
[rank1]:     return disable_fn(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
[rank1]:     return fn(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/utils/checkpoint.py", line 503, in checkpoint
[rank1]:     ret = function(*args, **kwargs)
[rank1]:           ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 736, in forward
[rank1]:     return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 456, in cuda_kernels_forward
[rank1]:     out, ssm_state = mamba_split_conv1d_scan_combined(
[rank1]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 997, in mamba_split_conv1d_scan_combined
[rank1]:     return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 581, in apply
[rank1]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/amp/autocast_mode.py", line 527, in decorate_fwd
[rank1]:     return fwd(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 881, in forward
[rank1]:     out = F.linear(out, outproj_weight, outproj_bias)
[rank1]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/bitsandbytes/nn/modules.py", line 402, in __torch_function__
[rank1]:     return super().__torch_function__(func, types, args, kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: mat1 and mat2 shapes cannot be multiplied (10816x8192 and 1x8388608)
[2026-03-23 21:14:56,394] [ERROR] [axolotl.telemetry.errors] Error captured in telemetry. Run ID: b228720b-0625-416a-b873-724b89ee46cb
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/workspace/axolotl/src/axolotl/cli/train.py", line 121, in <module>
    fire.Fire(do_cli)
  File "/usr/local/lib/python3.11/dist-packages/fire/core.py", line 135, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/fire/core.py", line 468, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
                                ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/fire/core.py", line 684, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/axolotl/src/axolotl/cli/train.py", line 88, in do_cli
    return do_train(parsed_cfg, parsed_cli_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/axolotl/src/axolotl/cli/train.py", line 45, in do_train
    model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/axolotl/src/axolotl/telemetry/errors.py", line 127, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/axolotl/src/axolotl/train.py", line 591, in train
    execute_training(cfg, trainer, resume_from_checkpoint)
  File "/workspace/axolotl/src/axolotl/train.py", line 219, in execute_training
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1424, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1506, in _inner_training_loop
    self._run_epoch(
  File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1734, in _run_epoch
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/axolotl/src/axolotl/core/trainers/mixins/activation_checkpointing.py", line 46, in training_step
    return super().training_step(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1906, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/axolotl/src/axolotl/core/trainers/base.py", line 391, in compute_loss
    return super().compute_loss(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1978, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1881, in _call_impl
    return inner()
           ^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1829, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/peft/peft_model.py", line 1923, in forward
    return self.base_model(
           ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/peft/tuners/tuners_utils.py", line 311, in forward
    return self.model.forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 1718, in forward
    nemotron_h_outputs = self.model(
                         ^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 1489, in forward
    hidden_states = mixer_block(
                    ^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
    return self.checkpoint_fn(  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_compile.py", line 53, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/checkpoint.py", line 503, in checkpoint
    ret = function(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1881, in _call_impl
    return inner()
           ^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1829, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 822, in forward
    hidden_states = self.mixer(
                    ^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
    return self.checkpoint_fn(  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_compile.py", line 53, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/checkpoint.py", line 503, in checkpoint
    ret = function(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 736, in forward
    return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 456, in cuda_kernels_forward
    out, ssm_state = mamba_split_conv1d_scan_combined(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 997, in mamba_split_conv1d_scan_combined
    return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 581, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/amp/autocast_mode.py", line 527, in decorate_fwd
    return fwd(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 881, in forward
    out = F.linear(out, outproj_weight, outproj_bias)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/bitsandbytes/nn/modules.py", line 402, in __torch_function__
    return super().__torch_function__(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mat1 and mat2 shapes cannot be multiplied (10816x8192 and 1x8388608)
[rank0]: Traceback (most recent call last):
[rank0]:   File "<frozen runpy>", line 198, in _run_module_as_main
[rank0]:   File "<frozen runpy>", line 88, in _run_code
[rank0]:   File "/workspace/axolotl/src/axolotl/cli/train.py", line 121, in <module>
[rank0]:     fire.Fire(do_cli)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/fire/core.py", line 135, in Fire
[rank0]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/fire/core.py", line 468, in _Fire
[rank0]:     component, remaining_args = _CallAndUpdateTrace(
[rank0]:                                 ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/fire/core.py", line 684, in _CallAndUpdateTrace
[rank0]:     component = fn(*varargs, **kwargs)
[rank0]:                 ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/axolotl/src/axolotl/cli/train.py", line 88, in do_cli
[rank0]:     return do_train(parsed_cfg, parsed_cli_args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/axolotl/src/axolotl/cli/train.py", line 45, in do_train
[rank0]:     model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
[rank0]:                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/axolotl/src/axolotl/telemetry/errors.py", line 127, in wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/axolotl/src/axolotl/train.py", line 591, in train
[rank0]:     execute_training(cfg, trainer, resume_from_checkpoint)
[rank0]:   File "/workspace/axolotl/src/axolotl/train.py", line 219, in execute_training
[rank0]:     trainer.train(resume_from_checkpoint=resume_from_checkpoint)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1424, in train
[rank0]:     return inner_training_loop(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1506, in _inner_training_loop
[rank0]:     self._run_epoch(
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1734, in _run_epoch
[rank0]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/axolotl/src/axolotl/core/trainers/mixins/activation_checkpointing.py", line 46, in training_step
[rank0]:     return super().training_step(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1906, in training_step
[rank0]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/axolotl/src/axolotl/core/trainers/base.py", line 391, in compute_loss
[rank0]:     return super().compute_loss(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1978, in compute_loss
[rank0]:     outputs = model(**inputs)
[rank0]:               ^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1881, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1829, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/peft/peft_model.py", line 1923, in forward
[rank0]:     return self.base_model(
[rank0]:            ^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/peft/tuners/tuners_utils.py", line 311, in forward
[rank0]:     return self.model.forward(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 1718, in forward
[rank0]:     nemotron_h_outputs = self.model(
[rank0]:                          ^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 1489, in forward
[rank0]:     hidden_states = mixer_block(
[rank0]:                     ^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
[rank0]:     return self.checkpoint_fn(  # type: ignore[misc]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_compile.py", line 53, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/utils/checkpoint.py", line 503, in checkpoint
[rank0]:     ret = function(*args, **kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1881, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1829, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 822, in forward
[rank0]:     hidden_states = self.mixer(
[rank0]:                     ^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
[rank0]:     return self.checkpoint_fn(  # type: ignore[misc]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_compile.py", line 53, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/utils/checkpoint.py", line 503, in checkpoint
[rank0]:     ret = function(*args, **kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 736, in forward
[rank0]:     return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 456, in cuda_kernels_forward
[rank0]:     out, ssm_state = mamba_split_conv1d_scan_combined(
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 997, in mamba_split_conv1d_scan_combined
[rank0]:     return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 581, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/amp/autocast_mode.py", line 527, in decorate_fwd
[rank0]:     return fwd(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 881, in forward
[rank0]:     out = F.linear(out, outproj_weight, outproj_bias)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/bitsandbytes/nn/modules.py", line 402, in __torch_function__
[rank0]:     return super().__torch_function__(func, types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: mat1 and mat2 shapes cannot be multiplied (10816x8192 and 1x8388608)
[rank2]: Traceback (most recent call last):
[rank2]:   File "<frozen runpy>", line 198, in _run_module_as_main
[rank2]:   File "<frozen runpy>", line 88, in _run_code
[rank2]:   File "/workspace/axolotl/src/axolotl/cli/train.py", line 121, in <module>
[rank2]:     fire.Fire(do_cli)
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/fire/core.py", line 135, in Fire
[rank2]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank2]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/fire/core.py", line 468, in _Fire
[rank2]:     component, remaining_args = _CallAndUpdateTrace(
[rank2]:                                 ^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/fire/core.py", line 684, in _CallAndUpdateTrace
[rank2]:     component = fn(*varargs, **kwargs)
[rank2]:                 ^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/workspace/axolotl/src/axolotl/cli/train.py", line 88, in do_cli
[rank2]:     return do_train(parsed_cfg, parsed_cli_args)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/workspace/axolotl/src/axolotl/cli/train.py", line 45, in do_train
[rank2]:     model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
[rank2]:                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/workspace/axolotl/src/axolotl/telemetry/errors.py", line 124, in wrapper
[rank2]:     return func(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/workspace/axolotl/src/axolotl/train.py", line 591, in train
[rank2]:     execute_training(cfg, trainer, resume_from_checkpoint)
[rank2]:   File "/workspace/axolotl/src/axolotl/train.py", line 219, in execute_training
[rank2]:     trainer.train(resume_from_checkpoint=resume_from_checkpoint)
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1424, in train
[rank2]:     return inner_training_loop(
[rank2]:            ^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1506, in _inner_training_loop
[rank2]:     self._run_epoch(
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1734, in _run_epoch
[rank2]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank2]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/workspace/axolotl/src/axolotl/core/trainers/mixins/activation_checkpointing.py", line 46, in training_step
[rank2]:     return super().training_step(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1906, in training_step
[rank2]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/workspace/axolotl/src/axolotl/core/trainers/base.py", line 391, in compute_loss
[rank2]:     return super().compute_loss(
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1978, in compute_loss
[rank2]:     outputs = model(**inputs)
[rank2]:               ^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1881, in _call_impl
[rank2]:     return inner()
[rank2]:            ^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1829, in inner
[rank2]:     result = forward_call(*args, **kwargs)
[rank2]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/peft/peft_model.py", line 1923, in forward
[rank2]:     return self.base_model(
[rank2]:            ^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank2]:     return forward_call(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/peft/tuners/tuners_utils.py", line 311, in forward
[rank2]:     return self.model.forward(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 1718, in forward
[rank2]:     nemotron_h_outputs = self.model(
[rank2]:                          ^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank2]:     return forward_call(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 1489, in forward
[rank2]:     hidden_states = mixer_block(
[rank2]:                     ^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank2]:     return forward_call(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
[rank2]:     return self.checkpoint_fn(  # type: ignore[misc]
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/_compile.py", line 53, in inner
[rank2]:     return disable_fn(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
[rank2]:     return fn(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/utils/checkpoint.py", line 503, in checkpoint
[rank2]:     ret = function(*args, **kwargs)
[rank2]:           ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1881, in _call_impl
[rank2]:     return inner()
[rank2]:            ^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1829, in inner
[rank2]:     result = forward_call(*args, **kwargs)
[rank2]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 822, in forward
[rank2]:     hidden_states = self.mixer(
[rank2]:                     ^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank2]:     return forward_call(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
[rank2]:     return self.checkpoint_fn(  # type: ignore[misc]
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/_compile.py", line 53, in inner
[rank2]:     return disable_fn(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
[rank2]:     return fn(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/utils/checkpoint.py", line 503, in checkpoint
[rank2]:     ret = function(*args, **kwargs)
[rank2]:           ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank2]:     return forward_call(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 736, in forward
[rank2]:     return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/root/.cache/huggingface/modules/transformers_modules/nvidia/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Super_hyphen_120B_hyphen_A12B_hyphen_BF16/c0c012c1aaa62bba6493897796460eec83902141/modeling_nemotron_h.py", line 456, in cuda_kernels_forward
[rank2]:     out, ssm_state = mamba_split_conv1d_scan_combined(
[rank2]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 997, in mamba_split_conv1d_scan_combined
[rank2]:     return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 581, in apply
[rank2]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/torch/amp/autocast_mode.py", line 527, in decorate_fwd
[rank2]:     return fwd(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 881, in forward
[rank2]:     out = F.linear(out, outproj_weight, outproj_bias)
[rank2]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.11/dist-packages/bitsandbytes/nn/modules.py", line 402, in __torch_function__
[rank2]:     return super().__torch_function__(func, types, args, kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: RuntimeError: mat1 and mat2 shapes cannot be multiplied (10816x8192 and 1x8388608)
wandb: 

@ved1beta

Copy link
Copy Markdown
Member Author

removing this shoude work
trust_remote_code: true

@zerofata

Copy link
Copy Markdown

Thanks! That did the trick.

@NanoCode012 NanoCode012 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's a dd a README

Comment thread examples/nemotron-h/120b-a12b-qlora.yaml Outdated
Comment thread examples/nemotron-h/120b-a12b-qlora.yaml Outdated
# pip install mamba-ssm causal-conv1d # for fast Mamba2 CUDA kernels

base_model: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget this?

Comment thread examples/nemotron-h/nano-30b-a3b-qlora.yaml
Comment thread examples/nemotron-h/nano-30b-a3b-qlora.yaml Outdated
Comment thread src/axolotl/loaders/patch_manager.py Outdated
NemotronHPreTrainedModel,
)

NemotronHPreTrainedModel.supports_gradient_checkpointing = True

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason upstream sets this to False? What did you fix to satisfy this?

Comment thread src/axolotl/monkeypatch/models/nemotron_h/modeling.py
Comment thread src/axolotl/monkeypatch/lora_kernels.py
return False
return _original_check(config, key)

BaseTuner._check_target_module_exists = _patched_check_target_module_exists

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a change that was already merged? Or maybe I misread. Any reason why this issue didn't occur for prev models we tested?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not already merged , did'nt occur for mixtral for deepseek ect store each expert as an individual nn.Linear module
nemotron is different its MoE experts are stored as 3D stacked nn.Parameter tensors (shape [num_experts, out, in]) , not individual nn.Linear modules.

@ved1beta ved1beta requested a review from NanoCode012 March 24, 2026 14:26
@zerofata

zerofata commented Mar 25, 2026

Copy link
Copy Markdown

Don't know if it's helpful but this was a config I trained on 4*H200 fairly comfortably.

I don't think the model has a gate_proj.
Swapped the pad token to an unnused one that doesn't increase the vocab size (the vocab size increase made GGUF'ing the model a headache).

Merged and did inference seemingly ok.

Random small thing I noticed, dunno if it's important since the model seems to work ok, but I saw one of the model keys got renamed.

Original nvidia model: backbone.embeddings.weight
after training and merging lora: backbone.embedding.weight

image
base_model: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false

datasets:
  - path: ./data/nemotron_sft_2_masked_20260324_190315.jsonl

val_set_size: 0.03
output_dir: ./Nemotron-3-Super-120B-SFT-5
dataset_prepared_path: last_run_prepared

sequence_len: 10756
sample_packing: true

use_cut_cross_entropy: true

load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
lora_r: 64
lora_alpha: 128
lora_dropout: 0.0
lora_target_modules:
  - q_proj
  - k_proj
  - v_proj
  - o_proj
  - up_proj #shexp
  - down_proj #shexp
  - in_proj
  - out_proj

lora_target_parameters:
  - mixer.experts.up_proj
  - mixer.experts.down_proj

wandb_project: Nemotron-3-Super-120B-SFT
wandb_entity:
wandb_watch:
wandb_name: Nemotron-3-Super-120B-SFT-5
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch_4bit
lr_scheduler: constant_with_warmup
learning_rate: 1e-5

bf16: auto
tf32: true

resume_from_checkpoint:
logging_steps: 1
flash_attention: true

warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 2
weight_decay: 0.0

special_tokens:
  pad_token: <unk>

fsdp_config:
  fsdp_version: 2
  offload_params: false
  cpu_ram_efficient_loading: false
  auto_wrap_policy: TRANSFORMER_BASED_WRAP
  transformer_layer_cls_to_wrap: NemotronHBlock
  state_dict_type: FULL_STATE_DICT
  sharding_strategy: FULL_SHARD
  reshard_after_forward: true
  activation_checkpointing: true

@ved1beta

Copy link
Copy Markdown
Member Author

@zerofata thanks for the help i have changed the pading token

Random small thing I noticed, dunno if it's important since the model seems to work ok, but I saw one of the model keys got renamed.

not sure about this will investigate

@winglian

Copy link
Copy Markdown
Collaborator

the merge key renaming is likely an issue with trust_remote_code and _build_checkpoint_conversion_mapping() in transformers

@winglian

Copy link
Copy Markdown
Collaborator

@ved1beta shouldn't we not need trust_remote_code since nemotron_h is supported in transformers? what's the difference in the two?

removing this shoude work trust_remote_code: true

@winglian

Copy link
Copy Markdown
Collaborator

also, can you rebase this branch on main to get the latest framework updates ?

Comment on lines +34 to +36
# Add gate_up_proj and down_proj to also target shared experts (nn.Linear):
# - gate_up_proj
# - down_proj

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why we have qwen 3.5 examples in this super nemotron PR. also, I don't think this is correct either

if self.cfg.model_config_type in [
"jamba",
"qwen2_moe",
"nemotron_h",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, is this needed here? did you test without this condition?

@winglian winglian left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

other than the stray qwen3.5 example YAML's, this looks good to go once those are removed

Comment thread examples/nemotron-h/120b-a12b-qlora.yaml Outdated
Comment thread examples/nemotron-h/120b-a12b-qlora.yaml Outdated
Comment thread examples/nemotron-h/nano-30b-a3b-qlora.yaml Outdated
lora_target_parameters:
- up_proj
- down_proj
```

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add limitation note that MoE Kernels not supported yet

if self.cfg.model_config_type in [
"jamba",
"qwen2_moe",
"nemotron_h",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ved1beta check

Comment on lines +311 to +314
# supports_gradient_checkpointing is only enabled after
# patch_nemotron_h_modeling_packing() installs the GC-compatible
# NemotronHBlock.forward. Without the patch, upstream marks this
# False because the original block forward is not GC-safe.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is true, we need to raise error in validator that gradient checkpointing without packing for this model does not work

NemotronHPreTrainedModel.supports_gradient_checkpointing = True

@staticmethod
def _fix_nemotron_h_conversion_mapping():

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the ramifications of this change? How is this related to LoRA? When an adapter is applied, the weight rename would've happened correctly?

What do downstream libs expect?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the entry prevents save_pretrained() from applying the reverse rename (embeddings→embedding) on merge+save, which would corrupt the checkpoint key that transformers/vLLM expect.

Comment on lines +41 to +48
try:
import transformers.modeling_flash_attention_utils as fa_utils

from axolotl.monkeypatch.utils import get_unpad_data

fa_utils._get_unpad_data = get_unpad_data
except Exception as exc: # pragma: no cover
LOG.warning("Failed to patch _get_unpad_data for NemotronH: %s", exc)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure you need to patch this here? FA monkey patch would've handled this?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nemotron_h is excluded from SUPPORTED_MULTIPACK_MODEL_TYPES,
the standard patch_for_multipack() only patches _get_unpad_data, so the entire packing suite is handled together in _apply_model_patches() instead.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just add nemotron_h to SUPPORTED_MULTIPACK_MODEL_TYPES , so we don't need to duplicate this code here

@ved1beta ved1beta mentioned this pull request Mar 29, 2026

@NanoCode012 NanoCode012 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you

@NanoCode012 NanoCode012 linked an issue Mar 30, 2026 that may be closed by this pull request
5 tasks
@winglian winglian merged commit bb622b8 into axolotl-ai-cloud:main Mar 30, 2026
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

scheduled_release This PR is slated for the upcoming release

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support for NVIDIA-Nemotron-3 models

4 participants