@@ -622,8 +622,7 @@ def __init__(
622622 assert quant_method is not None
623623 self .quant_method = quant_method
624624
625- dispatch_combine = self ._construct_dispatch_combine (
626- moe , quant_config )
625+ dispatch_combine = self ._construct_dispatch_combine (moe , quant_config )
627626
628627 success = self .quant_method .set_dispatch_combine (dispatch_combine )
629628
@@ -1029,13 +1028,12 @@ def forward(self, hidden_states: torch.Tensor,
10291028 return torch .ops .vllm .moe_forward (hidden_states , router_logits ,
10301029 self .layer_name )
10311030
1032-
10331031 def forward_impl_chunked (self , full_hidden_states : torch .Tensor ,
10341032 full_router_logits : torch .Tensor ):
10351033
10361034 full_final_hidden_states = torch .empty_like (full_hidden_states )
10371035
1038- def process_chunk (chunk_start , chunk_end , skip_result_store = False ):
1036+ def process_chunk (chunk_start , chunk_end , skip_result_store = False ):
10391037 hidden_states = full_hidden_states [chunk_start :chunk_end , :]
10401038 router_logits = full_router_logits [chunk_start :chunk_end , :]
10411039
@@ -1088,18 +1086,23 @@ def process_chunk(chunk_start, chunk_end, skip_result_store = False):
10881086 full_final_hidden_states [chunk_start :chunk_end , :].copy_ (
10891087 final_hidden_states )
10901088
1091- max_tokens_across_dp = get_forward_context ().dp_metadata .max_tokens_across_dp
1089+ max_tokens_across_dp = get_forward_context (
1090+ ).dp_metadata .max_tokens_across_dp
10921091 moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self .dp_size
10931092
10941093 num_tokens = full_hidden_states .size (0 )
1095- for chunk_start_ in range (0 , max_tokens_across_dp , moe_dp_chunk_size_per_rank ):
1096- chunk_start = chunk_start_
1097- chunk_end = min (chunk_start + moe_dp_chunk_size_per_rank , max_tokens_across_dp )
1094+ for chunk_start_ in range (0 , max_tokens_across_dp ,
1095+ moe_dp_chunk_size_per_rank ):
1096+ chunk_start = chunk_start_
1097+ chunk_end = min (chunk_start + moe_dp_chunk_size_per_rank ,
1098+ max_tokens_across_dp )
10981099 # clamp start and end
10991100 chunk_start = min (chunk_start , num_tokens - 1 )
11001101 chunk_end = min (chunk_end , num_tokens )
11011102
1102- process_chunk (chunk_start , chunk_end , skip_result_store = chunk_start_ >= num_tokens )
1103+ process_chunk (chunk_start ,
1104+ chunk_end ,
1105+ skip_result_store = chunk_start_ >= num_tokens )
11031106
11041107 return full_final_hidden_states
11051108
0 commit comments