diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index 1db70fdbfb4..160ce302b87 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -2,8 +2,10 @@ import os from typing import Any, Dict, List, Optional, Tuple, Union +import numpy as np import torch import torch.nn as nn +from PIL import Image from torch.nn import functional as F from transformers import (AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel) @@ -25,9 +27,11 @@ from tensorrt_llm.inputs.multimodal import MultimodalParams from ..._utils import nvtx_range, nvtx_range_debug -from ...inputs import (BaseMultimodalInputProcessor, ExtraProcessedInputs, - InputProcessor, MultimodalPlaceholderMetadata, +from ...inputs import (BaseDummyInputsBuilder, BaseMultimodalInputProcessor, + ExtraProcessedInputs, InputProcessor, + MultimodalPlaceholderMetadata, MultimodalPlaceholderPlacement, TextPrompt, + default_multimodal_input_loader, register_input_processor) from ...logger import logger from ...sampling_params import SamplingParams @@ -83,7 +87,8 @@ def process_weights(weights: Dict, return filtered_weights -class Qwen2VLInputProcessorBase(BaseMultimodalInputProcessor, InputProcessor): +class Qwen2VLInputProcessorBase(BaseDummyInputsBuilder, + BaseMultimodalInputProcessor, InputProcessor): def __init__(self, model_path: str, @@ -91,8 +96,10 @@ def __init__(self, tokenizer: AutoTokenizer, trust_remote_code: bool = True): self.model_config = model_config - self.tokenizer = tokenizer + self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained( + model_path) self.use_fast = True + self.model_path = model_path self.processor = AutoProcessor.from_pretrained( model_path, use_fast=self.use_fast, @@ -277,6 +284,81 @@ def get_rope_index( mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + def get_dummy_text(self, input_seq_len: int) -> str: + ids = np.random.randint( + low=0, + high=int( + self.model_config.vocab_size), # high is exclusive in NumPy + size=input_seq_len, + ).tolist() + return self.tokenizer.decode(ids, skip_special_tokens=True) + + def get_dummy_image(self, max_width: int, max_height: int): + image = Image.new("RGB", (max_width, max_height), color=255) + return image + + def get_dummy_prompt(self, input_seq_len: int): + text = "" + # we use the max resolution as starting point + img_max_dim = 3584 + image = self.get_dummy_image(max_width=img_max_dim, + max_height=img_max_dim) + + test_mm_prompt = default_multimodal_input_loader( + tokenizer=self.tokenizer, + model_dir=self.model_path, + model_type=self.model_config.model_type, + modality="image", + prompts=[text], + media=[[image]], + image_data_format="pt")[0] + + prompt_token_ids_single_img, _ = self(test_mm_prompt, None) + + # if the max img resolution results in a number of tokens greater then + # input_seq_len, we keep lowering the resolution such as to find the + # max resolution such as it does not exceed the input_seq_len + while len(prompt_token_ids_single_img) > input_seq_len: + # reduce img resolution + img_max_dim = img_max_dim >> 1 + + image = self.get_dummy_image(max_width=img_max_dim, + max_height=img_max_dim) + + test_mm_prompt = default_multimodal_input_loader( + tokenizer=self.tokenizer, + model_dir=self.model_path, + model_type=self.model_config.model_type, + modality="image", + prompts=[text], + media=[[image]], + image_data_format="pt")[0] + + prompt_token_ids_single_img, _ = self(test_mm_prompt, None) + + len_prompt_tokens_ids = len(prompt_token_ids_single_img) + # There are corner cases where if we strictly try to generate a text based + # on how many tokens we need to complete the input_seq_len, the output of + # default_multimodal_input_loader may give more tokens then the input_seq_len and this + # can lead to errors. + # That is why we try to clip the variable text_token_left to a lower threshold + # but close enough to the actual input_seq_len + text_generation_perc_threshold = 0.95 + text_token_left = int((input_seq_len - len_prompt_tokens_ids) * + text_generation_perc_threshold) + + if text_token_left > 0: + text = self.get_dummy_text(text_token_left) + + return default_multimodal_input_loader( + tokenizer=self.tokenizer, + model_dir=self.model_path, + model_type=self.model_config.model_type, + modality="image", + prompts=[text], + media=[[image]], + image_data_format="pt")[0] + def _preprocess(self, text: dict[str, any], mm_data: dict[str, any], mm_processor_kwargs: Dict[str, Any]): images = mm_data.get("image") @@ -790,7 +872,7 @@ def __init__( **kwargs, ) -> None: model_config.pretrained_config.rope_scaling['type'] = 'mrope' - + self.original_arch = model_config.pretrained_config.architectures[0] # NOTE: Setting disable_fuse_rope to True to do mrope fusion in the model engine by pre-computing rotary_cos_sin in the model engine disabble_fuse_rope = kwargs.get('disable_fuse_rope', False) model_config.pretrained_config.text_config.disable_fuse_rope = disabble_fuse_rope @@ -979,7 +1061,7 @@ def multimodal_data_device_paths(self) -> List[str]: return [ "image.pixel_values", "image.image_grid_thw", "video.pixel_values_videos", "video.video_grid_thw", - "multimodal_embedding", "mrope_config.mrope_position_ids" + "multimodal_embedding" ] def load_weights(self, weights, weight_mapper: BaseWeightMapper): @@ -1032,12 +1114,12 @@ def multimodal_data_device_paths(self) -> List[str]: return [ "image.pixel_values", "video.pixel_values_videos", "image.image_grid_thw", "video.video_grid_thw", - "multimodal_embedding", "mrope_config.mrope_position_ids" + "multimodal_embedding" ] else: return [ "image.pixel_values", "video.pixel_values_videos", - "multimodal_embedding", "mrope_config.mrope_position_ids" + "multimodal_embedding" ] def load_weights(self, weights, weight_mapper: BaseWeightMapper): diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 2fa2eeb4768..4099e0e104d 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -7,6 +7,8 @@ import tensorrt_llm import tensorrt_llm.bindings.executor as trtllm from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.models.modeling_utils import \ + MODEL_CLASS_VISION_ENCODER_MAPPING from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str from tensorrt_llm.bindings.executor import DecodingMode from tensorrt_llm.llmapi.llm_args import (EagleDecodingConfig, KvCacheConfig, @@ -76,6 +78,7 @@ def __init__( pytorch_backend_config: PyTorchConfig, speculative_config: SpeculativeConfig, sparse_attention_config: SparseAttentionConfig, + profiling_stage_data: Optional[dict], ): self._model_engine = model_engine self._draft_model_engine = draft_model_engine @@ -93,6 +96,7 @@ def __init__( self._max_batch_size = max_batch_size self._net_max_seq_len = net_max_seq_len self._dummy_reqs = None + self._profiling_stage_data = profiling_stage_data self._kv_cache_manager_cls = get_kv_cache_manager_cls( model_engine.model.model_config) @@ -133,13 +137,76 @@ def _cal_max_memory(self, peak_memory, total_gpu_memory, fraction, f", tmp kv_mem { (allocated_bytes) / (GB):.2f} GiB") return int(available_kv_mem) + def _create_dummy_mm_context_request( + self, input_seq_len: int) -> List[trtllm.Request]: + requests = [] + if isinstance( + self._profiling_stage_data, + dict) and not self._profiling_stage_data.get("enable_mm_reqs"): + return requests + + input_processor = self._model_engine.input_processor + if not (hasattr(input_processor, "get_dummy_prompt")): + logger.warning("The input processor of the model does not have the method [get_dummy_prompt] implemented." \ + "Profiling with the default input dummy context request. This may not take into account the memory consumption of " \ + "the image encoder") + return requests + prompt = input_processor.get_dummy_prompt(input_seq_len) + + prompt_token_ids, extra_processed_inputs = self._model_engine.input_processor_with_hash( + prompt, None) + + multimodal_input = extra_processed_inputs.get('multimodal_input') + multimodal_data = extra_processed_inputs.get('multimodal_data') + + max_num_tokens = len(prompt_token_ids) + assert max_num_tokens > 0, "the length of the prompt of the dummy mm req is less than or equal to 0" + remaining_tokens = min(max_num_tokens, input_seq_len) + if remaining_tokens > input_seq_len: + logger.warning(f"Profiling with multimedia prompt which contains more tokens than the allowed input_seq_len. " \ + f"Multimodal prompt has {remaining_tokens} while the input_seq_len is: {input_seq_len}") + while remaining_tokens > 0: + req_mm_input = trtllm.MultimodalInput( + multimodal_hashes=multimodal_input.multimodal_hashes, + multimodal_positions=multimodal_input.multimodal_positions, + multimodal_lengths=multimodal_input.multimodal_lengths + ) if multimodal_input else None + request = trtllm.Request(prompt_token_ids, + max_tokens=1, + streaming=False, + sampling_config=trtllm.SamplingConfig( + beam_width=self._max_beam_width, ), + output_config=trtllm.OutputConfig(), + end_id=-1, + multimodal_input=req_mm_input) + # TODO: + # create_input_processor_with_hash shouldn’t be required during profiling, + # but is temporarily needed due to the multimodal input dependency for chunked prefill + request.py_multimodal_data = multimodal_data + remaining_tokens -= max_num_tokens + requests.append(request) + + if self._mapping.enable_attention_dp: + requests = requests * self._mapping.tp_size + + return requests + def _create_dummy_context_requests( self, input_seq_len: int) -> List[trtllm.Request]: + requests = [] + if hasattr(self._model_engine.model, + "original_arch") and MODEL_CLASS_VISION_ENCODER_MAPPING.get( + self._model_engine.model.original_arch, None): + input_seq_len = min(self._max_num_tokens, input_seq_len) + requests = self._create_dummy_mm_context_request(input_seq_len) + # if succeed profiling with multimodal requests then return, otherwise profile + # with default case + if requests: + return requests vocab_size = self._model_engine.model.model_config.pretrained_config.vocab_size max_num_tokens = self._max_num_tokens max_beam_width = self._max_beam_width - requests = [] input_seq_len = min(max_num_tokens, input_seq_len) remaining_tokens = max_num_tokens while remaining_tokens > 0: @@ -349,6 +416,8 @@ def configure_kv_cache_capacity(self, py_executor: PyExecutor) -> None: ) # set max_gpu_total_bytes self._kv_cache_config.max_gpu_total_bytes = kv_cache_max_memory + if isinstance(self._profiling_stage_data, dict): + self._profiling_stage_data["activation_bytes"] = activation_bytes # ---------------------------handle max_gpu_total_bytes--------------------------------- def _create_kv_cache_manager( diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 53d98fc83d8..8d650469a05 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -18,6 +18,8 @@ torch_dtype_to_str, trace_func) from tensorrt_llm.inputs.multimodal import (MultimodalParams, MultimodalRuntimeData) +from tensorrt_llm.inputs.registry import (create_input_processor, + create_input_processor_with_hash) from tensorrt_llm.logger import logger from tensorrt_llm.lora_helper import LoraConfig from tensorrt_llm.lora_manager import LoraModelConfig @@ -171,7 +173,9 @@ def __init__( self.attn_runtime_features = attn_runtime_features or AttentionRuntimeFeatures( ) - + self.input_processor = create_input_processor(model_path, None) + self.input_processor_with_hash = create_input_processor_with_hash( + self.input_processor) if model is None: loader = ModelLoader( pytorch_backend_config=pytorch_backend_config, diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index ff580751976..c8aafeff429 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -207,6 +207,7 @@ def create_py_executor( tokenizer: Optional[TokenizerBase] = None, lora_config: Optional[LoraConfig] = None, kv_connector_config: Optional[KvCacheConnectorConfig] = None, + profiling_stage_data: Optional[dict] = None, ) -> PyExecutor: garbage_collection_gen0_threshold = llm_args.garbage_collection_gen0_threshold @@ -570,6 +571,7 @@ def drafting_loop_wrapper(model): kv_cache_config=kv_cache_config, pytorch_backend_config=pytorch_backend_config, speculative_config=spec_config, + profiling_stage_data=profiling_stage_data, sparse_attention_config=sparse_attention_config, ) estimating_kv_cache = kv_cache_creator.try_prepare_estimation() diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index 7da0930264b..9842c6d6c98 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -87,6 +87,7 @@ def get_llm_args(model: str, trust_remote_code: bool = False, reasoning_parser: Optional[str] = None, fail_fast_on_attention_window_too_large: bool = False, + enable_chunked_prefill: bool = False, **llm_args_extra_dict: Any): if gpus_per_node is None: @@ -109,44 +110,27 @@ def get_llm_args(model: str, dynamic_batch_config=dynamic_batch_config, ) llm_args = { - "model": - model, - "scheduler_config": - scheduler_config, - "tokenizer": - tokenizer, - "tensor_parallel_size": - tensor_parallel_size, - "pipeline_parallel_size": - pipeline_parallel_size, - "moe_expert_parallel_size": - moe_expert_parallel_size, - "gpus_per_node": - gpus_per_node, - "trust_remote_code": - trust_remote_code, - "build_config": - build_config, - "max_batch_size": - max_batch_size, - "max_num_tokens": - max_num_tokens, - "max_beam_width": - max_beam_width, - "max_seq_len": - max_seq_len, - "kv_cache_config": - kv_cache_config, - "backend": - backend, - "num_postprocess_workers": - num_postprocess_workers, - "postprocess_tokenizer_dir": - tokenizer or model, - "reasoning_parser": - reasoning_parser, + "model": model, + "scheduler_config": scheduler_config, + "tokenizer": tokenizer, + "tensor_parallel_size": tensor_parallel_size, + "pipeline_parallel_size": pipeline_parallel_size, + "moe_expert_parallel_size": moe_expert_parallel_size, + "gpus_per_node": gpus_per_node, + "trust_remote_code": trust_remote_code, + "build_config": build_config, + "max_batch_size": max_batch_size, + "max_num_tokens": max_num_tokens, + "max_beam_width": max_beam_width, + "max_seq_len": max_seq_len, + "kv_cache_config": kv_cache_config, + "backend": backend, + "num_postprocess_workers": num_postprocess_workers, + "postprocess_tokenizer_dir": tokenizer or model, + "reasoning_parser": reasoning_parser, "fail_fast_on_attention_window_too_large": fail_fast_on_attention_window_too_large, + "enable_chunked_prefill": enable_chunked_prefill, } return llm_args, llm_args_extra_dict @@ -329,6 +313,10 @@ def convert(self, value: Any, param: Optional["click.Parameter"], help= "Exit with runtime error when attention window is too large to fit even a single sequence in the KV cache." ) +@click.option("--enable_chunked_prefill", + is_flag=True, + default=False, + help="Enable chunked prefill") def serve( model: str, tokenizer: Optional[str], host: str, port: int, log_level: str, backend: str, max_beam_width: int, max_batch_size: int, @@ -338,7 +326,8 @@ def serve( num_postprocess_workers: int, trust_remote_code: bool, extra_llm_api_options: Optional[str], reasoning_parser: Optional[str], metadata_server_config_file: Optional[str], server_role: Optional[str], - fail_fast_on_attention_window_too_large: bool): + fail_fast_on_attention_window_too_large: bool, + enable_chunked_prefill: bool): """Running an OpenAI API compatible server MODEL: model name | HF checkpoint path | TensorRT engine path @@ -363,7 +352,8 @@ def serve( trust_remote_code=trust_remote_code, reasoning_parser=reasoning_parser, fail_fast_on_attention_window_too_large= - fail_fast_on_attention_window_too_large) + fail_fast_on_attention_window_too_large, + enable_chunked_prefill=enable_chunked_prefill) llm_args_extra_dict = {} if extra_llm_api_options is not None: diff --git a/tensorrt_llm/inputs/__init__.py b/tensorrt_llm/inputs/__init__.py index d411731e519..1bf38ed5a7b 100644 --- a/tensorrt_llm/inputs/__init__.py +++ b/tensorrt_llm/inputs/__init__.py @@ -1,7 +1,8 @@ from .data import PromptInputs, TextPrompt, TokensPrompt, prompt_inputs from .multimodal import MultimodalInput -from .registry import (BaseMultimodalInputProcessor, ExtraProcessedInputs, - InputProcessor, MultimodalPlaceholderMetadata, +from .registry import (BaseDummyInputsBuilder, BaseMultimodalInputProcessor, + ExtraProcessedInputs, InputProcessor, + MultimodalPlaceholderMetadata, MultimodalPlaceholderPlacement, create_input_processor, create_input_processor_with_hash, register_input_processor, @@ -30,6 +31,7 @@ "register_input_processor", "support_multimodal_disaggregated", "ExtraProcessedInputs", + "BaseDummyInputsBuilder", "BaseMultimodalInputProcessor", "MultimodalPlaceholderMetadata", "MultimodalPlaceholderPlacement", diff --git a/tensorrt_llm/inputs/registry.py b/tensorrt_llm/inputs/registry.py index 99c984a77fb..b1c93ca5898 100644 --- a/tensorrt_llm/inputs/registry.py +++ b/tensorrt_llm/inputs/registry.py @@ -42,6 +42,16 @@ def __call__( ... +class BaseDummyInputsBuilder: + """ + Base class for generating dummy inputs. Specially for profiling + """ + + def get_dummy_prompt(self, input_seq_len: int): + raise NotImplementedError( + "Please ensure this method is implemented in your inherited class") + + class BaseMultimodalInputProcessor: """ Base class for multimodal input processors with default implementations diff --git a/tests/integration/test_lists/test-db/l0_a100.yml b/tests/integration/test_lists/test-db/l0_a100.yml index 5d8d5162c36..619bb57b275 100644 --- a/tests/integration/test_lists/test-db/l0_a100.yml +++ b/tests/integration/test_lists/test-db/l0_a100.yml @@ -15,6 +15,7 @@ l0_a100: tests: - unittest/llmapi/test_llm_pytorch.py - unittest/llmapi/test_mpi_session.py ISOLATION + - unittest/llmapi/test_memory_profiling.py # profile kvcache for vision encoder - unittest/trt/model_api/test_model_quantization.py # executor - unittest/executor/test_base_worker.py diff --git a/tests/unittest/llmapi/test_memory_profiling.py b/tests/unittest/llmapi/test_memory_profiling.py new file mode 100644 index 00000000000..57c668e9e15 --- /dev/null +++ b/tests/unittest/llmapi/test_memory_profiling.py @@ -0,0 +1,77 @@ +import pytest +import torch + +from tensorrt_llm._torch.pyexecutor.py_executor_creator import \ + create_py_executor +from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy, + DynamicBatchConfig, SchedulerConfig) +from tensorrt_llm.llmapi.llm_args import (CudaGraphConfig, KvCacheConfig, + TorchLlmArgs) + +# isort: off +from .test_llm import get_model_path +# isort: on + +pytestmark = pytest.mark.threadleak(enabled=False) + + +def test_profile_kvcache(): + kv_cache_config = KvCacheConfig(enable_block_reuse=False, + free_gpu_memory_fraction=0.9) + cuda_graph_config = CudaGraphConfig(max_batch_size=512) + VLM_MODEL = "Qwen2.5-VL-7B-Instruct" + VLM_MODEL_PATH = get_model_path(VLM_MODEL) + + build_config = BuildConfig(max_beam_width=1, max_num_tokens=16384) + dynamic_batch_config = DynamicBatchConfig( + enable_batch_size_tuning=True, + enable_max_num_tokens_tuning=False, + dynamic_batch_moving_average_window=128) + scheduler_config = SchedulerConfig( + capacity_scheduler_policy=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, + dynamic_batch_config=dynamic_batch_config, + ) + backend = "pytorch" + llm_args = { + "model": VLM_MODEL, + "scheduler_config": scheduler_config, + "tokenizer": None, + "tensor_parallel_size": 1, + "pipeline_parallel_size": 1, + "moe_expert_parallel_size": None, + "gpus_per_node": 1, + "trust_remote_code": False, + "build_config": build_config, + "max_batch_size": build_config.max_batch_size, + "max_num_tokens": build_config.max_num_tokens, + "max_beam_width": build_config.max_beam_width, + "max_seq_len": build_config.max_seq_len, + "kv_cache_config": kv_cache_config, + "backend": backend, + "num_postprocess_workers": 0, + "postprocess_tokenizer_dir": VLM_MODEL, + "reasoning_parser": None, + "fail_fast_on_attention_window_too_large": False, + "cuda_graph_config": cuda_graph_config, + } + + torchllm_args = TorchLlmArgs(**llm_args) + + profiling_data = {"enable_mm_reqs": True} + py_executor = create_py_executor(llm_args=torchllm_args, + checkpoint_dir=VLM_MODEL_PATH, + profiling_stage_data=profiling_data) + vlm_activation_bytes_with_mm_reqs = profiling_data["activation_bytes"] + py_executor.shutdown() + torch.cuda.empty_cache() + + profiling_data = {"enable_mm_reqs": False} + torchllm_args = TorchLlmArgs(**llm_args) + py_executor_2 = create_py_executor(llm_args=torchllm_args, + checkpoint_dir=VLM_MODEL_PATH, + profiling_stage_data=profiling_data) + vlm_activation_bytes_no_mm_reqs = profiling_data["activation_bytes"] + py_executor_2.shutdown() + torch.cuda.empty_cache() + + assert vlm_activation_bytes_with_mm_reqs > vlm_activation_bytes_no_mm_reqs, f"Activation bytes should be higher with mm reqs, but got {vlm_activation_bytes_with_mm_reqs} for mm reqs and {vlm_activation_bytes_no_mm_reqs} without mm reqs"