Skip to content

Commit 3045ede

Browse files
committed
default unit test kv cache dtype
Signed-off-by: NickLucche <[email protected]>
1 parent 8cb76e4 commit 3045ede

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

vllm/attention/layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def __init__(
174174
calculate_kv_scales = False
175175
self.block_size = block_size
176176
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
177-
kv_cache_dtype, vllm_config.model_config.dtype
177+
kv_cache_dtype, vllm_config.model_config
178178
)
179179
if num_kv_heads is None:
180180
num_kv_heads = num_heads

vllm/utils/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,12 @@ def set_default_torch_num_threads(num_threads: int):
157157
torch.set_num_threads(old_num_threads)
158158

159159

160-
def kv_cache_dtype_str_to_dtype(kv_cache_dtype: str, model_dtype: str) -> torch.dtype:
160+
def kv_cache_dtype_str_to_dtype(
161+
kv_cache_dtype: str, model_config: ModelConfig
162+
) -> torch.dtype:
161163
if kv_cache_dtype == "auto":
162-
return model_dtype
164+
# Model config may not be specified for unit tests, default to float16
165+
return model_config.dtype if model_config else torch.half
163166
return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
164167

165168

vllm/v1/worker/gpu_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def __init__(
235235
self.pin_memory = is_pin_memory_available()
236236
self.dtype = self.model_config.dtype
237237
self.kv_cache_dtype = kv_cache_dtype_str_to_dtype(
238-
cache_config.cache_dtype, self.dtype
238+
cache_config.cache_dtype, self.model_config
239239
)
240240

241241
self.is_pooling_model = model_config.runner_type == "pooling"

0 commit comments

Comments
 (0)