@@ -2164,38 +2164,6 @@ struct mmid_row_mapping {
21642164    int32_t  i2;
21652165};
21662166
2167- static  __global__  void  k_copy_src1_to_contiguous (const  char  * __restrict__  src1_original, char  * __restrict__  src1_contiguous,
2168-                                                  int  * __restrict__  cur_src1_row, mmid_row_mapping * __restrict__  row_mapping,
2169-                                                  const  char  * __restrict ids, int64_t  i02, size_t  ids_nb1, size_t  ids_nb0,
2170-                                                  int64_t  ne11, int64_t  ne10,
2171-                                                  size_t  nb11, size_t  nb12) {
2172-     int32_t  iid1 = blockIdx .x ;
2173-     int32_t  id = blockIdx .y ;
2174- 
2175-     const  int32_t  row_id_i = *(const  int32_t  *) (ids + iid1*ids_nb1 + id*ids_nb0);
2176- 
2177-     if  (row_id_i != i02) {
2178-         return ;
2179-     }
2180- 
2181-     const  int64_t  i11 = id % ne11;
2182-     const  int64_t  i12 = iid1;
2183- 
2184-     __shared__  int  src1_row;
2185-     if  (threadIdx .x  == 0 ) {
2186-         src1_row = atomicAdd (cur_src1_row, 1 );
2187-         row_mapping[src1_row] = {id, iid1};
2188-     }
2189-     __syncthreads ();
2190- 
2191-     const  float  * src1_row_original = (const  float  *)(src1_original + i11*nb11 + i12*nb12);
2192-     float  * src1_row_contiguous = (float  *)(src1_contiguous + src1_row*nb11);
2193- 
2194-     for  (int  i = threadIdx .x ; i < ne10; i += blockDim .x ) {
2195-         src1_row_contiguous[i] = src1_row_original[i];
2196-     }
2197- }
2198- 
21992167static  __global__  void  k_copy_src_to_contiguous (const  char  * __restrict__  src_original, char  * __restrict__  src_contiguous,
22002168                                                  const  mmid_row_mapping * __restrict__  row_mapping,
22012169                                                  int64_t  ne10, int64_t  ne11, size_t  nb11, size_t  nb12) {
@@ -2229,6 +2197,51 @@ static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_origin
22292197    }
22302198}
22312199
2200+ static  inline  void  prepare_row_mappigs (ggml_backend_cuda_context& ctx, int64_t  n_as, int64_t  n_ids,
2201+         const  ggml_tensor * ids, std::vector<int >& moe_counts, std::vector<int >& cum_moe_counts,
2202+         ggml_cuda_pool_alloc<mmid_row_mapping>& dev_row_mapping) {
2203+ 
2204+     GGML_ASSERT (moe_counts.empty () && cum_moe_counts.empty ());
2205+ 
2206+     auto  stream = ctx.stream ();
2207+ 
2208+     std::vector<char > ids_host (ggml_nbytes (ids));
2209+     const  char  * ids_dev = (const  char  *) ids->data ;
2210+     CUDA_CHECK (cudaMemcpyAsync (ids_host.data (), ids_dev, ggml_nbytes (ids), cudaMemcpyDeviceToHost, stream));
2211+     // CUDA_CHECK(cudaStreamSynchronize(stream));
2212+ 
2213+     std::vector<mmid_row_mapping> rmapping (ids->ne [1 ]*n_ids);
2214+     moe_counts.resize (n_as, 0 );
2215+     cum_moe_counts.resize (n_as + 1 );
2216+ 
2217+     for  (int64_t  iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2218+         for  (int64_t  id = 0 ; id < n_ids; id++) {
2219+             const  int32_t  row_id_i = *(const  int32_t  *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
2220+             if  (row_id_i >= 0  && row_id_i < n_as) ++moe_counts[row_id_i];
2221+         }
2222+     }
2223+     cum_moe_counts[0 ] = 0 ;
2224+     for  (int  i = 0 ; i < (int )n_as; ++i) {
2225+         cum_moe_counts[i+1 ] = cum_moe_counts[i] + moe_counts[i];
2226+     }
2227+ 
2228+     dev_row_mapping.alloc (cum_moe_counts[n_as]);
2229+ 
2230+     for  (int64_t  iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2231+         for  (int64_t  id = 0 ; id < n_ids; id++) {
2232+             const  int32_t  row_id_i = *(const  int32_t  *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
2233+             if  (row_id_i >= 0  && row_id_i < n_as) {
2234+                 rmapping[cum_moe_counts[row_id_i]++] = {(int )id, (int )iid1};
2235+             }
2236+         }
2237+     }
2238+ 
2239+     for  (int  i = 0 ; i < (int )n_as; ++i) cum_moe_counts[i] -= moe_counts[i];
2240+ 
2241+     CUDA_CHECK (cudaMemcpyAsync (dev_row_mapping.get (), rmapping.data (), cum_moe_counts[n_as]*sizeof (mmid_row_mapping), cudaMemcpyHostToDevice, stream));
2242+ 
2243+ }
2244+ 
22322245static  void  ggml_cuda_mul_mat_id (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
22332246    const  ggml_tensor * src0 = dst->src [0 ];
22342247    const  ggml_tensor * src1 = dst->src [1 ];
@@ -2289,10 +2302,10 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
22892302    const  int64_t  n_as = ne02;
22902303    const  int64_t  n_ids = ids->ne [0 ];
22912304
2292-     std::vector<char > ids_host (ggml_nbytes (ids));
2293-     const  char  * ids_dev = (const  char  *) ids->data ;
2294-     CUDA_CHECK (cudaMemcpyAsync (ids_host.data (), ids_dev, ggml_nbytes (ids), cudaMemcpyDeviceToHost, stream));
2295-     CUDA_CHECK (cudaStreamSynchronize (stream));
2305+     // std::vector<char> ids_host(ggml_nbytes(ids));
2306+     // const char * ids_dev = (const char *) ids->data;
2307+     // CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
2308+     // CUDA_CHECK(cudaStreamSynchronize(stream));
22962309
22972310    ggml_tensor src0_row = *src0;
22982311    ggml_tensor src1_row = *src1;
@@ -2319,6 +2332,9 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
23192332    dst_row.nb [3 ] = nb1;
23202333
23212334    if  (ne12 == 1 ) {
2335+         std::vector<char > ids_host (ggml_nbytes (ids));
2336+         const  char  * ids_dev = (const  char  *) ids->data ;
2337+         CUDA_CHECK (cudaMemcpyAsync (ids_host.data (), ids_dev, ggml_nbytes (ids), cudaMemcpyDeviceToHost, stream));
23222338        for  (int64_t  iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
23232339            for  (int64_t  id = 0 ; id < n_ids; id++) {
23242340                const  int32_t  i02 = *(const  int32_t  *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
@@ -2340,44 +2356,32 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
23402356            }
23412357        }
23422358    } else  {
2359+ 
2360+         ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping (ctx.pool ());
2361+         std::vector<int > moe_counts, cum_moe_counts;
2362+         prepare_row_mappigs (ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping);
2363+ 
23432364        ggml_cuda_pool_alloc<char > src1_contiguous (ctx.pool (), sizeof (float )*ggml_nelements (src1));
23442365        ggml_cuda_pool_alloc<char >  dst_contiguous (ctx.pool (), sizeof (float )*ggml_nelements (dst));
23452366
23462367        src1_row.data  = src1_contiguous.get ();
23472368        dst_row.data   =  dst_contiguous.get ();
23482369
23492370        for  (int64_t  i02 = 0 ; i02 < n_as; i02++) {
2350-             int64_t  num_src1_rows = 0 ;
23512371
2352-             for  (int64_t  iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2353-                 for  (int64_t  id = 0 ; id < n_ids; id++) {
2354-                     const  int32_t  row_id_i = *(const  int32_t  *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
2355- 
2356-                     if  (row_id_i != i02) {
2357-                         continue ;
2358-                     }
2359- 
2360-                     num_src1_rows++;
2361-                 }
2362-             }
2372+             int64_t  num_src1_rows = moe_counts[i02];
23632373
23642374            if  (num_src1_rows == 0 ) {
23652375                continue ;
23662376            }
23672377
2368-             ggml_cuda_pool_alloc<int > dev_cur_src1_row (ctx.pool (), 1 );
2369-             ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping (ctx.pool (), num_src1_rows);
2370-             CUDA_CHECK (cudaMemsetAsync (dev_cur_src1_row.get (), 0 , sizeof (int ), stream));
2378+             size_t  mapping_offset = cum_moe_counts[i02];
23712379
23722380            {
23732381                dim3  block_dims (std::min ((unsigned  int )ne10, 768u ));
2374-                 dim3  grid_dims (ids->ne [1 ], n_ids);
2375-                 k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
2376-                         src1_original, src1_contiguous.get (),
2377-                         dev_cur_src1_row.get (), dev_row_mapping.get (),
2378-                         ids_dev, i02, ids->nb [1 ], ids->nb [0 ],
2379-                         ne11, ne10,
2380-                         nb11, nb12);
2382+                 dim3  grid_dims (num_src1_rows);
2383+                 k_copy_src_to_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
2384+                         src1_original, src1_contiguous.get (), dev_row_mapping.get () + mapping_offset, ne10, ne11, nb11, nb12);
23812385                CUDA_CHECK (cudaGetLastError ());
23822386            }
23832387
@@ -2403,7 +2407,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
24032407                dim3  grid_dims (num_src1_rows);
24042408                k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
24052409                        dst_original, dst_contiguous.get (),
2406-                         dev_row_mapping.get (),
2410+                         dev_row_mapping.get () + mapping_offset ,
24072411                        ne0,
24082412                        nb1, nb2);
24092413                CUDA_CHECK (cudaGetLastError ());
@@ -2658,77 +2662,22 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
26582662
26592663        bool  first = false ; // true;
26602664
2661-         std::vector <mmid_row_mapping> rmapping (ids-> ne [ 1 ]*n_ids );
2662-         std::vector<int > moe_counts (n_as,  0 ),  cum_moe_counts (n_as+ 1 ) ;
2665+         ggml_cuda_pool_alloc <mmid_row_mapping> dev_row_mapping (ctx. pool () );
2666+         std::vector<int > moe_counts,  cum_moe_counts;
26632667
2664-         for  (int64_t  iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2665-             for  (int64_t  id = 0 ; id < n_ids; id++) {
2666-                 const  int32_t  row_id_i = *(const  int32_t  *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
2667-                 if  (row_id_i >= 0  && row_id_i < n_as) ++moe_counts[row_id_i];
2668-             }
2669-         }
2670-         cum_moe_counts[0 ] = 0 ;
2671-         for  (int  i = 0 ; i < (int )n_as; ++i) {
2672-             cum_moe_counts[i+1 ] = cum_moe_counts[i] + moe_counts[i];
2673-             // printf("moe_counts[%2d] = %d, cum = %d\n", i, moe_counts[i], cum_moe_counts[i+1]);
2674-         }
2675- 
2676-         ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping (ctx.pool (), cum_moe_counts[n_as]);
2677- 
2678-         for  (int64_t  iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2679-             for  (int64_t  id = 0 ; id < n_ids; id++) {
2680-                 const  int32_t  row_id_i = *(const  int32_t  *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
2681-                 if  (row_id_i >= 0  && row_id_i < n_as) {
2682-                     rmapping[cum_moe_counts[row_id_i]++] = {(int )id, (int )iid1};
2683-                 }
2684-             }
2685-         }
2686- 
2687-         for  (int  i = 0 ; i < (int )n_as; ++i) cum_moe_counts[i] -= moe_counts[i];
2688- 
2689-         CUDA_CHECK (cudaMemcpyAsync (dev_row_mapping.get (), rmapping.data (), cum_moe_counts[n_as]*sizeof (mmid_row_mapping), cudaMemcpyHostToDevice, stream));
2668+         prepare_row_mappigs (ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping);
26902669
26912670        for  (int64_t  i02 = 0 ; i02 < n_as; i02++) {
26922671            int64_t  num_src1_rows = moe_counts[i02];
2693-             // printf("Processing i02 = %d with %d counts\n", (int)i02, (int)num_src1_rows);
2694-             // int64_t num_src1_rows = 0;
2695- 
2696-             // for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2697-             //     for (int64_t id = 0; id < n_ids; id++) {
2698-             //         const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
2699-             //         if (row_id_i == i02) {
2700-             //             //if (id >= ne11) printf("Oops: id = %ld, ne11 = %ld\n", id, ne11);
2701-             //             //rmapping[num_src1_rows++] = {(int)(id%ne11), (int)iid1};
2702-             //             rmapping[num_src1_rows++] = {(int)id, (int)iid1};
2703-             //         }
2704-             //     }
2705-             // }
27062672
27072673            if  (num_src1_rows == 0 ) continue ;
27082674            size_t  mapping_offset = cum_moe_counts[i02];
27092675
2710-             // ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
2711-             // CUDA_CHECK(cudaMemcpyAsync(dev_row_mapping.get(), rmapping.data(), num_src1_rows*sizeof(mmid_row_mapping), cudaMemcpyHostToDevice, stream));
2712-             // CUDA_CHECK(cudaStreamSynchronize(stream));
2713- 
2714-             // ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
2715-             // ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
2716-             // CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
2717- 
27182676            {
2719-                 // printf("Invoking k_copy_src_to_contiguous kernel using offset %zu\n", offset);
27202677                dim3  block_dims (std::min ((unsigned  int )ne10, 768u ));
27212678                dim3  grid_dims (num_src1_rows);
27222679                k_copy_src_to_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
27232680                        src1_original, src1_contiguous.get (), dev_row_mapping.get () + mapping_offset, ne10, ne11, nb11, nb12);
2724-                 // dim3 block_dims(std::min((unsigned int)ne10, 768u));
2725-                 // dim3 grid_dims(ids->ne[1], n_ids);
2726-                 // k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
2727-                 //         src1_original, src1_contiguous.get(),
2728-                 //         dev_cur_src1_row.get(), dev_row_mapping.get(),
2729-                 //         ids_dev, i02, ids->nb[1], ids->nb[0],
2730-                 //         ne11, ne10,
2731-                 //         nb11, nb12);
27322681                CUDA_CHECK (cudaGetLastError ());
27332682            }
27342683
0 commit comments