diff --git a/tensorrt_llm/inputs/registry.py b/tensorrt_llm/inputs/registry.py index 6ac3d4ee657..7165b392d55 100644 --- a/tensorrt_llm/inputs/registry.py +++ b/tensorrt_llm/inputs/registry.py @@ -448,20 +448,40 @@ def wrapper(model_cls: N) -> N: return wrapper -def create_input_processor(model_path_or_dir: str, tokenizer): - """ - Create an input processor for a specific model. +def create_input_processor( + model_path_or_dir: str, + tokenizer, + checkpoint_format: Optional[str] = "HF", +) -> InputProcessor: + """Create an input processor for a specific model. + + Args: + model_path_or_dir: Path or repo id used to locate pretrained config/tokenizer. + tokenizer: Tokenizer instance. + checkpoint_format: Checkpoint format identifier. "HF" uses Hugging Face-style + config loading; any other value skips HF config loading. Default is "HF". + + Returns: + An InputProcessor implementation (model-specific if registered; otherwise DefaultInputProcessor). """ from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models import get_model_architecture model_config = None - try: - config = ModelConfig.from_pretrained(model_path_or_dir, - trust_remote_code=True) - model_config = config.pretrained_config - except (ValueError, EnvironmentError): - config = None + + if checkpoint_format == "HF": + try: + config = ModelConfig.from_pretrained(model_path_or_dir, + trust_remote_code=True) + model_config = config.pretrained_config + except (ValueError, EnvironmentError) as e: + config = None + logger.debug( + f"Unable to load HF config from {model_path_or_dir}: {e}. Falling back." + ) + else: + logger.debug( + f"checkpoint_format={checkpoint_format}; skipping HF config load.") if model_config is not None: try: diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index c9a7aed32b3..2bbb0e2134a 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -1036,8 +1036,10 @@ def _build_model(self): # Multimodal special handling: # 1. Default load_tokenizer may fail because MM has different tokenizer configuration. Hence we initialize it inside input processor # 2. May need to modify model weights for MM (e.g., resize vocab embedding). We must do such operation via input processor's __init__ + checkpoint_format = getattr(self.args, "checkpoint_format", None) self.input_processor = create_input_processor(self._hf_model_dir, - self.tokenizer) + self.tokenizer, + checkpoint_format) self._tokenizer = self.input_processor.tokenizer # TODO: revisit gather_context_logits diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 35d02350e9c..8cb0d7e0bc7 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2518,7 +2518,13 @@ class TorchLlmArgs(BaseLlmArgs): status="beta") checkpoint_loader: Optional[object] = Field( default=None, - description="The checkpoint loader to use for this LLM instance.", + description= + "The checkpoint loader to use for this LLM instance. You may use a custom checkpoint loader by subclassing " + "`BaseCheckpointLoader` and providing an instance of the subclass here to load weights from a custom " + "checkpoint format.\n" + "If neither checkpoint_format nor checkpoint_loader are provided, checkpoint_format will be set to HF " + "and the default HfCheckpointLoader will be used.\n" + "If checkpoint_format and checkpoint_loader are both provided, checkpoint_loader will be ignored.", json_schema_extra={ "type": "Optional[tensorrt_llm._torch.models.checkpoints.BaseCheckpointLoader]" @@ -2528,7 +2534,12 @@ class TorchLlmArgs(BaseLlmArgs): checkpoint_format: Optional[str] = Field( default=None, - description="The format of the provided checkpoint.", + description= + "The format of the provided checkpoint. You may use a custom checkpoint format by subclassing " + "`BaseCheckpointLoader` and registering it with `register_checkpoint_loader`.\n" + "If neither checkpoint_format nor checkpoint_loader are provided, checkpoint_format will be set to HF " + "and the default HfCheckpointLoader will be used.\n" + "If checkpoint_format and checkpoint_loader are both provided, checkpoint_loader will be ignored.", status="prototype", ) diff --git a/tensorrt_llm/llmapi/mm_encoder.py b/tensorrt_llm/llmapi/mm_encoder.py index af0f031fc02..8553d4678ea 100644 --- a/tensorrt_llm/llmapi/mm_encoder.py +++ b/tensorrt_llm/llmapi/mm_encoder.py @@ -51,8 +51,10 @@ def _build_model(self): # Multimodal special handling: # 1. Default load_tokenizer may fail because MM has different tokenizer configuration. Hence we initialize it inside input processor # 2. May need to modify model weights for MM (e.g., resize vocab embedding). We must do such operation via input processor's __init__ + checkpoint_format = getattr(self.args, "checkpoint_format", None) self.input_processor = create_input_processor(self._hf_model_dir, - self.tokenizer) + self.tokenizer, + checkpoint_format) self._tokenizer = self.input_processor.tokenizer assert isinstance(self.args, TorchLlmArgs)