Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions tensorrt_llm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,20 +448,40 @@ def wrapper(model_cls: N) -> N:
return wrapper


def create_input_processor(model_path_or_dir: str, tokenizer):
"""
Create an input processor for a specific model.
def create_input_processor(
model_path_or_dir: str,
tokenizer,
checkpoint_format: Optional[str] = "HF",
) -> InputProcessor:
"""Create an input processor for a specific model.

Args:
model_path_or_dir: Path or repo id used to locate pretrained config/tokenizer.
tokenizer: Tokenizer instance.
checkpoint_format: Checkpoint format identifier. "HF" uses Hugging Face-style
config loading; any other value skips HF config loading. Default is "HF".

Returns:
An InputProcessor implementation (model-specific if registered; otherwise DefaultInputProcessor).
"""
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models import get_model_architecture

model_config = None
try:
config = ModelConfig.from_pretrained(model_path_or_dir,
trust_remote_code=True)
model_config = config.pretrained_config
except (ValueError, EnvironmentError):
config = None

if checkpoint_format == "HF":
try:
config = ModelConfig.from_pretrained(model_path_or_dir,
trust_remote_code=True)
model_config = config.pretrained_config
except (ValueError, EnvironmentError) as e:
config = None
logger.debug(
f"Unable to load HF config from {model_path_or_dir}: {e}. Falling back."
)
else:
logger.debug(
f"checkpoint_format={checkpoint_format}; skipping HF config load.")

if model_config is not None:
try:
Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,8 +1036,10 @@ def _build_model(self):
# Multimodal special handling:
# 1. Default load_tokenizer may fail because MM has different tokenizer configuration. Hence we initialize it inside input processor
# 2. May need to modify model weights for MM (e.g., resize vocab embedding). We must do such operation via input processor's __init__
checkpoint_format = getattr(self.args, "checkpoint_format", None)
self.input_processor = create_input_processor(self._hf_model_dir,
self.tokenizer)
self.tokenizer,
checkpoint_format)
self._tokenizer = self.input_processor.tokenizer

# TODO: revisit gather_context_logits
Expand Down
15 changes: 13 additions & 2 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2518,7 +2518,13 @@ class TorchLlmArgs(BaseLlmArgs):
status="beta")
checkpoint_loader: Optional[object] = Field(
default=None,
description="The checkpoint loader to use for this LLM instance.",
description=
"The checkpoint loader to use for this LLM instance. You may use a custom checkpoint loader by subclassing "
"`BaseCheckpointLoader` and providing an instance of the subclass here to load weights from a custom "
"checkpoint format.\n"
"If neither checkpoint_format nor checkpoint_loader are provided, checkpoint_format will be set to HF "
"and the default HfCheckpointLoader will be used.\n"
"If checkpoint_format and checkpoint_loader are both provided, checkpoint_loader will be ignored.",
json_schema_extra={
"type":
"Optional[tensorrt_llm._torch.models.checkpoints.BaseCheckpointLoader]"
Expand All @@ -2528,7 +2534,12 @@ class TorchLlmArgs(BaseLlmArgs):

checkpoint_format: Optional[str] = Field(
default=None,
description="The format of the provided checkpoint.",
description=
"The format of the provided checkpoint. You may use a custom checkpoint format by subclassing "
"`BaseCheckpointLoader` and registering it with `register_checkpoint_loader`.\n"
"If neither checkpoint_format nor checkpoint_loader are provided, checkpoint_format will be set to HF "
"and the default HfCheckpointLoader will be used.\n"
"If checkpoint_format and checkpoint_loader are both provided, checkpoint_loader will be ignored.",
status="prototype",
)

Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/llmapi/mm_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ def _build_model(self):
# Multimodal special handling:
# 1. Default load_tokenizer may fail because MM has different tokenizer configuration. Hence we initialize it inside input processor
# 2. May need to modify model weights for MM (e.g., resize vocab embedding). We must do such operation via input processor's __init__
checkpoint_format = getattr(self.args, "checkpoint_format", None)
self.input_processor = create_input_processor(self._hf_model_dir,
self.tokenizer)
self.tokenizer,
checkpoint_format)
self._tokenizer = self.input_processor.tokenizer

assert isinstance(self.args, TorchLlmArgs)
Expand Down