@@ -214,6 +214,11 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
214214static_assert (K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2 , " K_QUANTS_PER_ITERATION must be 1 or 2" );
215215#endif
216216
217+ struct ggml_tensor_extra_gpu {
218+ void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
219+ cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
220+ };
221+
217222static __global__ void add_f32 (const float * x, const float * y, float * dst, const int k) {
218223 const int i = blockDim .x *blockIdx .x + threadIdx .x ;
219224
@@ -1995,7 +2000,6 @@ inline void ggml_cuda_op_add(
19952000 } else {
19962001 GGML_ASSERT (false );
19972002 }
1998- CUDA_CHECK (cudaGetLastError ());
19992003
20002004 (void ) src1;
20012005 (void ) dst;
@@ -2027,7 +2031,6 @@ inline void ggml_cuda_op_mul(
20272031
20282032 // compute
20292033 mul_f32_cuda (src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
2030- CUDA_CHECK (cudaGetLastError ());
20312034 }
20322035
20332036 (void ) dst;
@@ -2048,7 +2051,6 @@ inline void ggml_cuda_op_silu(
20482051
20492052 // compute
20502053 silu_f32_cuda (src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
2051- CUDA_CHECK (cudaGetLastError ());
20522054
20532055 (void ) src1;
20542056 (void ) dst;
@@ -2071,7 +2073,6 @@ inline void ggml_cuda_op_rms_norm(
20712073
20722074 // compute
20732075 rms_norm_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
2074- CUDA_CHECK (cudaGetLastError ());
20752076
20762077 (void ) src1;
20772078 (void ) dst;
@@ -2150,7 +2151,6 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
21502151 GGML_ASSERT (false );
21512152 break ;
21522153 }
2153- CUDA_CHECK (cudaGetLastError ());
21542154
21552155#ifdef GGML_CUDA_DMMV_F16
21562156 if (src1_convert_f16) {
@@ -2230,7 +2230,6 @@ inline void ggml_cuda_op_rope(
22302230
22312231 // compute
22322232 rope_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
2233- CUDA_CHECK (cudaGetLastError ());
22342233
22352234 (void ) dst;
22362235 (void ) src0_ddq_i;
@@ -2254,7 +2253,6 @@ inline void ggml_cuda_op_diag_mask_inf(
22542253
22552254 // compute
22562255 diag_mask_inf_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
2257- CUDA_CHECK (cudaGetLastError ());
22582256
22592257 (void ) dst;
22602258 (void ) src0_ddq_i;
@@ -2276,7 +2274,6 @@ inline void ggml_cuda_op_soft_max(
22762274
22772275 // compute
22782276 soft_max_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
2279- CUDA_CHECK (cudaGetLastError ());
22802277
22812278 (void ) src1;
22822279 (void ) dst;
@@ -2372,10 +2369,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
23722369 size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0 };
23732370 size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0 };
23742371
2375- // if multiple GPUs are used they need to wait for the main GPU to finish
2372+ // if multiple devices are used they need to wait for the main device
2373+ // here an event is recorded that signifies that the main device has finished calculating the input data
23762374 if (split && g_device_count > 1 ) {
23772375 CUDA_CHECK (cudaSetDevice (g_main_device));
2378- CUDA_CHECK (cudaDeviceSynchronize ( ));
2376+ CUDA_CHECK (cudaEventRecord (src0_extra-> events [g_main_device], g_cudaStreams_main[g_main_device] ));
23792377 }
23802378
23812379 for (int id = 0 ; id < g_device_count; ++id) {
@@ -2401,6 +2399,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
24012399 int64_t row_diff = row_high - row_low;
24022400
24032401 cudaSetDevice (id);
2402+ cudaStream_t cudaStream_main = g_cudaStreams_main[id];
2403+
2404+ // wait for main GPU data if necessary
2405+ if (split && id != g_main_device) {
2406+ CUDA_CHECK (cudaStreamWaitEvent (cudaStream_main, src0_extra->events [g_main_device]));
2407+ }
24042408
24052409 if (src0_on_device && src0_is_contiguous) {
24062410 if (src0_is_f32) {
@@ -2476,8 +2480,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
24762480 }
24772481 const int64_t i11 = i13*ne12 + i12;
24782482
2479- cudaStream_t cudaStream_main = g_cudaStreams_main[id];
2480-
24812483 // for split tensors the data begins at i0 == i0_offset_low
24822484 char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
24832485 float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
@@ -2537,6 +2539,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
25372539
25382540 // do the computation
25392541 op (src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
2542+ CUDA_CHECK (cudaGetLastError ());
25402543
25412544 // copy dst to host or other device if necessary
25422545 if (!dst_on_device) {
@@ -2566,6 +2569,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
25662569 CUDA_CHECK (cudaMemcpyAsync (dhf_dst_i, dst_ddf_i, dst_stride*sizeof (float ), kind, cudaStream_main));
25672570 }
25682571 }
2572+
2573+ // signify to main device that other device is done
2574+ if (split && g_device_count > 1 && id != g_main_device) {
2575+ CUDA_CHECK (cudaEventRecord (src0_extra->events [id], cudaStream_main));
2576+ }
25692577 }
25702578 }
25712579 }
@@ -2577,7 +2585,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
25772585 }
25782586
25792587 CUDA_CHECK (cudaSetDevice (id));
2580- CUDA_CHECK (cudaDeviceSynchronize ());
25812588
25822589 if (src0_asq[id] > 0 ) {
25832590 ggml_cuda_pool_free (src0_ddq[id], src0_asq[id]);
@@ -2592,6 +2599,21 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
25922599 ggml_cuda_pool_free (dst_ddf[id], dst_asf[id]);
25932600 }
25942601 }
2602+
2603+ // main device waits for all other devices to be finished
2604+ if (split && g_device_count > 1 ) {
2605+ CUDA_CHECK (cudaSetDevice (g_main_device));
2606+ for (int id = 0 ; id < g_device_count; ++id) {
2607+ if (id != g_main_device) {
2608+ CUDA_CHECK (cudaStreamWaitEvent (g_cudaStreams_main[g_main_device], src0_extra->events [id]));
2609+ }
2610+ }
2611+ }
2612+
2613+ if (dst->backend == GGML_BACKEND_CPU) {
2614+ CUDA_CHECK (cudaSetDevice (g_main_device));
2615+ CUDA_CHECK (cudaDeviceSynchronize ());
2616+ }
25952617}
25962618
25972619void ggml_cuda_add (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2831,6 +2853,10 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
28312853 cudaMemcpy (buf, buf_host, size, cudaMemcpyHostToDevice);
28322854
28332855 extra->data_device [id] = buf;
2856+
2857+ if (backend == GGML_BACKEND_GPU_SPLIT) {
2858+ CUDA_CHECK (cudaEventCreateWithFlags (&extra->events [id], cudaEventDisableTiming));
2859+ }
28342860 }
28352861
28362862 tensor->extra = extra;
@@ -2844,12 +2870,15 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
28442870 ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra ;
28452871
28462872 for (int id = 0 ; id < g_device_count; ++id) {
2847- if (extra->data_device [id] == nullptr ) {
2848- continue ;
2873+ if (extra->data_device [id] != nullptr ) {
2874+ CUDA_CHECK (cudaSetDevice (id));
2875+ CUDA_CHECK (cudaFree (extra->data_device [id]));
28492876 }
28502877
2851- CUDA_CHECK (cudaSetDevice (id));
2852- CUDA_CHECK (cudaFree (extra->data_device [id]));
2878+ if (extra->events [id] != nullptr ) {
2879+ CUDA_CHECK (cudaSetDevice (id));
2880+ CUDA_CHECK (cudaEventDestroy (extra->events [id]));
2881+ }
28532882 }
28542883
28552884 delete extra;
0 commit comments