@@ -1163,13 +1163,24 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
11631163 // not realy a GGML_TYPE_Q8_0 but same size.
11641164 switch (op->op ) {
11651165 case GGML_OP_MUL_MAT:
1166- size = ggml_row_size (PARAM_TYPE, ggml_nelements (op->src [1 ]));
1167- return true ;
1166+ {
1167+ size = ggml_row_size (PARAM_TYPE, ggml_nelements (op->src [1 ]));
1168+ return true ;
1169+ }
11681170 case GGML_OP_MUL_MAT_ID:
1169- size = ggml_row_size (PARAM_TYPE, ggml_nelements (op->src [1 ]));
1170- size = GGML_PAD (size, sizeof (int64_t )); // + padding for next bloc.
1171- size += sizeof (int64_t ) * (1 +op->src [0 ]->ne [2 ]) * op->src [1 ]->ne [2 ];
1172- return true ;
1171+ {
1172+ size = ggml_row_size (PARAM_TYPE, ggml_nelements (op->src [1 ]));
1173+ size = GGML_PAD (size, sizeof (int64_t )); // + padding for next bloc.
1174+
1175+ const int64_t ne02 = op->src [0 ]->ne [2 ]; // n_as, n_expert
1176+ const int64_t ne12 = op->src [1 ]->ne [2 ]; // n_tokens
1177+
1178+ const size_t sizeof_mmid_row_mapping = sizeof (int64_t );
1179+
1180+ size += sizeof_mmid_row_mapping*ne02*(ne12 + 1 );
1181+
1182+ return true ;
1183+ }
11731184 default :
11741185 // GGML_ABORT("fatal error");
11751186 break ;
@@ -1305,14 +1316,17 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
13051316 int32_t i2;
13061317 };
13071318
1308- GGML_ASSERT (params->wsize >= (GGML_PAD (nbw3, sizeof (int64_t )) + n_as * sizeof (int64_t ) +
1309- n_as * ne12 * sizeof (mmid_row_mapping)));
1319+ GGML_ASSERT (params->wsize >=
1320+ (GGML_PAD (nbw3, sizeof (int64_t )) +
1321+ n_as*(ne12 + 1 )*sizeof (mmid_row_mapping))
1322+ );
13101323
1311- auto * wdata = (char *) params->wdata ;
1312- auto * wdata_src1_end = (char *) wdata + GGML_PAD (nbw3, sizeof (int64_t ));
1313- auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
1324+ auto * wdata = (char *)params->wdata ;
1325+ auto * wdata_src1_end = (char *)wdata + GGML_PAD (nbw3, sizeof (int64_t ));
13141326
1315- struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
1327+ // total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t)
1328+ auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
1329+ struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
13161330
13171331 // src1: float32 => param type
13181332 for (int64_t i12 = 0 ; i12 < ne12; ++i12) {
0 commit comments