diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index b434780e19a2..7b8919859985 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -629,28 +629,32 @@ def invoke_fused_moe_wna16_triton_kernel( EM = sorted_token_ids.size(0) if A.size(0) < config["BLOCK_SIZE_M"]: # optimize for small batch_size. - # We assume that top_ids of each token is unique, - # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, - # and we can skip some invalid blocks. EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]) + + config = config.copy() + + block_config = get_moe_wna16_block_config( + config=config, + use_moe_wna16_cuda=False, + num_valid_tokens=num_tokens, + size_k=A.size(1), + size_n=B.size(1), + num_experts=B.size(1), + group_size=block_shape[1], + real_top_k=top_k, + block_size_m=config["BLOCK_SIZE_M"], + ) + config.update(block_config) + + if "BLOCK_SIZE_N" not in config: + config["BLOCK_SIZE_N"] = 64 + if "BLOCK_SIZE_K" not in config: + config["BLOCK_SIZE_K"] = 32 + grid = lambda META: ( triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]), ) - config = config.copy() - config.update( - get_moe_wna16_block_config( - config=config, - use_moe_wna16_cuda=False, - num_valid_tokens=num_tokens, - size_k=A.size(1), - size_n=B.size(1), - num_experts=B.size(1), - group_size=block_shape[1], - real_top_k=top_k, - block_size_m=config["BLOCK_SIZE_M"], - ) - ) fused_moe_kernel_gptq_awq[grid]( A, @@ -2373,12 +2377,13 @@ def apply( topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map ) - invoke_fused_moe_triton_kernel( + dispatch_fused_moe_kernel( hidden_states, w1, intermediate_cache1, a1q_scale, self.w1_scale, + getattr(self, "w1_zp", None), None, # topk_weights sorted_token_ids, expert_ids, @@ -2410,12 +2415,13 @@ def apply( self.block_shape, ) - invoke_fused_moe_triton_kernel( + dispatch_fused_moe_kernel( qintermediate_cache2, w2, intermediate_cache3, a2q_scale, self.w2_scale, + getattr(self, "w2_zp", None), topk_weights, sorted_token_ids, expert_ids,