@@ -340,28 +340,32 @@ def schedule(
340340 prefix_scheduler_metadata = None
341341
342342 if self .dcp_world_size > 1 :
343- query_kv_lens_cpu = common_attn_metadata .query_start_loc_cpu [1 :] \
343+ query_kv_lens_cpu = (
344+ common_attn_metadata .query_start_loc_cpu [1 :]
344345 - common_attn_metadata .query_start_loc_cpu [:- 1 ]
346+ )
345347 dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu
346- dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu \
347- // self .dcp_world_size + ( self .dcp_rank \
348- <= ( dcp_context_kv_lens_cpu - 1 ) % self . dcp_world_size )
348+ dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu // self . dcp_world_size + (
349+ self .dcp_rank <= ( dcp_context_kv_lens_cpu - 1 ) % self .dcp_world_size
350+ )
349351 dcp_context_kv_lens = dcp_context_kv_lens_cpu .to (self .device )
350352 max_dcp_context_kv_len = dcp_context_kv_lens .max ().item ()
351353
352- scheduler_metadata = schedule (batch_size = num_reqs ,
353- cu_query_lens = query_start_loc ,
354- max_query_len = max_query_len ,
355- seqlens = dcp_context_kv_lens ,
356- max_seq_len = max_dcp_context_kv_len ,
357- causal = False )
354+ scheduler_metadata = schedule (
355+ batch_size = num_reqs ,
356+ cu_query_lens = query_start_loc ,
357+ max_query_len = max_query_len ,
358+ seqlens = dcp_context_kv_lens ,
359+ max_seq_len = max_dcp_context_kv_len ,
360+ causal = False ,
361+ )
358362 elif use_cascade :
359- cu_prefix_query_lens = torch .tensor ([ 0 , num_actual_tokens ],
360- dtype = torch .int32 ,
361- device = self . device )
362- prefix_kv_lens = torch .tensor ([ common_prefix_len ],
363- dtype = torch .int32 ,
364- device = self . device )
363+ cu_prefix_query_lens = torch .tensor (
364+ [ 0 , num_actual_tokens ], dtype = torch .int32 , device = self . device
365+ )
366+ prefix_kv_lens = torch .tensor (
367+ [ common_prefix_len ], dtype = torch .int32 , device = self . device
368+ )
365369 suffix_kv_lens = (seq_lens_cpu [:num_reqs ] - common_prefix_len ).to (
366370 self .device , non_blocking = True
367371 )
@@ -683,60 +687,57 @@ def _forward_with_dcp(
683687
684688 query = query .contiguous ()
685689 query_across_dcp = get_dcp_group ().all_gather (query , dim = 1 )
686- context_attn_out , context_lse = \
687- flash_attn_varlen_func (
688- q = query_across_dcp ,
689- k = key_cache ,
690- v = value_cache ,
691- out = None ,
692- cu_seqlens_q = cu_seqlens_q ,
693- max_seqlen_q = max_seqlen_q ,
694- seqused_k = attn_metadata .dcp_context_kv_lens ,
695- max_seqlen_k = attn_metadata .max_dcp_context_kv_len ,
696- softmax_scale = self .scale ,
697- causal = False ,
698- alibi_slopes = self .alibi_slopes ,
699- window_size = self .sliding_window ,
700- block_table = block_table ,
701- softcap = self .logits_soft_cap ,
702- return_softmax_lse = True ,
703- scheduler_metadata = attn_metadata .scheduler_metadata ,
704- fa_version = self .vllm_flash_attn_version ,
705- q_descale = q_descale ,
706- k_descale = k_descale ,
707- v_descale = v_descale ,
708- )
690+ context_attn_out , context_lse = flash_attn_varlen_func (
691+ q = query_across_dcp ,
692+ k = key_cache ,
693+ v = value_cache ,
694+ out = None ,
695+ cu_seqlens_q = cu_seqlens_q ,
696+ max_seqlen_q = max_seqlen_q ,
697+ seqused_k = attn_metadata .dcp_context_kv_lens ,
698+ max_seqlen_k = attn_metadata .max_dcp_context_kv_len ,
699+ softmax_scale = self .scale ,
700+ causal = False ,
701+ alibi_slopes = self .alibi_slopes ,
702+ window_size = self .sliding_window ,
703+ block_table = block_table ,
704+ softcap = self .logits_soft_cap ,
705+ return_softmax_lse = True ,
706+ scheduler_metadata = attn_metadata .scheduler_metadata ,
707+ fa_version = self .vllm_flash_attn_version ,
708+ q_descale = q_descale ,
709+ k_descale = k_descale ,
710+ v_descale = v_descale ,
711+ )
709712 # FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
710- context_attn_out_cor , context_lse_cor = \
711- cp_lse_ag_out_rs (
712- context_attn_out ,
713- context_lse .transpose (0 , 1 ),
714- get_dcp_group (),
715- return_lse = True
716- )
713+ context_attn_out_cor , context_lse_cor = cp_lse_ag_out_rs (
714+ context_attn_out ,
715+ context_lse .transpose (0 , 1 ),
716+ get_dcp_group (),
717+ return_lse = True ,
718+ )
717719 context_lse_cor = context_lse_cor .transpose (0 , 1 ).contiguous ()
718720
719- query_attn_out , query_lse = \
720- flash_attn_varlen_func (
721- q = query ,
722- k = key ,
723- v = value ,
724- out = None ,
725- cu_seqlens_q = cu_seqlens_q ,
726- max_seqlen_q = max_seqlen_q ,
727- cu_seqlens_k = cu_seqlens_q ,
728- max_seqlen_k = max_seqlen_q ,
729- softmax_scale = self .scale ,
730- causal = attn_metadata .causal ,
731- alibi_slopes = self .alibi_slopes ,
732- window_size = self .sliding_window ,
733- softcap = self .logits_soft_cap ,
734- return_softmax_lse = True ,
735- fa_version = self .vllm_flash_attn_version ,
736- q_descale = q_descale ,
737- k_descale = k_descale ,
738- v_descale = v_descale ,
739- )
721+ query_attn_out , query_lse = flash_attn_varlen_func (
722+ q = query ,
723+ k = key ,
724+ v = value ,
725+ out = None ,
726+ cu_seqlens_q = cu_seqlens_q ,
727+ max_seqlen_q = max_seqlen_q ,
728+ cu_seqlens_k = cu_seqlens_q ,
729+ max_seqlen_k = max_seqlen_q ,
730+ softmax_scale = self .scale ,
731+ causal = attn_metadata .causal ,
732+ alibi_slopes = self .alibi_slopes ,
733+ window_size = self .sliding_window ,
734+ softcap = self .logits_soft_cap ,
735+ return_softmax_lse = True ,
736+ fa_version = self .vllm_flash_attn_version ,
737+ q_descale = q_descale ,
738+ k_descale = k_descale ,
739+ v_descale = v_descale ,
740+ )
740741 assert context_attn_out_cor .shape == query_attn_out .shape
741742 assert context_lse_cor .shape == query_lse .shape
742743 merge_attn_states (
0 commit comments