diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 0b76b3556d9d..1a01274683bf 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -121,6 +121,8 @@ def __init__( init_new_workspace: bool = False, ): super().__init__() + self.prefill_backend = "fa2" + self.decode_backend = "fa2" # Store multi-item scoring delimiter for efficient access self.multi_item_scoring_delimiter = ( @@ -264,19 +266,21 @@ def __init__( BatchPrefillWithPagedKVCacheWrapper( self.workspace_buffer, "NHD", - backend="fa2", + backend=self.prefill_backend, ) ) self.prefill_wrappers_verify.append( BatchPrefillWithPagedKVCacheWrapper( self.workspace_buffer, "NHD", + backend=self.prefill_backend, ) ) self.decode_wrappers.append( BatchDecodeWithPagedKVCacheWrapper( self.workspace_buffer, "NHD", + backend=self.decode_backend, use_tensor_cores=self.decode_use_tensor_cores, ) ) @@ -555,6 +559,7 @@ def init_forward_metadata_capture_cuda_graph( BatchDecodeWithPagedKVCacheWrapper( self.workspace_buffer, "NHD", + backend=self.decode_backend, use_cuda_graph=True, use_tensor_cores=self.decode_use_tensor_cores, paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1], @@ -590,6 +595,7 @@ def init_forward_metadata_capture_cuda_graph( self.workspace_buffer, "NHD", use_cuda_graph=True, + backend=self.prefill_backend, qo_indptr_buf=self.cuda_graph_qo_indptr[i][: bs + 1], paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1], paged_kv_indices_buf=self.cuda_graph_kv_indices[i], @@ -619,7 +625,7 @@ def init_forward_metadata_capture_cuda_graph( BatchPrefillWithPagedKVCacheWrapper( self.workspace_buffer, "NHD", - backend="fa2", + backend=self.prefill_backend, use_cuda_graph=True, qo_indptr_buf=self.cuda_graph_qo_indptr[i][: bs + 1], paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1], @@ -649,7 +655,7 @@ def init_forward_metadata_capture_cuda_graph( BatchPrefillWithPagedKVCacheWrapper( self.workspace_buffer, "NHD", - backend="fa2", + backend=self.prefill_backend, use_cuda_graph=True, qo_indptr_buf=self.cuda_graph_qo_indptr[i][: bs + 1], paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1],