From fe0743a5fff1473cba7d5ef375bc45a266ae9db5 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Thu, 13 Nov 2025 11:19:42 +0800 Subject: [PATCH 1/7] Add FP32 dtype support for RoPE --- python/sglang/srt/layers/rotary_embedding.py | 4 +--- sgl-kernel/csrc/elementwise/rope.cu | 2 +- sgl-kernel/tests/test_rotary_embedding.py | 6 ++++++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 7c0a27c59e42..0b64afb4219c 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -113,8 +113,7 @@ def __init__( if not _is_cuda: cache = cache.to(dtype) - if dtype == torch.float32 or ( - (not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]) + if ((not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]) and not (_is_cpu and _is_cpu_amx_available) and not (_is_xpu) ): @@ -276,7 +275,6 @@ def forward_cuda( if ( _is_cuda and (self.head_size in [64, 128, 256, 512]) - and self.dtype != torch.float32 ): apply_rope_with_cos_sin_cache_inplace( positions=positions, diff --git a/sgl-kernel/csrc/elementwise/rope.cu b/sgl-kernel/csrc/elementwise/rope.cu index 23bc87660396..3d8b7713b8de 100644 --- a/sgl-kernel/csrc/elementwise/rope.cu +++ b/sgl-kernel/csrc/elementwise/rope.cu @@ -88,7 +88,7 @@ void apply_rope_pos_ids_cos_sin_cache( size_t k_rope_stride_h = k_rope.stride(1); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(q.scalar_type(), c_type, [&] { // TODO temporarily only use `BatchQKApplyRotaryPosIdsCosSinCacheEnhanced` when save_kv_cache // to avoid changing original code path; but this branch is feature-complete and should switch to this later if (save_kv_cache) { diff --git a/sgl-kernel/tests/test_rotary_embedding.py b/sgl-kernel/tests/test_rotary_embedding.py index d9f9364b0fa3..cc5374dbe8db 100644 --- a/sgl-kernel/tests/test_rotary_embedding.py +++ b/sgl-kernel/tests/test_rotary_embedding.py @@ -47,6 +47,12 @@ (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8, False), (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4, False), (512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2, False), + (64, 64, 32, 8000, True, torch.float32, "cuda", 32, 32, 1, 1, False), + (256, 128, 4096, 10000, True, torch.float32, "cuda", 2, 512, 4, 2, False), + (512, 128, 311, 10000, True, torch.float32, "cuda", 3, 39, 4, 2, False), + (128, 128, 2048, 10000, False, torch.float32, "cuda", 2, 512, 32, 8, False), + (128, 128, 2048, 10000, False, torch.float32, "cuda", 2, 512, 16, 4, False), + (512, 128, 311, 10000, False, torch.float32, "cuda", 3, 39, 4, 2, False), ], ) def test_correctness( From e1a90af8a5a99334396a207dd387833d1a05a7ce Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Thu, 13 Nov 2025 11:40:37 +0800 Subject: [PATCH 2/7] fix codestyle --- python/sglang/srt/layers/rotary_embedding.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 0b64afb4219c..28de1fe1e727 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -113,7 +113,8 @@ def __init__( if not _is_cuda: cache = cache.to(dtype) - if ((not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]) + if ( + (not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]) and not (_is_cpu and _is_cpu_amx_available) and not (_is_xpu) ): @@ -272,10 +273,7 @@ def forward_cuda( offsets: Optional[torch.Tensor] = None, fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - if ( - _is_cuda - and (self.head_size in [64, 128, 256, 512]) - ): + if _is_cuda and (self.head_size in [64, 128, 256, 512]): apply_rope_with_cos_sin_cache_inplace( positions=positions, query=query, From 211046b34da4450d50d8a26915efcf7866093166 Mon Sep 17 00:00:00 2001 From: "Jin, Youzhi" Date: Thu, 13 Nov 2025 19:08:29 +0800 Subject: [PATCH 3/7] fix sgl-kernel build fail --- sgl-kernel/csrc/elementwise/rope.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/sgl-kernel/csrc/elementwise/rope.cu b/sgl-kernel/csrc/elementwise/rope.cu index 3d8b7713b8de..e9a7bc692e88 100644 --- a/sgl-kernel/csrc/elementwise/rope.cu +++ b/sgl-kernel/csrc/elementwise/rope.cu @@ -16,6 +16,7 @@ #include "pos_enc.cuh" #include "pytorch_extension_utils.h" +#include "utils.h" using namespace flashinfer; From 592c219afae468c2bb2e2d303ffce01dfbf356b0 Mon Sep 17 00:00:00 2001 From: "Jin, Youzhi" Date: Thu, 13 Nov 2025 19:20:07 +0800 Subject: [PATCH 4/7] fix sgl-kernel build fail --- sgl-kernel/csrc/elementwise/rope.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sgl-kernel/csrc/elementwise/rope.cu b/sgl-kernel/csrc/elementwise/rope.cu index e9a7bc692e88..9ec5f080cf39 100644 --- a/sgl-kernel/csrc/elementwise/rope.cu +++ b/sgl-kernel/csrc/elementwise/rope.cu @@ -15,8 +15,8 @@ */ #include "pos_enc.cuh" -#include "pytorch_extension_utils.h" #include "utils.h" +#include "pytorch_extension_utils.h" using namespace flashinfer; From fba6fb00f080d10cdca802b7fcb9139db307a649 Mon Sep 17 00:00:00 2001 From: "Jin, Youzhi" Date: Thu, 13 Nov 2025 20:16:03 +0800 Subject: [PATCH 5/7] cleanup --- sgl-kernel/csrc/elementwise/rope.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/sgl-kernel/csrc/elementwise/rope.cu b/sgl-kernel/csrc/elementwise/rope.cu index 9ec5f080cf39..07ad819b0b97 100644 --- a/sgl-kernel/csrc/elementwise/rope.cu +++ b/sgl-kernel/csrc/elementwise/rope.cu @@ -16,7 +16,6 @@ #include "pos_enc.cuh" #include "utils.h" -#include "pytorch_extension_utils.h" using namespace flashinfer; From a97d3e8d6a9dbaf4c75a6b45a4c15ee92592da9a Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Fri, 14 Nov 2025 02:09:31 +0800 Subject: [PATCH 6/7] fix compile error --- sgl-kernel/csrc/elementwise/rope.cu | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sgl-kernel/csrc/elementwise/rope.cu b/sgl-kernel/csrc/elementwise/rope.cu index 07ad819b0b97..b6abd1cce122 100644 --- a/sgl-kernel/csrc/elementwise/rope.cu +++ b/sgl-kernel/csrc/elementwise/rope.cu @@ -14,6 +14,11 @@ * limitations under the License. */ +#include +#include +#include +#include + #include "pos_enc.cuh" #include "utils.h" From 7a49d773a3ffab22535c9391dd901c48dfd9c2b5 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Sat, 15 Nov 2025 17:11:59 +0800 Subject: [PATCH 7/7] revert arch side mofication --- python/sglang/srt/layers/rotary_embedding.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 28de1fe1e727..7c0a27c59e42 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -113,7 +113,7 @@ def __init__( if not _is_cuda: cache = cache.to(dtype) - if ( + if dtype == torch.float32 or ( (not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]) and not (_is_cpu and _is_cpu_amx_available) and not (_is_xpu) @@ -273,7 +273,11 @@ def forward_cuda( offsets: Optional[torch.Tensor] = None, fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - if _is_cuda and (self.head_size in [64, 128, 256, 512]): + if ( + _is_cuda + and (self.head_size in [64, 128, 256, 512]) + and self.dtype != torch.float32 + ): apply_rope_with_cos_sin_cache_inplace( positions=positions, query=query,