diff --git a/csrc/cpp_itfs/pa/pa.py b/csrc/cpp_itfs/pa/pa.py index b4d26597bf..3157a59565 100644 --- a/csrc/cpp_itfs/pa/pa.py +++ b/csrc/cpp_itfs/pa/pa.py @@ -82,6 +82,7 @@ def paged_attention_rocm( torch.bfloat16: "__hip_bfloat16", torch.float16: "_Float16", torch.float8_e4m3fnuz: "uint8_t", + torch.float8_e4m3fn: "uint8_t", } warpSize = torch.cuda.get_device_properties(out.device).warp_size