diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 5f62b3f2ea3b..dfddc1d67a62 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -113,7 +113,7 @@ def __init__( if not _is_cuda: cache = cache.to(dtype) - if dtype == torch.float32 or ( + if ( (not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]) and not (_is_cpu and _is_cpu_amx_available) and not (_is_xpu) @@ -273,11 +273,7 @@ def forward_cuda( offsets: Optional[torch.Tensor] = None, fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - if ( - _is_cuda - and (self.head_size in [64, 128, 256, 512]) - and self.dtype != torch.float32 - ): + if _is_cuda and (self.head_size in [64, 128, 256, 512]): apply_rope_with_cos_sin_cache_inplace( positions=positions, query=query, diff --git a/test/srt/cpu/test_rope.py b/test/srt/cpu/test_rope.py index 8c1dfe9aa168..22824e0ca57e 100644 --- a/test/srt/cpu/test_rope.py +++ b/test/srt/cpu/test_rope.py @@ -146,6 +146,12 @@ def single_test( (128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 32, 8), (128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 16, 4), (512, 128, 311, 10000, False, torch.bfloat16, "cpu", 3, 39, 4, 2), + (64, 64, 32, 8000, True, torch.float32, "cpu", 32, 32, 1, 1), + (256, 128, 4096, 10000, True, torch.float32, "cpu", 2, 512, 32, 8), + (512, 128, 311, 10000, True, torch.float32, "cpu", 3, 39, 4, 2), + (128, 128, 2048, 10000, False, torch.float32, "cpu", 2, 512, 32, 8), + (128, 128, 2048, 10000, False, torch.float32, "cpu", 2, 512, 16, 4), + (512, 128, 311, 10000, False, torch.float32, "cpu", 3, 39, 4, 2), ] for ( diff --git a/test/srt/rotary_embedding/test_mrope.py b/test/srt/rotary_embedding/test_mrope.py index ad6412ec4d2f..4fbfd06911bc 100644 --- a/test/srt/rotary_embedding/test_mrope.py +++ b/test/srt/rotary_embedding/test_mrope.py @@ -76,7 +76,7 @@ class MRoPETestInfo(NamedTuple): ], ) @pytest.mark.parametrize("tp_size", [1, 2]) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("num_tokens", num_tokens_list) def test_mrope( model_name: str,