@@ -1231,7 +1231,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12311231
12321232 if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized (src0->type )) && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ] && dst->op_params [0 ] == GGML_PREC_DEFAULT) {
12331233 // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
1234- ggml_cuda_pool_alloc<half> src0_as_f16 (ctx.pool ());
1234+ ggml_cuda_pool_alloc<half> src0_as_f16 (ctx.pool (id ));
12351235 if (src0->type != GGML_TYPE_F16) {
12361236 const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda (src0->type );
12371237 GGML_ASSERT (to_fp16_cuda != nullptr );
@@ -1241,7 +1241,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12411241 }
12421242 const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get ();
12431243
1244- ggml_cuda_pool_alloc<half> src1_as_f16 (ctx.pool ());
1244+ ggml_cuda_pool_alloc<half> src1_as_f16 (ctx.pool (id ));
12451245 if (src1->type != GGML_TYPE_F16) {
12461246 const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda (src1->type );
12471247 GGML_ASSERT (to_fp16_cuda != nullptr );
@@ -1250,7 +1250,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12501250 to_fp16_cuda (src1_ddf_i, src1_as_f16.get (), ne, stream);
12511251 }
12521252 const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get ();
1253- ggml_cuda_pool_alloc<half> dst_f16 (ctx.pool (), row_diff*src1_ncols);
1253+ ggml_cuda_pool_alloc<half> dst_f16 (ctx.pool (id ), row_diff*src1_ncols);
12541254
12551255 const half alpha_f16 = 1 .0f ;
12561256 const half beta_f16 = 0 .0f ;
@@ -1960,20 +1960,84 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19601960 }
19611961}
19621962
1963+ struct mmid_row_mapping {
1964+ int64_t i1;
1965+ int64_t i2;
1966+ };
1967+
1968+ static __global__ void k_copy_src1_to_contiguous (const char * src1_original, char * src1_contiguous,
1969+ int * cur_src1_row, mmid_row_mapping * row_mapping,
1970+ const char * ids_dev, int64_t i02, int64_t ids_nb1, int64_t ids_nb0,
1971+ int64_t ids_ne1, int64_t n_ids,
1972+ int64_t ne11,
1973+ size_t nb11, size_t nb12) {
1974+ int64_t iid1 = blockIdx .x ;
1975+ int64_t id = blockIdx .y ;
1976+
1977+ if (iid1 >= ids_ne1 || id >= n_ids) {
1978+ return ;
1979+ }
1980+
1981+ const int32_t row_id_i = *(const int32_t *) (ids_dev + iid1*ids_nb1 + id*ids_nb0);
1982+
1983+ if (row_id_i != i02) {
1984+ return ;
1985+ }
1986+
1987+ const int64_t i11 = id % ne11;
1988+ const int64_t i12 = iid1;
1989+
1990+ __shared__ int src1_row;
1991+ if (threadIdx .x == 0 ) {
1992+ src1_row = atomicAdd (cur_src1_row, 1 );
1993+ row_mapping[src1_row] = {id, iid1};
1994+ }
1995+ __syncthreads ();
1996+
1997+ const char * src1_row_original = src1_original + i11*nb11 + i12*nb12;
1998+ char * src1_row_contiguous = src1_contiguous + src1_row*nb11;
1999+
2000+ for (int i = threadIdx .x ; i < nb11; i += blockDim .x ) {
2001+ src1_row_contiguous[i] = src1_row_original[i];
2002+ }
2003+ }
2004+
2005+ static __global__ void k_copy_dst_from_contiguous (char * dst_original, const char * dst_contiguous,
2006+ const mmid_row_mapping * row_mapping,
2007+ int64_t n_rows,
2008+ int64_t nb1, int64_t nb2) {
2009+ int64_t i = blockIdx .x ;
2010+
2011+ if (i >= n_rows) {
2012+ return ;
2013+ }
2014+
2015+ const int64_t i1 = row_mapping[i].i1 ;
2016+ const int64_t i2 = row_mapping[i].i2 ;
2017+
2018+ const char * dst_row_contiguous = dst_contiguous + i*nb1;
2019+ char * dst_row_original = dst_original + i1*nb1 + i2*nb2;
2020+
2021+ for (int j = threadIdx .x ; j < nb1; j += blockDim .x ) {
2022+ dst_row_original[j] = dst_row_contiguous[j];
2023+ }
2024+ }
2025+
2026+ // #define MMID_MEMCPY
2027+
19632028static void ggml_cuda_mul_mat_id (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
19642029 const ggml_tensor * src0 = dst->src [0 ];
19652030 const ggml_tensor * src1 = dst->src [1 ];
19662031 const ggml_tensor * ids = dst->src [2 ];
19672032
2033+ GGML_TENSOR_BINARY_OP_LOCALS
2034+
19682035 GGML_ASSERT (!ggml_backend_buffer_is_cuda_split (src0->buffer ) && " mul_mat_id does not support split buffers" );
19692036
19702037 cudaStream_t stream = ctx.stream ();
19712038
1972- const size_t nb11 = src1->nb [1 ];
1973- const size_t nb1 = dst->nb [1 ];
1974-
1975- const int32_t id = ((int32_t *) dst->op_params )[0 ];
1976- const int32_t n_as = src0->ne [2 ];
2039+ const int64_t n_as = ne02;
2040+ const int64_t n_ids = ids->ne [0 ];
19772041
19782042 std::vector<char > ids_host (ggml_nbytes (ids));
19792043 const char * ids_dev = (const char *) ids->data ;
@@ -1982,27 +2046,47 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
19822046
19832047 ggml_tensor src0_row = *src0;
19842048 ggml_tensor src1_row = *src1;
1985- ggml_tensor dst_row = *dst;
2049+ ggml_tensor dst_row = *dst;
19862050
19872051 char * src0_original = (char *) src0->data ;
19882052 char * src1_original = (char *) src1->data ;
19892053 char * dst_original = (char *) dst->data ;
19902054
19912055 src0_row.ne [2 ] = 1 ;
19922056 src0_row.ne [3 ] = 1 ;
1993- src0_row.nb [3 ] = src0-> nb [ 2 ] ;
2057+ src0_row.nb [3 ] = nb02 ;
19942058
1995- if (src1->ne [1 ] == 1 ) {
1996- for (int64_t i01 = 0 ; i01 < ids->ne [1 ]; i01++) {
1997- const int32_t row_id = *(const int32_t *) (ids_host.data () + i01*ids->nb [1 ] + id*ids->nb [0 ]);
2059+ src1_row.ne [1 ] = 1 ;
2060+ src1_row.ne [2 ] = 1 ;
2061+ src1_row.ne [3 ] = 1 ;
2062+ src1_row.nb [2 ] = nb11;
2063+ src1_row.nb [3 ] = nb11;
19982064
1999- GGML_ASSERT (row_id >= 0 && row_id < n_as);
2065+ dst_row.ne [1 ] = 1 ;
2066+ dst_row.ne [2 ] = 1 ;
2067+ dst_row.ne [3 ] = 1 ;
2068+ dst_row.nb [2 ] = nb1;
2069+ dst_row.nb [3 ] = nb1;
20002070
2001- src0_row.data = src0_original + row_id*src0->nb [2 ];
2002- src1_row.data = src1_original + i01*src1->nb [1 ];
2003- dst_row.data = dst_original + i01*dst->nb [1 ];
2071+ if (ne12 == 1 ) {
2072+ for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2073+ for (int64_t id = 0 ; id < n_ids; id++) {
2074+ const int32_t i02 = *(const int32_t *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
20042075
2005- ggml_cuda_mul_mat (ctx, &src0_row, &src1_row, &dst_row);
2076+ GGML_ASSERT (i02 >= 0 && i02 < n_as);
2077+
2078+ const int64_t i11 = id % ne11;
2079+ const int64_t i12 = iid1;
2080+
2081+ const int64_t i1 = id;
2082+ const int64_t i2 = i12;
2083+
2084+ src0_row.data = src0_original + i02*nb02;
2085+ src1_row.data = src1_original + i11*nb11 + i12*nb12;
2086+ dst_row.data = dst_original + i1*nb1 + i2*nb2;
2087+
2088+ ggml_cuda_mul_mat (ctx, &src0_row, &src1_row, &dst_row);
2089+ }
20062090 }
20072091 } else {
20082092 ggml_cuda_pool_alloc<char > src1_contiguous (ctx.pool (), sizeof (float )*ggml_nelements (src1));
@@ -2011,55 +2095,104 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
20112095 src1_row.data = src1_contiguous.get ();
20122096 dst_row.data = dst_contiguous.get ();
20132097
2014- for (int32_t row_id = 0 ; row_id < n_as; ++row_id ) {
2098+ for (int64_t i02 = 0 ; i02 < n_as; i02++ ) {
20152099 int64_t num_src1_rows = 0 ;
2016- for (int64_t i01 = 0 ; i01 < ids->ne [1 ]; i01++) {
2017- const int32_t row_id_i = *(const int32_t *) (ids_host.data () + i01*ids->nb [1 ] + id*ids->nb [0 ]);
20182100
2019- if (row_id_i != row_id) {
2020- continue ;
2021- }
2101+ for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2102+ for (int64_t id = 0 ; id < n_ids; id++) {
2103+ const int32_t row_id_i = *(const int32_t *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
2104+
2105+ if (row_id_i != i02) {
2106+ continue ;
2107+ }
20222108
2023- GGML_ASSERT (row_id >= 0 && row_id < n_as);
2109+ GGML_ASSERT (i02 >= 0 && i02 < n_as);
20242110
2025- CUDA_CHECK (cudaMemcpyAsync (src1_contiguous.get () + num_src1_rows*nb11, src1_original + i01*nb11,
2026- nb11, cudaMemcpyDeviceToDevice, stream));
2027- num_src1_rows++;
2111+ #ifdef MMID_MEMCPY
2112+ const int64_t i11 = id % ne11;
2113+ const int64_t i12 = iid1;
2114+ CUDA_CHECK (cudaMemcpyAsync (src1_contiguous.get () + num_src1_rows*nb11,
2115+ src1_original + i11*nb11 + i12*nb12,
2116+ nb11, cudaMemcpyDeviceToDevice, stream));
2117+ #endif
2118+ num_src1_rows++;
2119+ }
20282120 }
20292121
20302122 if (num_src1_rows == 0 ) {
20312123 continue ;
20322124 }
20332125
2034- src0_row.data = src0_original + row_id*src0->nb [2 ];
2126+ #ifndef MMID_MEMCPY
2127+ ggml_cuda_pool_alloc<int > dev_cur_src1_row (ctx.pool (), 1 );
2128+ ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping (ctx.pool (), num_src1_rows);
2129+ CUDA_CHECK (cudaMemsetAsync (dev_cur_src1_row.get (), 0 , sizeof (int ), stream));
20352130
2036- src1_row.ne [1 ] = num_src1_rows;
2037- dst_row.ne [1 ] = num_src1_rows;
2131+ {
2132+ dim3 block_dims (std::min ((uint)nb11, 1024u ));
2133+ dim3 grid_dims (ids->ne [1 ], n_ids);
2134+ k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
2135+ src1_original, src1_contiguous.get (),
2136+ dev_cur_src1_row.get (), dev_row_mapping.get (),
2137+ ids_dev, i02, ids->nb [1 ], ids->nb [0 ],
2138+ ids->ne [1 ], n_ids,
2139+ ne11,
2140+ nb11, nb12);
2141+ CUDA_CHECK (cudaGetLastError ());
2142+ }
2143+ #endif
2144+
2145+ src0_row.data = src0_original + i02*nb02;
20382146
2147+ GGML_ASSERT (nb11 == sizeof (float )*ne10);
2148+ GGML_ASSERT (nb1 == sizeof (float )*ne0);
2149+
2150+ src1_row.ne [1 ] = num_src1_rows;
20392151 src1_row.nb [1 ] = nb11;
20402152 src1_row.nb [2 ] = num_src1_rows*nb11;
20412153 src1_row.nb [3 ] = num_src1_rows*nb11;
20422154
2155+ dst_row.ne [1 ] = num_src1_rows;
20432156 dst_row.nb [1 ] = nb1;
20442157 dst_row.nb [2 ] = num_src1_rows*nb1;
20452158 dst_row.nb [3 ] = num_src1_rows*nb1;
20462159
20472160 ggml_cuda_mul_mat (ctx, &src0_row, &src1_row, &dst_row);
20482161
2162+ #ifndef MMID_MEMCPY
2163+ {
2164+ dim3 block_dims (std::min ((uint)nb1, 1024u ));
2165+ dim3 grid_dims (num_src1_rows);
2166+ k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
2167+ dst_original, dst_contiguous.get (),
2168+ dev_row_mapping.get (),
2169+ num_src1_rows, nb1, nb2);
2170+ CUDA_CHECK (cudaGetLastError ());
2171+ }
2172+ #endif
2173+
2174+ #ifdef MMID_MEMCPY
20492175 num_src1_rows = 0 ;
2050- for (int64_t i01 = 0 ; i01 < ids->ne [1 ]; i01++) {
2051- const int32_t row_id_i = *(const int32_t *) (ids_host.data () + i01*ids->nb [1 ] + id*ids->nb [0 ]);
2176+ for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2177+ for (int64_t id = 0 ; id < n_ids; id++) {
2178+ const int32_t row_id_i = *(const int32_t *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
20522179
2053- if (row_id_i != row_id ) {
2054- continue ;
2055- }
2180+ if (row_id_i != i02 ) {
2181+ continue ;
2182+ }
20562183
2057- GGML_ASSERT (row_id >= 0 && row_id < n_as);
2184+ GGML_ASSERT (i02 >= 0 && i02 < n_as);
20582185
2059- CUDA_CHECK (cudaMemcpyAsync (dst_original + i01*nb1, dst_contiguous.get () + num_src1_rows*nb1,
2060- nb1, cudaMemcpyDeviceToDevice, stream));
2061- num_src1_rows++;
2186+ const int64_t i1 = id;
2187+ const int64_t i2 = iid1;
2188+
2189+ CUDA_CHECK (cudaMemcpyAsync (dst_original + i1*nb1 + i2*nb2,
2190+ dst_contiguous.get () + num_src1_rows*nb1,
2191+ nb1, cudaMemcpyDeviceToDevice, stream));
2192+ num_src1_rows++;
2193+ }
20622194 }
2195+ #endif
20632196 }
20642197 }
20652198}
@@ -2487,7 +2620,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
24872620GGML_CALL static bool ggml_backend_cuda_offload_op (ggml_backend_t backend, const ggml_tensor * op) {
24882621 const int min_batch_size = 32 ;
24892622
2490- return op->ne [1 ] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
2623+ return (op->ne [1 ] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
2624+ (op->ne [2 ] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
24912625
24922626 GGML_UNUSED (backend);
24932627}
0 commit comments