Skip to content

Commit b8c89c9

Browse files
committed
update merge state
1 parent 2462ecc commit b8c89c9

File tree

1 file changed

+3
-10
lines changed

1 file changed

+3
-10
lines changed

python/sglang/srt/models/deepseek_v2.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,7 @@
7373
from sglang.srt.managers.schedule_batch import global_server_args_dict
7474
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
7575
from sglang.srt.model_loader.weight_utils import default_weight_loader
76-
from sglang.srt.utils import (
77-
DeepEPMode,
78-
add_prefix,
79-
is_cuda,
80-
is_flashinfer_available,
81-
is_hip,
82-
)
76+
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip
8377

8478
_is_hip = is_hip()
8579
_is_cuda = is_cuda()
@@ -1060,8 +1054,7 @@ def _chunked_prefix_attn_mha(
10601054
forward_batch: ForwardBatch,
10611055
) -> torch.Tensor:
10621056

1063-
assert is_flashinfer_available()
1064-
from flashinfer.cascade import merge_state
1057+
from sgl_kernel import merge_state
10651058

10661059
assert forward_batch.num_prefix_chunks is not None
10671060
for i in range(forward_batch.num_prefix_chunks):
@@ -1100,7 +1093,7 @@ def _chunked_prefix_attn_mha(
11001093

11011094
output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
11021095
lse = torch.transpose(lse, 0, 1).contiguous()
1103-
accum_output, accum_lse = merge_state(accum_output, accum_lse, output, lse)
1096+
accum_output, accum_lse = merge_state(output, lse, accum_output, accum_lse)
11041097

11051098
return accum_output
11061099

0 commit comments

Comments
 (0)