|
73 | 73 | from sglang.srt.managers.schedule_batch import global_server_args_dict |
74 | 74 | from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode |
75 | 75 | from sglang.srt.model_loader.weight_utils import default_weight_loader |
76 | | -from sglang.srt.utils import ( |
77 | | - DeepEPMode, |
78 | | - add_prefix, |
79 | | - is_cuda, |
80 | | - is_flashinfer_available, |
81 | | - is_hip, |
82 | | -) |
| 76 | +from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip |
83 | 77 |
|
84 | 78 | _is_hip = is_hip() |
85 | 79 | _is_cuda = is_cuda() |
@@ -1060,8 +1054,7 @@ def _chunked_prefix_attn_mha( |
1060 | 1054 | forward_batch: ForwardBatch, |
1061 | 1055 | ) -> torch.Tensor: |
1062 | 1056 |
|
1063 | | - assert is_flashinfer_available() |
1064 | | - from flashinfer.cascade import merge_state |
| 1057 | + from sgl_kernel import merge_state |
1065 | 1058 |
|
1066 | 1059 | assert forward_batch.num_prefix_chunks is not None |
1067 | 1060 | for i in range(forward_batch.num_prefix_chunks): |
@@ -1100,7 +1093,7 @@ def _chunked_prefix_attn_mha( |
1100 | 1093 |
|
1101 | 1094 | output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False) |
1102 | 1095 | lse = torch.transpose(lse, 0, 1).contiguous() |
1103 | | - accum_output, accum_lse = merge_state(accum_output, accum_lse, output, lse) |
| 1096 | + accum_output, accum_lse = merge_state(output, lse, accum_output, accum_lse) |
1104 | 1097 |
|
1105 | 1098 | return accum_output |
1106 | 1099 |
|
|
0 commit comments