116116#include " ggml.h"
117117#include " ggml-backend-impl.h"
118118
119+ #define CC_PASCAL 600
119120#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
120121#define CC_VOLTA 700
121122#define CC_OFFSET_AMD 1000000
@@ -585,6 +586,14 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
585586 return a;
586587}
587588
589+ static __device__ __forceinline__ half2 warp_reduce_sum (half2 a) {
590+ #pragma unroll
591+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
592+ a += __shfl_xor_sync (0xffffffff , a, mask, 32 );
593+ }
594+ return a;
595+ }
596+
588597static __device__ __forceinline__ float warp_reduce_max (float x) {
589598#pragma unroll
590599 for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
@@ -593,6 +602,19 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
593602 return x;
594603}
595604
605+ static __device__ __forceinline__ half2 warp_reduce_max (half2 x) {
606+ #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
607+ (void ) x;
608+ bad_arch ();
609+ #else
610+ #pragma unroll
611+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
612+ x = __hmax2 (x, __shfl_xor_sync (0xffffffff , x, mask, 32 ));
613+ }
614+ return x;
615+ #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
616+ }
617+
596618static __device__ __forceinline__ float op_repeat (const float a, const float b) {
597619 return b;
598620 GGML_UNUSED (a);
@@ -5201,75 +5223,227 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
52015223 dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
52025224}
52035225
5204- static __global__ void soft_max_f32 (const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
5226+ template <int ncols_template, int block_size_template, bool need_check>
5227+ static __global__ void soft_max_f16 (const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
5228+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
5229+ const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
5230+ const int ncols_smem = GGML_PAD (ncols_data/2 , WARP_SIZE);
5231+
5232+ const int tid = threadIdx .x ;
5233+ const int rowx = blockIdx .x ;
5234+ const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
5235+
5236+ const int block_size = block_size_template == 0 ? blockDim .x : block_size_template;
5237+
5238+ const int warp_id = threadIdx .x / WARP_SIZE;
5239+ const int lane_id = threadIdx .x % WARP_SIZE;
5240+
5241+ extern __shared__ half2 data_soft_max_f16[];
5242+ half2 * vals = data_soft_max_f16 + 0 ; // shared memory buffer to cache values between iterations
5243+ half * buf_iw = (half *) (data_soft_max_f16 + ncols_smem); // shared memory buffer for inter-warp communication
5244+
5245+ half2 max_val = make_half2 (-INFINITY, -INFINITY);
5246+
5247+ #pragma unroll
5248+ for (int col0 = 0 ; col0 < ncols_smem; col0 += block_size) {
5249+ const int col_smem = col0 + tid;
5250+ const int col_data = 2 *col0 + 2 *WARP_SIZE*warp_id + lane_id;
5251+
5252+ if (ncols_template == 0 && col_smem >= ncols_smem) {
5253+ break ;
5254+ }
5255+
5256+ const int ix = rowx*ncols_data + col_data;
5257+ const int iy = rowy*ncols_data + col_data;
5258+
5259+ half2 val;
5260+ val.x = x[ix + 0 ]*scale + (y ? y[iy + 0 ] : 0 .0f );
5261+ if (need_check && col_data + WARP_SIZE >= ncols_data) {
5262+ val.y = -INFINITY;
5263+ } else {
5264+ val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0 .0f );
5265+ }
5266+ vals[col_smem] = val;
5267+ max_val = __hmax2 (max_val, val);
5268+ }
5269+
5270+ // find the max value in the block
5271+ max_val = warp_reduce_max (max_val);
5272+ if (block_size > WARP_SIZE) {
5273+ if (warp_id == 0 ) {
5274+ buf_iw[lane_id] = -INFINITY;
5275+ }
5276+ __syncthreads ();
5277+
5278+ if (lane_id == 0 ) {
5279+ buf_iw[warp_id] = __hmax (max_val.x , max_val.y );
5280+ }
5281+ __syncthreads ();
5282+
5283+ max_val = __half2half2 (buf_iw[lane_id]);
5284+ max_val = warp_reduce_max (max_val);
5285+ } else {
5286+ max_val = __half2half2 (__hmax (max_val.x , max_val.y ));
5287+ }
5288+
5289+ half2 tmp = make_half2 (0 .0f , 0 .0f ); // partial sums
5290+
5291+ #pragma unroll
5292+ for (int col0 = 0 ; col0 < ncols_smem; col0 += block_size) {
5293+ const int col_smem = col0 + tid;
5294+
5295+ if (ncols_template == 0 && col_smem >= ncols_smem) {
5296+ break ;
5297+ }
5298+
5299+ const half2 val = h2exp (vals[col_smem] - max_val);
5300+
5301+ tmp += val;
5302+ vals[col_smem] = val;
5303+ }
5304+
5305+ // find the sum of exps in the block
5306+ tmp = warp_reduce_sum (tmp);
5307+ if (block_size > WARP_SIZE) {
5308+ if (warp_id == 0 ) {
5309+ buf_iw[lane_id] = 0 .0f ;
5310+ }
5311+ __syncthreads ();
5312+
5313+ if (lane_id == 0 ) {
5314+ buf_iw[warp_id] = tmp.x + tmp.y ;
5315+ }
5316+ __syncthreads ();
5317+
5318+ tmp = __half2half2 (buf_iw[lane_id]);
5319+ tmp = warp_reduce_sum (tmp);
5320+ } else {
5321+ tmp = __half2half2 (tmp.x + tmp.y );
5322+ }
5323+
5324+ const half2 inv_sum = make_half2 (1 .0f , 1 .0f ) / tmp;
5325+
5326+ #pragma unroll
5327+ for (int col0 = 0 ; col0 < ncols_smem; col0 += block_size) {
5328+ const int col_smem = col0 + tid;
5329+ const int col_data = 2 *col0 + 2 *WARP_SIZE*warp_id + lane_id;
5330+
5331+ if (ncols_template == 0 && col_smem >= ncols_smem) {
5332+ return ;
5333+ }
5334+
5335+ const int idst = rowx*ncols_data + col_data;
5336+ const half2 result = vals[col_smem] * inv_sum;
5337+ dst[idst] = result.x ;
5338+
5339+ if (need_check && col_data + WARP_SIZE >= ncols_data) {
5340+ return ;
5341+ }
5342+
5343+ dst[idst + WARP_SIZE] = result.y ;
5344+ }
5345+ #else
5346+ (void ) x; (void ) y; (void ) dst; (void ) ncols_par; (void ) nrows_y; (void ) scale;
5347+ bad_arch ();
5348+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
5349+ }
5350+
5351+ template <int ncols_template, int block_size_template>
5352+ static __global__ void soft_max_f32 (const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
5353+ const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
5354+
52055355 const int tid = threadIdx .x ;
52065356 const int rowx = blockIdx .x ;
52075357 const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
52085358
5209- const int block_size = blockDim .x ;
5359+ const int block_size = block_size_template == 0 ? blockDim .x : block_size_template ;
52105360
52115361 const int warp_id = threadIdx .x / WARP_SIZE;
52125362 const int lane_id = threadIdx .x % WARP_SIZE;
52135363
5214- __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
5364+ extern __shared__ float data_soft_max_f32[];
5365+ float * vals = data_soft_max_f32 + 0 ; // shared memory buffer to cache values between iterations
5366+ float * buf_iw = data_soft_max_f32 + ncols; // shared memory buffer for inter-warp communication
52155367
52165368 float max_val = -INFINITY;
52175369
5218- for (int col = tid; col < ncols; col += block_size) {
5370+ #pragma unroll
5371+ for (int col0 = 0 ; col0 < ncols; col0 += block_size) {
5372+ const int col = col0 + tid;
5373+
5374+ if (ncols_template == 0 && col >= ncols) {
5375+ break ;
5376+ }
5377+
52195378 const int ix = rowx*ncols + col;
52205379 const int iy = rowy*ncols + col;
5221- max_val = max (max_val, x[ix]*scale + (y ? y[iy] : 0 .0f ));
5380+
5381+ const float val = x[ix]*scale + (y ? y[iy] : 0 .0f );
5382+ vals[col] = val;
5383+ max_val = max (max_val, val);
52225384 }
52235385
52245386 // find the max value in the block
52255387 max_val = warp_reduce_max (max_val);
52265388 if (block_size > WARP_SIZE) {
52275389 if (warp_id == 0 ) {
5228- buf [lane_id] = -INFINITY;
5390+ buf_iw [lane_id] = -INFINITY;
52295391 }
52305392 __syncthreads ();
52315393
52325394 if (lane_id == 0 ) {
5233- buf [warp_id] = max_val;
5395+ buf_iw [warp_id] = max_val;
52345396 }
52355397 __syncthreads ();
52365398
5237- max_val = buf [lane_id];
5399+ max_val = buf_iw [lane_id];
52385400 max_val = warp_reduce_max (max_val);
52395401 }
52405402
5241- float tmp = 0 .f ;
5403+ float tmp = 0 .0f ; // partial sum
52425404
5243- for (int col = tid; col < ncols; col += block_size) {
5244- const int ix = rowx*ncols + col;
5245- const int iy = rowy*ncols + col;
5246- const float val = expf ((x[ix]*scale + (y ? y[iy] : 0 .0f )) - max_val);
5405+ #pragma unroll
5406+ for (int col0 = 0 ; col0 < ncols; col0 += block_size) {
5407+ const int col = col0 + tid;
5408+
5409+ if (ncols_template == 0 && col >= ncols) {
5410+ break ;
5411+ }
5412+
5413+ const float val = expf (vals[col] - max_val);
52475414 tmp += val;
5248- dst[ix ] = val;
5415+ vals[col ] = val;
52495416 }
52505417
52515418 // find the sum of exps in the block
52525419 tmp = warp_reduce_sum (tmp);
52535420 if (block_size > WARP_SIZE) {
52545421 if (warp_id == 0 ) {
5255- buf [lane_id] = 0 .f ;
5422+ buf_iw [lane_id] = 0 .0f ;
52565423 }
52575424 __syncthreads ();
52585425
52595426 if (lane_id == 0 ) {
5260- buf [warp_id] = tmp;
5427+ buf_iw [warp_id] = tmp;
52615428 }
52625429 __syncthreads ();
52635430
5264- tmp = buf [lane_id];
5431+ tmp = buf_iw [lane_id];
52655432 tmp = warp_reduce_sum (tmp);
52665433 }
52675434
5268- const float inv_tmp = 1 .f / tmp;
5435+ const float inv_sum = 1 .0f / tmp;
52695436
5270- for (int col = tid; col < ncols; col += block_size) {
5271- const int i = rowx*ncols + col;
5272- dst[i] *= inv_tmp;
5437+ #pragma unroll
5438+ for (int col0 = 0 ; col0 < ncols; col0 += block_size) {
5439+ const int col = col0 + tid;
5440+
5441+ if (ncols_template == 0 && col >= ncols) {
5442+ return ;
5443+ }
5444+
5445+ const int idst = rowx*ncols + col;
5446+ dst[idst] = vals[col] * inv_sum;
52735447 }
52745448}
52755449
@@ -6543,12 +6717,80 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
65436717 diag_mask_inf_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols_x, rows_per_channel, n_past);
65446718}
65456719
6720+ static void soft_max_f16_cuda (const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
6721+ int nth = WARP_SIZE;
6722+ while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2 ;
6723+ const dim3 block_dims (nth, 1 , 1 );
6724+ const dim3 block_nums (nrows_x, 1 , 1 );
6725+ const int64_t shmem = (ncols_x + WARP_SIZE)*sizeof (half);
6726+ static_assert (CUDA_SOFT_MAX_BLOCK_SIZE == 1024 , " These values need to be adjusted." );
6727+ switch (ncols_x) {
6728+ case 32 :
6729+ soft_max_f16<32 , 32 , true ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6730+ break ;
6731+ case 64 :
6732+ soft_max_f16<64 , 32 , false ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6733+ break ;
6734+ case 128 :
6735+ soft_max_f16<128 , 64 , false ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6736+ break ;
6737+ case 256 :
6738+ soft_max_f16<256 , 128 , false ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6739+ break ;
6740+ case 512 :
6741+ soft_max_f16<512 , 256 , false ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6742+ break ;
6743+ case 1024 :
6744+ soft_max_f16<1024 , 512 , false ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6745+ break ;
6746+ case 2048 :
6747+ soft_max_f16<2048 , 1024 , false ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6748+ break ;
6749+ case 4096 :
6750+ soft_max_f16<4096 , 1024 , false ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6751+ break ;
6752+ default :
6753+ soft_max_f16<0 , 0 , true ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6754+ break ;
6755+ }
6756+ }
6757+
65466758static void soft_max_f32_cuda (const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
65476759 int nth = WARP_SIZE;
65486760 while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2 ;
65496761 const dim3 block_dims (nth, 1 , 1 );
65506762 const dim3 block_nums (nrows_x, 1 , 1 );
6551- soft_max_f32<<<block_nums, block_dims, 0 , stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6763+ const int64_t shmem = (ncols_x + WARP_SIZE)*sizeof (float );
6764+ static_assert (CUDA_SOFT_MAX_BLOCK_SIZE == 1024 , " These values need to be adjusted." );
6765+ switch (ncols_x) {
6766+ case 32 :
6767+ soft_max_f32<32 , 32 ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6768+ break ;
6769+ case 64 :
6770+ soft_max_f32<64 , 64 ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6771+ break ;
6772+ case 128 :
6773+ soft_max_f32<128 , 128 ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6774+ break ;
6775+ case 256 :
6776+ soft_max_f32<256 , 256 ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6777+ break ;
6778+ case 512 :
6779+ soft_max_f32<512 , 512 ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6780+ break ;
6781+ case 1024 :
6782+ soft_max_f32<1024 , 1024 ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6783+ break ;
6784+ case 2048 :
6785+ soft_max_f32<2048 , 1024 ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6786+ break ;
6787+ case 4096 :
6788+ soft_max_f32<4096 , 1024 ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6789+ break ;
6790+ default :
6791+ soft_max_f32<0 , 0 ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6792+ break ;
6793+ }
65526794}
65536795
65546796static void im2col_f32_f16_cuda (const float * x, half* dst,
@@ -7873,7 +8115,21 @@ static void ggml_cuda_op_soft_max(
78738115 float scale = 1 .0f ;
78748116 memcpy (&scale, dst->op_params , sizeof (float ));
78758117
7876- soft_max_f32_cuda (src0_dd, src1 ? src1_dd : nullptr , dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
8118+ #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8119+ const bool use_f16_soft_max = false ;
8120+ #else
8121+ #ifdef GGML_CUDA_F16
8122+ const bool use_f16_soft_max = true ;
8123+ #else
8124+ const bool use_f16_soft_max = g_device_caps[g_main_device].cc >= CC_VOLTA;
8125+ #endif // GGML_CUDA_F16
8126+ #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8127+
8128+ if (use_f16_soft_max) {
8129+ soft_max_f16_cuda (src0_dd, src1 ? src1_dd : nullptr , dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
8130+ } else {
8131+ soft_max_f32_cuda (src0_dd, src1 ? src1_dd : nullptr , dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
8132+ }
78778133
78788134 (void ) dst;
78798135}
0 commit comments