File tree Expand file tree Collapse file tree 3 files changed +7
-4
lines changed Expand file tree Collapse file tree 3 files changed +7
-4
lines changed Original file line number Diff line number Diff line change @@ -174,7 +174,7 @@ def __init__(
174174 calculate_kv_scales = False
175175 self .block_size = block_size
176176 self .kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype (
177- kv_cache_dtype , vllm_config .model_config . dtype
177+ kv_cache_dtype , vllm_config .model_config
178178 )
179179 if num_kv_heads is None :
180180 num_kv_heads = num_heads
Original file line number Diff line number Diff line change @@ -157,9 +157,12 @@ def set_default_torch_num_threads(num_threads: int):
157157 torch .set_num_threads (old_num_threads )
158158
159159
160- def kv_cache_dtype_str_to_dtype (kv_cache_dtype : str , model_dtype : str ) -> torch .dtype :
160+ def kv_cache_dtype_str_to_dtype (
161+ kv_cache_dtype : str , model_config : ModelConfig
162+ ) -> torch .dtype :
161163 if kv_cache_dtype == "auto" :
162- return model_dtype
164+ # Model config may not be specified for unit tests, default to float16
165+ return model_config .dtype if model_config else torch .half
163166 return STR_DTYPE_TO_TORCH_DTYPE [kv_cache_dtype ]
164167
165168
Original file line number Diff line number Diff line change @@ -235,7 +235,7 @@ def __init__(
235235 self .pin_memory = is_pin_memory_available ()
236236 self .dtype = self .model_config .dtype
237237 self .kv_cache_dtype = kv_cache_dtype_str_to_dtype (
238- cache_config .cache_dtype , self .dtype
238+ cache_config .cache_dtype , self .model_config
239239 )
240240
241241 self .is_pooling_model = model_config .runner_type == "pooling"
You can’t perform that action at this time.
0 commit comments