diff --git a/lmdeploy/pytorch/kernels/cuda/fused_moe.py b/lmdeploy/pytorch/kernels/cuda/fused_moe.py index 46e4458cd2..6a60aec0ef 100644 --- a/lmdeploy/pytorch/kernels/cuda/fused_moe.py +++ b/lmdeploy/pytorch/kernels/cuda/fused_moe.py @@ -353,7 +353,7 @@ def __get_sorted_idx(topk_ids: torch.Tensor): # activate if intermediate_cache1.size(-1) % 2048 == 0: - unflat_size = intermediate_cache1.shape[:-2] + unflat_size = intermediate_cache1.shape[:-1] intermediate_cache1 = intermediate_cache1.flatten(0, -2) gate_cache = silu_and_mul(intermediate_cache1) gate_cache = gate_cache.unflatten(0, unflat_size)