Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/source/features/checkpoint-loading.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ The `BaseCheckpointLoader` is the central base interface for all checkpoint load

**Key Methods:**
- `load_config(checkpoint_dir, **kwargs)`: Loads and returns a `ModelConfig` object
- `load_weights(checkpoint_dir, **kwargs)`: Loads and returns a dictionary of weights
- `load_weights(checkpoint_dir, mapping, **kwargs)`: Loads and returns a dictionary of weights
- `get_initialized_weight_mapper(model, config)`: Returns a runtime initialized weight mapper for the model
- `cleanup()`: Releases resources and cleans up internal state

Expand Down Expand Up @@ -63,7 +63,7 @@ Handles the loading of model weights from storage:
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader

class CustomWeightLoader(BaseWeightLoader):
def load_weights(self, checkpoint_dir: str) -> dict[str, Any]:
def load_weights(self, checkpoint_dir: str, mapping: Mapping) -> dict[str, Any]:
# Load weights from your custom format
# Return a dictionary mapping parameter names to tensors
return weights_dict
Expand Down Expand Up @@ -186,11 +186,12 @@ from tensorrt_llm._torch.models.modeling_utils import register_checkpoint_weight

@register_checkpoint_weight_loader("CUSTOM_FORMAT")
class CustomWeightLoader(BaseWeightLoader):
def load_weights(self, checkpoint_dir: str, **kwargs) -> dict[str, Any]:
def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[str, Any]:
"""
Load weights from your custom format.
Args:
checkpoint_dir: Directory containing checkpoint files
mapping: A mapping object containing the distributed configuration.
**kwargs: Additional loading parameters
Returns:
Dictionary mapping parameter names to tensors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
BaseWeightMapper
from tensorrt_llm._torch.models.modeling_utils import \
CHECKPOINT_LOADER_FORMAT_DEFAULT_MAPPING
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping


class BaseCheckpointLoader(ABC):
Expand Down Expand Up @@ -51,10 +53,17 @@ def checkpoint_format(self) -> str:
...

def load_config(self, checkpoint_dir: str, **kwargs) -> ModelConfig:
logger.debug(f"Loading config from {checkpoint_dir}")
return self.config_loader.load(checkpoint_dir, **kwargs)

def load_weights(self, checkpoint_dir: str, **kwargs) -> dict[str, Any]:
return self.weight_loader.load_weights(checkpoint_dir, **kwargs)
def load_weights(self, checkpoint_dir: str, mapping: Mapping,
**kwargs) -> dict[str, Any]:
logger.debug(
f"Loading weights from {checkpoint_dir} with mapping {mapping.to_dict()}"
)
return self.weight_loader.load_weights(checkpoint_dir,
mapping=mapping,
**kwargs)

@classmethod
def get(cls, checkpoint_format: str, **kwargs) -> "BaseCheckpointLoader":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from abc import ABC, abstractmethod
from typing import Any

from tensorrt_llm.mapping import Mapping


class BaseWeightLoader(ABC):

@abstractmethod
def load_weights(self, checkpoint_dir: str) -> dict[str, Any]:
def load_weights(self, checkpoint_dir: str,
mapping: Mapping) -> dict[str, Any]:
"""
Loads weights from a checkpoint directory.

Args:
checkpoint_dir: A path to the checkpoint directory.
mapping: A mapping object containing the distributed configuration.

Returns:
A dictionary where keys are tensor names and values are the tensors.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from tensorrt_llm._utils import (local_mpi_barrier, local_mpi_rank,
local_mpi_size)
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping


@register_checkpoint_weight_loader("HF")
Expand All @@ -24,7 +25,8 @@ class HfWeightLoader(BaseWeightLoader):
Loads weights from SafeTensors/bin/pth files.
"""

def load_weights(self, checkpoint_dir: str) -> dict[str, Any]:
def load_weights(self, checkpoint_dir: str,
mapping: Mapping) -> dict[str, Any]:
weight_files = glob.glob(f"{checkpoint_dir}/*.safetensors")
# Some model checkpoint directories contain not only the sharded safetensors, but one
# consolidated tensor. In the presence of both, we favor the former, as there really is no need
Expand Down
8 changes: 5 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,10 @@ def init_meta_tensor(t: torch.Tensor):
if load_format == LoadFormat.AUTO:
if hasattr(model, 'llm_checkpoint_dir'):
weights = checkpoint_loader.load_weights(
model.llm_checkpoint_dir)
model.llm_checkpoint_dir, mapping=self.mapping)
else:
weights = checkpoint_loader.load_weights(checkpoint_dir)
weights = checkpoint_loader.load_weights(
checkpoint_dir, mapping=self.mapping)

self.weight_mapper = checkpoint_loader.get_initialized_weight_mapper(
model, config)
Expand All @@ -277,7 +278,8 @@ def init_meta_tensor(t: torch.Tensor):
if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights(
):
weights = checkpoint_loader.load_weights(
self.spec_config.speculative_model_dir)
self.spec_config.speculative_model_dir,
mapping=self.mapping)

draft_model_arch = model.draft_config.pretrained_config.architectures[
0]
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_a10.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ l0_a10:
# NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no
# test list either).
- unittest/_torch/models/checkpoints/hf/test_weight_loader.py
- unittest/_torch/models/checkpoints/hf/test_checkpoint_loader.py
- unittest/others/test_time_breakdown.py
- unittest/disaggregated/test_disagg_utils.py
- unittest/disaggregated/test_router.py
Expand Down
111 changes: 111 additions & 0 deletions tests/unittest/_torch/models/checkpoints/hf/test_checkpoint_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import pathlib as _pl
from typing import Any, Optional

import pytest
import torch
from transformers.configuration_utils import PretrainedConfig

from tensorrt_llm import LLM, Mapping
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.checkpoints import HfCheckpointLoader
from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import BaseWeightMapper
from tensorrt_llm._torch.models.modeling_utils import register_auto_model


class DummyConfig(PretrainedConfig):
def __init__(self):
self.architectures: list[str] = ["DummyModel"]
self.dtype: torch.dtype = torch.float16
self.num_attention_heads: int = 16
self.hidden_size: int = 256
self.vocab_size: int = 1000
self.num_hidden_layers: int = 1


@register_auto_model("DummyModel")
class DummyModel(torch.nn.Module):
def __init__(self, model_config: ModelConfig):
super().__init__()
self.model_config = model_config

def infer_max_seq_len(self):
return 2048

@property
def config(self):
return self.model_config.pretrained_config

def forward(self, *args, input_ids: torch.Tensor, **kwargs) -> torch.Tensor:
num_batch_tokens = input_ids.size(0)
vocab_size = self.config.vocab_size

# Logits: dummy values for testing
logits = torch.ones((num_batch_tokens, vocab_size), device="cuda") * 0.1

return {
"logits": logits,
}

def load_weights(
self,
weights: dict,
weight_mapper: Optional[BaseWeightMapper] = None,
skip_modules: list[str] = [],
):
pass


class DummyWeightLoader(BaseWeightLoader):
def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[str, Any]:
"""Load weights from your dummy format.
Args:
checkpoint_dir: Directory containing checkpoint files
mapping: Mapping object containing the distributed configuration
**kwargs: Additional loading parameters
Returns:
Dictionary mapping parameter names to tensors
"""

assert mapping is not None
assert isinstance(mapping, Mapping)
assert mapping.world_size == 1
assert mapping.rank == 0

weights = {}

return weights


class DummyConfigLoader(BaseConfigLoader):
def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig:
"""Load and parse configuration from your dummy format.
Args:
checkpoint_dir: Directory containing configuration files
**kwargs: Additional loading parameters
Returns:
ModelConfig object containing parsed configuration
"""
return ModelConfig(pretrained_config=DummyConfig())


def test_weight_loader_mapping():
"""Test that the mapping in weight loader is correct."""

# Create LLM with the provided model
with LLM(
model=_pl.Path("dummy_path"),
backend="pytorch",
cuda_graph_config=None,
checkpoint_loader=HfCheckpointLoader(
weight_loader=DummyWeightLoader(), config_loader=DummyConfigLoader()
),
):
pass


if __name__ == "__main__":
pytest.main([__file__])
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from tensorrt_llm._torch.models.checkpoints import HfWeightLoader
from tensorrt_llm.mapping import Mapping


class MyError(Exception):
Expand Down Expand Up @@ -69,7 +70,7 @@ def test_load_weights_ignores_consolidated_ckpt_when_sharded_ckpt_exists(
mock.patch.object(loader, "prefetch_files") as prefetch_files,
pytest.raises(MyError),
):
loader.load_weights(checkpoint_dir=str(checkpoint_dir))
loader.load_weights(checkpoint_dir=str(checkpoint_dir), mapping=Mapping())

prefetch_files.assert_called_once()
prefetched_files = prefetch_files.call_args[0][0]
Expand Down