diff --git a/python/sgl_kernel_npu/sgl_kernel_npu/fla/utils.py b/python/sgl_kernel_npu/sgl_kernel_npu/fla/utils.py index 6d676b2c5..f4c4ae21f 100644 --- a/python/sgl_kernel_npu/sgl_kernel_npu/fla/utils.py +++ b/python/sgl_kernel_npu/sgl_kernel_npu/fla/utils.py @@ -135,10 +135,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: @tensor_cache def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: - device = cu_seqlens.device - cu_seqlens = cu_seqlens.to(torch.int64).cpu().clone() - tmp = cu_seqlens[1:] - cu_seqlens[:-1] - return tmp.to(device) + cu_seqlens_i64 = cu_seqlens.to(torch.int64) + return cu_seqlens_i64[1:] - cu_seqlens_i64[:-1] @tensor_cache