Skip to content
3 changes: 2 additions & 1 deletion csrc/dispatch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__)\
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)

#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
pynvml == 11.5.0
triton >= 2.1.0
triton >= 2.2.0
outlines >= 0.0.27
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
8 changes: 8 additions & 0 deletions tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,17 @@ def test_copy_blocks(

# Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
# NOTE: torch.allclose has not supported
# torch.fp8_e5m2/torch.fp8_e4m3fn dtypes.
if kv_cache_dtype == "fp8_e5m2":
key_cache = key_cache.view(torch.half)
cloned_key_cache = cloned_key_cache.view(torch.half)
assert torch.allclose(key_cache, cloned_key_cache)
for value_cache, cloned_value_cache in zip(value_caches,
cloned_value_caches):
if kv_cache_dtype == "fp8_e5m2":
value_cache = value_cache.view(torch.half)
cloned_value_cache = cloned_value_cache.view(torch.half)
assert torch.allclose(value_cache, cloned_value_cache)


Expand Down
Loading