Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
c8e73b6
Update transformers upperbound to v5 rc0
thomasdhc Dec 2, 2025
f32e96e
fix import
akoumpa Dec 2, 2025
c0d7724
fix
akoumpa Dec 2, 2025
c20e37b
fix import
akoumpa Dec 2, 2025
23d9387
fix
akoumpa Dec 2, 2025
6a5e4eb
hf home
adil-a Dec 2, 2025
9a10703
fix
akoumpa Dec 2, 2025
44aa727
remove debug
akoumpa Dec 2, 2025
889ab9d
fix
akoumpa Dec 2, 2025
e8b9626
fix
akoumpa Dec 2, 2025
63fd0bd
fix
akoumpa Dec 2, 2025
cdde6d6
fix import
akoumpa Dec 2, 2025
4f155ea
fix
akoumpa Dec 2, 2025
abe2485
meta device and ckptr fix
adil-a Dec 2, 2025
6d9d79b
fix
adil-a Dec 2, 2025
5eb49f2
guard
akoumpa Dec 2, 2025
6542aa9
rope_param
akoumpa Dec 2, 2025
4a3cf20
support older version
akoumpa Dec 2, 2025
f404c91
fix import
akoumpa Dec 2, 2025
82dafc3
fmt
akoumpa Dec 2, 2025
203a59a
update rope to v5
akoumpa Dec 2, 2025
57b4747
lint
akoumpa Dec 2, 2025
57138ae
lint
akoumpa Dec 2, 2025
a055437
lint
akoumpa Dec 2, 2025
df51b8a
fix
akoumpa Dec 2, 2025
3677979
lint
akoumpa Dec 2, 2025
f942462
use get_as_string
akoumpa Dec 2, 2025
0dacc46
update test
akoumpa Dec 2, 2025
9685a5b
fix rope
akoumpa Dec 3, 2025
ea4044a
hf_home -> hf_hub_cache
adil-a Dec 3, 2025
3dbdb84
fmt
akoumpa Dec 3, 2025
cab758c
lint + fixing custom model download
adil-a Dec 3, 2025
f54d7a0
fixing unit test
adil-a Dec 3, 2025
1ff0e8d
fix
akoumpa Dec 9, 2025
1dd9c96
update
akoumpa Dec 9, 2025
233245e
add remapping
akoumpa Dec 9, 2025
95f00d8
add file
akoumpa Dec 16, 2025
5311108
lint
akoumpa Dec 16, 2025
5fa4911
lint
akoumpa Dec 16, 2025
a6c9c97
fix ruff
akoumpa Dec 16, 2025
b9d65c3
Update uv lock
akoumpa Dec 16, 2025
3d86667
fix
akoumpa Dec 16, 2025
6f9cc25
lower-bound transformers
akoumpa Dec 16, 2025
be23c62
Update uv lock
akoumpa Dec 16, 2025
b03209a
fix
akoumpa Dec 17, 2025
fe8437d
fix?
akoumpa Dec 18, 2025
429acad
Merge branch 'main' into transformers_v5_rc0
akoumpa Dec 19, 2025
bd22ab7
Merge branch 'transformers_v5_rc0' into huiyingl/transformers_v5_omni
HuiyingLi Jan 1, 2026
0c6801a
transformers v5rc1 changes
HuiyingLi Jan 2, 2026
dfe7e14
add transformers v5 AutomodelForMultimodalLM
HuiyingLi Jan 2, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,21 @@ jobs:
- name: Set up UV
uses: astral-sh/setup-uv@v1
with:
version: 0.7.2
version: 0.8.22
- name: Install ruff
env:
UV_PROJECT_ENVIRONMENT: ./venv
run: |
uv venv ${UV_PROJECT_ENVIRONMENT} --system-site-packages
source ./venv/bin/activate
uv venv ${UV_PROJECT_ENVIRONMENT}

export PATH="./bin/:$PATH"

uv sync --link-mode copy --locked --group linting

- name: Run ruff
env:
UV_PROJECT_ENVIRONMENT: ./venv
run: |
source ./venv/bin/activate
uv run ruff check . --verbose
uv run ruff format --check . --verbose

Expand All @@ -80,16 +80,16 @@ jobs:
env:
UV_PROJECT_ENVIRONMENT: ./venv
run: |
uv venv ${UV_PROJECT_ENVIRONMENT} --system-site-packages
source ./venv/bin/activate
uv venv ${UV_PROJECT_ENVIRONMENT}

export PATH="./bin/:$PATH"

uv sync --link-mode copy --locked --group linting

- name: Run import-linter
env:
UV_PROJECT_ENVIRONMENT: ./venv
run: |
source ./venv/bin/activate
uv run lint-imports --debug --verbose --no-cache

Nemo_Linting_Test:
Expand Down
76 changes: 56 additions & 20 deletions docker/common/uv-pytorch.lock

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ parallelizer:
activation_checkpointing: false

model:
_target_: nemo_automodel.NeMoAutoModelForImageTextToText.from_pretrained
_target_: nemo_automodel.NeMoAutoModelForMultimodalLM.from_pretrained
pretrained_model_name_or_path: Qwen/Qwen3-Omni-30B-A3B-Instruct
# Customize this backend for fine grained control
# backend:
Expand Down
3 changes: 3 additions & 0 deletions nemo_automodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,19 @@
from nemo_automodel._transformers.auto_model import (
NeMoAutoModelForCausalLM,
NeMoAutoModelForImageTextToText,
NeMoAutoModelForMultimodalLM,
NeMoAutoModelForSequenceClassification,
NeMoAutoModelForTextToWaveform,
) # noqa: I001

globals()["NeMoAutoModelForCausalLM"] = NeMoAutoModelForCausalLM
globals()["NeMoAutoModelForImageTextToText"] = NeMoAutoModelForImageTextToText
globals()["NeMoAutoModelForMultimodalLM"] = NeMoAutoModelForMultimodalLM
globals()["NeMoAutoModelForSequenceClassification"] = NeMoAutoModelForSequenceClassification
globals()["NeMoAutoModelForTextToWaveform"] = NeMoAutoModelForTextToWaveform
__all__.append("NeMoAutoModelForCausalLM")
__all__.append("NeMoAutoModelForImageTextToText")
__all__.append("NeMoAutoModelForMultimodalLM")
__all__.append("NeMoAutoModelForSequenceClassification")
__all__.append("NeMoAutoModelForTextToWaveform")
except:
Expand Down
31 changes: 20 additions & 11 deletions nemo_automodel/_transformers/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@
AutoConfig,
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForMultimodalLM,
AutoModelForSequenceClassification,
AutoModelForTextToWaveform,
PreTrainedModel,
)
from transformers.modeling_utils import _get_resolved_checkpoint_files
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from transformers.utils.hub import DownloadKwargs

import nemo_automodel.components.distributed.utils as dist_utils
from nemo_automodel import __version__
Expand Down Expand Up @@ -227,24 +229,25 @@ def _download_model_weights(hf_config, pretrained_model_name_or_path):
# Import via module reference (vs bound name) so unit tests can patch
# `nemo_automodel.components.distributed.utils.FirstRankPerNode`.
with dist_utils.FirstRankPerNode():
download_kwargs = {
"cache_dir": None,
"force_download": False,
"proxies": None,
"local_files_only": False,
"token": None,
"revision": "main",
"subfolder": "",
"commit_hash": getattr(hf_config, "_commit_hash", None),
}
_get_resolved_checkpoint_files(
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder="",
variant=None,
gguf_file=None,
from_tf=False,
from_flax=False,
use_safetensors=None,
cache_dir=None,
force_download=False,
proxies=None,
local_files_only=False,
token=None,
download_kwargs=download_kwargs,
user_agent={"file_type": "model", "framework": "pytorch", "from_auto_class": False},
revision="main",
commit_hash=getattr(hf_config, "_commit_hash", None),
is_remote_code=False,
transformers_explicit_filename=None,
transformers_explicit_filename=getattr(hf_config, "transformers_weights", None),
)


Expand Down Expand Up @@ -652,6 +655,12 @@ class NeMoAutoModelForImageTextToText(_BaseNeMoAutoModelClass, AutoModelForImage
pass


class NeMoAutoModelForMultimodalLM(_BaseNeMoAutoModelClass, AutoModelForMultimodalLM):
"""Drop-in replacement for ``transformers.AutoModelForMultimodalLM`` with custom-kernels."""

pass


class NeMoAutoModelForSequenceClassification(_BaseNeMoAutoModelClass, AutoModelForSequenceClassification):
"""Drop-in replacement for ``transformers.AutoModelForSequenceClassification`` with custom-kernels.

Expand Down
176 changes: 171 additions & 5 deletions nemo_automodel/components/checkpoint/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
import torch
import torch.distributed.checkpoint as dcp
import yaml
from huggingface_hub import constants as hf_constants
from packaging.version import parse
from safetensors.torch import load_file, save_file
from torch import nn
from torch.distributed.device_mesh import DeviceMesh
from transformers.utils import TRANSFORMERS_CACHE

from nemo_automodel.components.checkpoint._backports.consolidate_hf_safetensors import (
consolidate_safetensors_files_on_every_rank,
Expand All @@ -38,12 +38,16 @@
get_fqn_to_file_index_mapping,
)
from nemo_automodel.components.checkpoint.addons import ConsolidatedHFAddon, PeftAddon
from nemo_automodel.components.checkpoint.conversion_mapping import (
get_combined_key_mapping,
requires_tensor_merging,
)
from nemo_automodel.components.checkpoint.stateful_wrappers import ModelState, OptimizerState
from nemo_automodel.components.checkpoint.utils import is_tied_word_embeddings

if TYPE_CHECKING:
from peft import PeftConfig
from transformers.tokenization_utils import PreTrainedTokenizerBase
from transformers.tokenization_utils_base import PreTrainedTokenizerBase


def _is_geq_torch_2_9() -> bool:
Expand Down Expand Up @@ -284,6 +288,7 @@ def load_model(
- For PEFT (non-init): rank 0 reads `adapter_model.safetensors`, then broadcasts.
- Otherwise: use DCP with a Hugging Face or default storage reader to populate the state dict.
- If the model exposes a `state_dict_adapter`, convert to/from HF format as needed.
- For models requiring tensor merging (e.g., Mixtral), uses transformers' conversion mapping.

Args:
model: Model or parallelized model parts to load into.
Expand All @@ -301,6 +306,20 @@ def load_model(
is_init_step=is_init_step,
skip_task_head_prefixes=getattr(self.config, "skip_task_head_prefixes_for_base_model", None),
)

# Check if this model requires tensor merging (e.g., Mixtral with grouped experts)
model_type = getattr(getattr(model_state.model[0], "config", None), "model_type", None)
has_state_dict_adapter = hasattr(model_state.model[0], "state_dict_adapter")

# For models that need tensor merging and don't have an adapter, try using transformers' conversion
if is_init_step and model_type and requires_tensor_merging(model_type) and not has_state_dict_adapter:
converted_state_dict = _convert_checkpoint_with_transformers(model_state.model[0], model_path, key_mapping)
if converted_state_dict is not None:
# Load using full_state_dict=True to properly convert tensors to DTensors for FSDP
_load_full_state_dict_into_model(model_state.model, converted_state_dict)
return

# Standard loading path
state_dict = model_state.state_dict()
storage_reader = self._get_storage_reader(model_path, key_mapping, is_init_step=is_init_step)

Expand All @@ -310,7 +329,6 @@ def load_model(

state_dict = self._do_load(state_dict, model_path, storage_reader, is_init_step=is_init_step)

has_state_dict_adapter = hasattr(model_state.model[0], "state_dict_adapter")
state_dict = _maybe_adapt_state_dict_from_hf(model_state.model[0], state_dict, moe_mesh=self.moe_mesh)
model_state.load_state_dict(state_dict, strict=not (len(model_state.model) > 1 or has_state_dict_adapter))

Expand Down Expand Up @@ -366,13 +384,17 @@ def load_base_model(

if load_base_model:
assert model_name is not None, "model_name is required when loading base model"
# Get combined key mapping from model attribute and model-type specific conversions
model_type = getattr(getattr(model, "config", None), "model_type", None)
model_key_mapping = getattr(model, "_checkpoint_conversion_mapping", None)
key_mapping = get_combined_key_mapping(model_type, model_key_mapping)
self.load_model(
model,
model_path=model_name
if os.path.exists(model_name)
else get_safetensors_index_path(root_dir, model_name),
is_init_step=True,
key_mapping=getattr(model, "_checkpoint_conversion_mapping", None),
key_mapping=key_mapping,
)

is_tied_lm_head = is_tied_word_embeddings(model)
Expand Down Expand Up @@ -635,7 +657,8 @@ def _get_original_model_path(self, model_state: ModelState) -> str | None:
return None
pretrained_model_name_or_path = getattr(model_state.model[0], "name_or_path")
return get_safetensors_index_path(
getattr(self.config, "original_model_root_dir", None) or TRANSFORMERS_CACHE, pretrained_model_name_or_path
getattr(self.config, "original_model_root_dir", hf_constants.HF_HOME),
pretrained_model_name_or_path,
)


Expand Down Expand Up @@ -847,6 +870,149 @@ def compute_should_use_set_data(tensor, tensor_applied):
return module


def _load_full_state_dict_into_model(
model_parts: list[nn.Module],
state_dict: dict[str, torch.Tensor],
) -> None:
"""
Load a full (non-sharded) state dict into a potentially FSDP-wrapped model.

Uses PyTorch's set_model_state_dict with full_state_dict=True to properly
shard the tensors when loading into DTensors.

Args:
model_parts: List of model parts (for pipeline parallelism)
state_dict: Full state dict with regular tensors
"""
from functools import partial

from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict

# Use full_state_dict=True to tell PyTorch this is a complete, non-sharded state dict
# It will properly shard the tensors to match the model's DTensor layout
options = StateDictOptions(
strict=False,
full_state_dict=True, # Key: indicates state_dict contains full (non-sharded) tensors
broadcast_from_rank0=True, # Broadcast from rank 0 to other ranks
)

func = partial(set_model_state_dict, model_state_dict=state_dict, options=options)
list(map(func, model_parts))


def _convert_checkpoint_with_transformers(
model: nn.Module,
model_path: str,
key_mapping: Optional[dict[str, str]] = None,
) -> Optional[dict[str, torch.Tensor]]:
"""
Convert a checkpoint using transformers' conversion mapping for models that need tensor merging.

This handles MoE models like Mixtral where the checkpoint has individual expert weights
but the model uses grouped expert tensors. The transformers library's WeightConverter
operations handle the tensor merging (MergeModulelist, Concatenate).

This function converts the state dict WITHOUT loading it into the model, so it can be
used with FSDP-aware loading mechanisms.

Args:
model: The model (used to get conversion mapping and target keys).
model_path: Path to the HuggingFace checkpoint directory.
key_mapping: Optional additional key mapping.

Returns:
Converted state dict ready for loading, or None if conversion failed.
"""
try:
from copy import deepcopy

from safetensors import safe_open
from transformers.conversion_mapping import get_model_conversion_mapping
from transformers.core_model_loading import (
WeightConverter,
WeightRenaming,
dot_natural_key,
rename_source_key,
)
except ImportError:
logging.warning(
"transformers library with conversion_mapping not available. "
"Cannot use transformers' WeightConverter for tensor merging."
)
return None

try:
# Get the weight conversion mapping from transformers
weight_mapping = get_model_conversion_mapping(model, key_mapping=key_mapping, add_legacy=True)
if not weight_mapping:
logging.warning(
f"No conversion mapping found for model type {getattr(model.config, 'model_type', 'unknown')}"
)
return None

# Load the safetensors files
safetensors_files = glob.glob(os.path.join(model_path, "*.safetensors"))
if not safetensors_files:
logging.warning(f"No safetensors files found in {model_path}")
return None

# Load checkpoint state dict
checkpoint_state_dict = {}
for sf_path in safetensors_files:
with safe_open(sf_path, framework="pt", device="cpu") as f:
for key in f.keys():
checkpoint_state_dict[key] = f.get_tensor(key)

# Separate renamings and converters
renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)]
converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)]
pattern_to_converter = {k: converter for converter in converters for k in converter.source_patterns}

# Process checkpoint keys and apply conversions
converted_state_dict = {}
param_name_to_mapping: dict[str, WeightRenaming | WeightConverter] = {}

# Sort by key for consistent ordering
sorted_items = sorted(checkpoint_state_dict.items(), key=lambda kv: dot_natural_key(kv[0]))

for original_key, tensor in sorted_items:
# Rename the key
renamed_key, source_pattern = rename_source_key(original_key, renamings, converters)

# Check if this needs conversion
if source_pattern is not None:
# This key is part of a WeightConverter operation
new_converter = deepcopy(pattern_to_converter[source_pattern])
mapping = param_name_to_mapping.setdefault(renamed_key, new_converter)
mapping.add_tensor(renamed_key, original_key, source_pattern, tensor)
else:
# Simple rename or pass-through
mapping = param_name_to_mapping.setdefault(renamed_key, WeightRenaming(original_key, renamed_key))
mapping.add_tensor(renamed_key, original_key, original_key, tensor)

# Now apply all the conversions
for first_param_name, mapping in param_name_to_mapping.items():
try:
realized_value, _ = mapping.convert(first_param_name, model=model, config=model.config)
for target_name, param in realized_value.items():
param = param[0] if isinstance(param, list) else param
converted_state_dict[target_name] = param
mapping.reset()
except Exception as e:
logging.warning(f"Conversion failed for {first_param_name}: {e}")
continue

logging.info(f"Converted {len(converted_state_dict)} keys using transformers conversion mapping")
return converted_state_dict

except Exception as e:
logging.warning(f"Failed to convert checkpoint with transformers: {e}")
import traceback

traceback.print_exc()
return None


def _maybe_adapt_state_dict_to_hf(
model_part: nn.Module, state_dict: dict[str, torch.Tensor], quantization: bool = False
) -> dict[str, torch.Tensor]:
Expand Down
Loading
Loading