diff --git a/projects/composablekernel/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp b/projects/composablekernel/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp index 12d28f572c5d..684219b58432 100644 --- a/projects/composablekernel/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp +++ b/projects/composablekernel/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp @@ -256,7 +256,8 @@ struct DeviceMoeGemmBlockScale if(arg_.KBatch > 1) hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, 0, - arg_.M * arg_.N * sizeof(CDataType) * + arg_.NumTokens * arg_.TopK * arg_.N * + sizeof(CDataType) * (IsInputGemm && IsSplitK ? 2 : 1), stream_config.stream_id_)); }; @@ -273,12 +274,12 @@ struct DeviceMoeGemmBlockScale else { if(arg.KBatch > 1) - hipGetErrorString(hipMemsetAsync(arg.p_c_grid, - 0, - arg.M * arg.N * sizeof(CDataType) * - (IsInputGemm && IsSplitK ? 2 : 1), - stream_config.stream_id_)); - + hipGetErrorString( + hipMemsetAsync(arg.p_c_grid, + 0, + arg.NumTokens * arg.TopK * arg.N * sizeof(CDataType) * + (IsInputGemm && IsSplitK ? 2 : 1), + stream_config.stream_id_)); ave_time = launch_and_time_kernel( stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); }