diff --git a/docs/source/features/checkpoint-loading.md b/docs/source/features/checkpoint-loading.md index 4a37ef76234..41699b49000 100644 --- a/docs/source/features/checkpoint-loading.md +++ b/docs/source/features/checkpoint-loading.md @@ -31,7 +31,7 @@ The `BaseCheckpointLoader` is the central base interface for all checkpoint load **Key Methods:** - `load_config(checkpoint_dir, **kwargs)`: Loads and returns a `ModelConfig` object -- `load_weights(checkpoint_dir, **kwargs)`: Loads and returns a dictionary of weights +- `load_weights(checkpoint_dir, mapping, **kwargs)`: Loads and returns a dictionary of weights - `get_initialized_weight_mapper(model, config)`: Returns a runtime initialized weight mapper for the model - `cleanup()`: Releases resources and cleans up internal state @@ -63,7 +63,7 @@ Handles the loading of model weights from storage: from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader class CustomWeightLoader(BaseWeightLoader): - def load_weights(self, checkpoint_dir: str) -> dict[str, Any]: + def load_weights(self, checkpoint_dir: str, mapping: Mapping) -> dict[str, Any]: # Load weights from your custom format # Return a dictionary mapping parameter names to tensors return weights_dict @@ -186,11 +186,12 @@ from tensorrt_llm._torch.models.modeling_utils import register_checkpoint_weight @register_checkpoint_weight_loader("CUSTOM_FORMAT") class CustomWeightLoader(BaseWeightLoader): - def load_weights(self, checkpoint_dir: str, **kwargs) -> dict[str, Any]: + def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[str, Any]: """ Load weights from your custom format. Args: checkpoint_dir: Directory containing checkpoint files + mapping: A mapping object containing the distributed configuration. **kwargs: Additional loading parameters Returns: Dictionary mapping parameter names to tensors diff --git a/tensorrt_llm/_torch/models/checkpoints/base_checkpoint_loader.py b/tensorrt_llm/_torch/models/checkpoints/base_checkpoint_loader.py index 10130b087e4..2567571f585 100644 --- a/tensorrt_llm/_torch/models/checkpoints/base_checkpoint_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/base_checkpoint_loader.py @@ -14,6 +14,8 @@ BaseWeightMapper from tensorrt_llm._torch.models.modeling_utils import \ CHECKPOINT_LOADER_FORMAT_DEFAULT_MAPPING +from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import Mapping class BaseCheckpointLoader(ABC): @@ -51,10 +53,17 @@ def checkpoint_format(self) -> str: ... def load_config(self, checkpoint_dir: str, **kwargs) -> ModelConfig: + logger.debug(f"Loading config from {checkpoint_dir}") return self.config_loader.load(checkpoint_dir, **kwargs) - def load_weights(self, checkpoint_dir: str, **kwargs) -> dict[str, Any]: - return self.weight_loader.load_weights(checkpoint_dir, **kwargs) + def load_weights(self, checkpoint_dir: str, mapping: Mapping, + **kwargs) -> dict[str, Any]: + logger.debug( + f"Loading weights from {checkpoint_dir} with mapping {mapping.to_dict()}" + ) + return self.weight_loader.load_weights(checkpoint_dir, + mapping=mapping, + **kwargs) @classmethod def get(cls, checkpoint_format: str, **kwargs) -> "BaseCheckpointLoader": diff --git a/tensorrt_llm/_torch/models/checkpoints/base_weight_loader.py b/tensorrt_llm/_torch/models/checkpoints/base_weight_loader.py index c6c88d16bdc..2c829e027af 100644 --- a/tensorrt_llm/_torch/models/checkpoints/base_weight_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/base_weight_loader.py @@ -1,16 +1,20 @@ from abc import ABC, abstractmethod from typing import Any +from tensorrt_llm.mapping import Mapping + class BaseWeightLoader(ABC): @abstractmethod - def load_weights(self, checkpoint_dir: str) -> dict[str, Any]: + def load_weights(self, checkpoint_dir: str, + mapping: Mapping) -> dict[str, Any]: """ Loads weights from a checkpoint directory. Args: checkpoint_dir: A path to the checkpoint directory. + mapping: A mapping object containing the distributed configuration. Returns: A dictionary where keys are tensor names and values are the tensors. diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py index 7a631f45f6b..7c24f19ae73 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py @@ -16,6 +16,7 @@ from tensorrt_llm._utils import (local_mpi_barrier, local_mpi_rank, local_mpi_size) from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import Mapping @register_checkpoint_weight_loader("HF") @@ -24,7 +25,8 @@ class HfWeightLoader(BaseWeightLoader): Loads weights from SafeTensors/bin/pth files. """ - def load_weights(self, checkpoint_dir: str) -> dict[str, Any]: + def load_weights(self, checkpoint_dir: str, + mapping: Mapping) -> dict[str, Any]: weight_files = glob.glob(f"{checkpoint_dir}/*.safetensors") # Some model checkpoint directories contain not only the sharded safetensors, but one # consolidated tensor. In the presence of both, we favor the former, as there really is no need diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index 679d50615e8..1fb3993a2a3 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -265,9 +265,10 @@ def init_meta_tensor(t: torch.Tensor): if load_format == LoadFormat.AUTO: if hasattr(model, 'llm_checkpoint_dir'): weights = checkpoint_loader.load_weights( - model.llm_checkpoint_dir) + model.llm_checkpoint_dir, mapping=self.mapping) else: - weights = checkpoint_loader.load_weights(checkpoint_dir) + weights = checkpoint_loader.load_weights( + checkpoint_dir, mapping=self.mapping) self.weight_mapper = checkpoint_loader.get_initialized_weight_mapper( model, config) @@ -277,7 +278,8 @@ def init_meta_tensor(t: torch.Tensor): if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights( ): weights = checkpoint_loader.load_weights( - self.spec_config.speculative_model_dir) + self.spec_config.speculative_model_dir, + mapping=self.mapping) draft_model_arch = model.draft_config.pretrained_config.architectures[ 0] diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 591edb90da2..dfbe75d617d 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -23,6 +23,7 @@ l0_a10: # NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no # test list either). - unittest/_torch/models/checkpoints/hf/test_weight_loader.py + - unittest/_torch/models/checkpoints/hf/test_checkpoint_loader.py - unittest/others/test_time_breakdown.py - unittest/disaggregated/test_disagg_utils.py - unittest/disaggregated/test_router.py diff --git a/tests/unittest/_torch/models/checkpoints/hf/test_checkpoint_loader.py b/tests/unittest/_torch/models/checkpoints/hf/test_checkpoint_loader.py new file mode 100644 index 00000000000..6e1ab924a13 --- /dev/null +++ b/tests/unittest/_torch/models/checkpoints/hf/test_checkpoint_loader.py @@ -0,0 +1,111 @@ +import pathlib as _pl +from typing import Any, Optional + +import pytest +import torch +from transformers.configuration_utils import PretrainedConfig + +from tensorrt_llm import LLM, Mapping +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.models.checkpoints import HfCheckpointLoader +from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader +from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader +from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import BaseWeightMapper +from tensorrt_llm._torch.models.modeling_utils import register_auto_model + + +class DummyConfig(PretrainedConfig): + def __init__(self): + self.architectures: list[str] = ["DummyModel"] + self.dtype: torch.dtype = torch.float16 + self.num_attention_heads: int = 16 + self.hidden_size: int = 256 + self.vocab_size: int = 1000 + self.num_hidden_layers: int = 1 + + +@register_auto_model("DummyModel") +class DummyModel(torch.nn.Module): + def __init__(self, model_config: ModelConfig): + super().__init__() + self.model_config = model_config + + def infer_max_seq_len(self): + return 2048 + + @property + def config(self): + return self.model_config.pretrained_config + + def forward(self, *args, input_ids: torch.Tensor, **kwargs) -> torch.Tensor: + num_batch_tokens = input_ids.size(0) + vocab_size = self.config.vocab_size + + # Logits: dummy values for testing + logits = torch.ones((num_batch_tokens, vocab_size), device="cuda") * 0.1 + + return { + "logits": logits, + } + + def load_weights( + self, + weights: dict, + weight_mapper: Optional[BaseWeightMapper] = None, + skip_modules: list[str] = [], + ): + pass + + +class DummyWeightLoader(BaseWeightLoader): + def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[str, Any]: + """Load weights from your dummy format. + + Args: + checkpoint_dir: Directory containing checkpoint files + mapping: Mapping object containing the distributed configuration + **kwargs: Additional loading parameters + Returns: + Dictionary mapping parameter names to tensors + """ + + assert mapping is not None + assert isinstance(mapping, Mapping) + assert mapping.world_size == 1 + assert mapping.rank == 0 + + weights = {} + + return weights + + +class DummyConfigLoader(BaseConfigLoader): + def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig: + """Load and parse configuration from your dummy format. + + Args: + checkpoint_dir: Directory containing configuration files + **kwargs: Additional loading parameters + Returns: + ModelConfig object containing parsed configuration + """ + return ModelConfig(pretrained_config=DummyConfig()) + + +def test_weight_loader_mapping(): + """Test that the mapping in weight loader is correct.""" + + # Create LLM with the provided model + with LLM( + model=_pl.Path("dummy_path"), + backend="pytorch", + cuda_graph_config=None, + checkpoint_loader=HfCheckpointLoader( + weight_loader=DummyWeightLoader(), config_loader=DummyConfigLoader() + ), + ): + pass + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py b/tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py index 6447aa09a89..9989e2821ac 100644 --- a/tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py +++ b/tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py @@ -3,6 +3,7 @@ import pytest from tensorrt_llm._torch.models.checkpoints import HfWeightLoader +from tensorrt_llm.mapping import Mapping class MyError(Exception): @@ -69,7 +70,7 @@ def test_load_weights_ignores_consolidated_ckpt_when_sharded_ckpt_exists( mock.patch.object(loader, "prefetch_files") as prefetch_files, pytest.raises(MyError), ): - loader.load_weights(checkpoint_dir=str(checkpoint_dir)) + loader.load_weights(checkpoint_dir=str(checkpoint_dir), mapping=Mapping()) prefetch_files.assert_called_once() prefetched_files = prefetch_files.call_args[0][0]