Skip to content

Commit eb90a71

Browse files
committed
Address comments and fix test
Signed-off-by: Chang Liu (Enterprise Products) <[email protected]>
1 parent 19e70fb commit eb90a71

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

tensorrt_llm/disaggregated_params.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass
22
from typing import Any, Dict, List, Optional
33

4+
import numpy as np
5+
46
# isort: off
57
# needed before trying to import bindings to load tensorrt_libs
68
import tensorrt as trt # noqa
@@ -58,6 +60,17 @@ def get_request_type(self) -> tllme.RequestType:
5860
)
5961

6062
def __post_init__(self):
63+
if self.request_type is not None:
64+
self.request_type = self.request_type.lower()
65+
if self.request_type not in [
66+
"context_only",
67+
"generation_only",
68+
"context_and_generation",
69+
]:
70+
raise ValueError(
71+
f"Unknown request type: {self.request_type}. Must be context_only, generation_only or "
72+
"context_and_generation"
73+
)
6174
if self.multimodal_embedding_handles is not None:
6275
if self.multimodal_hashes is not None:
6376
# if mm hashes are provided, kvcache reuse can be enabled
@@ -69,8 +82,6 @@ def __post_init__(self):
6982
assert len(mm_hash) == 8, "mm_hash must be a list of 8 integers"
7083
assert all(isinstance(x, int) for x in mm_hash), "mm_hash must contain integers"
7184
else:
72-
import numpy as np
73-
7485
# if user did not provide mm embedding handles, kvcache reuse will be disabled
7586
assert len(self.multimodal_embedding_handles) > 0, (
7687
"multimodal_embedding_handles must be provided"

tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
import json
22
import os
3+
from pathlib import Path
34

45
import pytest
6+
from utils.llm_data import llm_models_root
57

68
from tensorrt_llm import MultimodalEncoder
79
from tensorrt_llm.inputs import default_multimodal_input_loader
810
from tensorrt_llm.llmapi import KvCacheConfig
911
from tensorrt_llm.llmapi.llm import LLM, SamplingParams
1012

13+
test_data_root = Path(
14+
os.path.join(llm_models_root(), "multimodals", "test_data"))
1115
example_images = [
12-
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png",
13-
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
14-
"https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg",
16+
str(test_data_root / "seashore.png"),
17+
str(test_data_root / "inpaint.png"),
18+
str(test_data_root / "61.jpg"),
1519
]
1620

1721

@@ -184,7 +188,8 @@ def test_multi_request_batch_chat(model_key, multimodal_model_config):
184188

185189
sampling_params = SamplingParams(max_tokens=max_tokens)
186190
kv_cache_config = KvCacheConfig(
187-
enable_block_reuse=True,
191+
enable_block_reuse=
192+
False, # Disable block reuse for output 1-1 matching check
188193
free_gpu_memory_fraction=free_gpu_memory_fraction,
189194
)
190195

0 commit comments

Comments
 (0)