diff --git a/sgl-kernel/csrc/elementwise/rope.cu b/sgl-kernel/csrc/elementwise/rope.cu index 23bc87660396..b6abd1cce122 100644 --- a/sgl-kernel/csrc/elementwise/rope.cu +++ b/sgl-kernel/csrc/elementwise/rope.cu @@ -14,8 +14,13 @@ * limitations under the License. */ +#include +#include +#include +#include + #include "pos_enc.cuh" -#include "pytorch_extension_utils.h" +#include "utils.h" using namespace flashinfer; @@ -88,7 +93,7 @@ void apply_rope_pos_ids_cos_sin_cache( size_t k_rope_stride_h = k_rope.stride(1); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(q.scalar_type(), c_type, [&] { // TODO temporarily only use `BatchQKApplyRotaryPosIdsCosSinCacheEnhanced` when save_kv_cache // to avoid changing original code path; but this branch is feature-complete and should switch to this later if (save_kv_cache) { diff --git a/sgl-kernel/tests/test_rotary_embedding.py b/sgl-kernel/tests/test_rotary_embedding.py index d9f9364b0fa3..cc5374dbe8db 100644 --- a/sgl-kernel/tests/test_rotary_embedding.py +++ b/sgl-kernel/tests/test_rotary_embedding.py @@ -47,6 +47,12 @@ (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8, False), (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4, False), (512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2, False), + (64, 64, 32, 8000, True, torch.float32, "cuda", 32, 32, 1, 1, False), + (256, 128, 4096, 10000, True, torch.float32, "cuda", 2, 512, 4, 2, False), + (512, 128, 311, 10000, True, torch.float32, "cuda", 3, 39, 4, 2, False), + (128, 128, 2048, 10000, False, torch.float32, "cuda", 2, 512, 32, 8, False), + (128, 128, 2048, 10000, False, torch.float32, "cuda", 2, 512, 16, 4, False), + (512, 128, 311, 10000, False, torch.float32, "cuda", 3, 39, 4, 2, False), ], ) def test_correctness(