From f6623327d8eeb85c024c93f20f7edbc793089595 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 18 Sep 2024 11:17:42 +0800 Subject: [PATCH] fix fused moe --- lmdeploy/pytorch/kernels/cuda/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/kernels/cuda/fused_moe.py b/lmdeploy/pytorch/kernels/cuda/fused_moe.py index 46e4458cd2..6a60aec0ef 100644 --- a/lmdeploy/pytorch/kernels/cuda/fused_moe.py +++ b/lmdeploy/pytorch/kernels/cuda/fused_moe.py @@ -353,7 +353,7 @@ def __get_sorted_idx(topk_ids: torch.Tensor): # activate if intermediate_cache1.size(-1) % 2048 == 0: - unflat_size = intermediate_cache1.shape[:-2] + unflat_size = intermediate_cache1.shape[:-1] intermediate_cache1 = intermediate_cache1.flatten(0, -2) gate_cache = silu_and_mul(intermediate_cache1) gate_cache = gate_cache.unflatten(0, unflat_size)