diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 7e28341dc373..c3fbaa422076 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -564,7 +564,9 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens cum_seq_lens_q = torch.cat( ( - torch.tensor([0], device=forward_batch.seq_lens.device), + torch.zeros( + 1, dtype=torch.int32, device=forward_batch.seq_lens.device + ), torch.cumsum(seq_lens, dim=0), ) ).int()