@@ -13,6 +13,29 @@ __device__ float __forceinline__ t2f32<half>(half val) {
1313    return  __half2float (val);
1414}
1515
16+ struct  soft_max_params  {
17+ 
18+     int64_t  nheads;
19+     uint32_t  n_head_log2;
20+     int64_t  ncols;
21+     int64_t  nrows_x;
22+     int64_t  nrows_y;
23+     int64_t  ne00;
24+     int64_t  ne01;
25+     int64_t  ne02;
26+     int64_t  ne03;
27+     int64_t  nb11;
28+     int64_t  nb12;
29+     int64_t  nb13;
30+ 
31+     int64_t  ne12;
32+     int64_t  ne13;
33+     float  scale;
34+     float  max_bias;
35+     float  m0;
36+     float  m1;
37+ };
38+ 
1639//  When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
1740//  As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
1841#ifdef  __clang__
@@ -21,24 +44,32 @@ __device__ float __forceinline__ t2f32<half>(half val) {
2144#endif  //  __clang__
2245template  <bool  use_shared, int  ncols_template, int  block_size_template, typename  T>
2346static  __global__  void  soft_max_f32 (
24-         const  float  * x, const  T * mask, float  * dst, const  int  ncols_par, const  int  nrows_y,
25-         const  float  scale, const  float  max_bias, const  float  m0, const  float  m1, uint32_t  n_head_log2) {
26-     const  int  ncols = ncols_template == 0  ? ncols_par : ncols_template;
47+         const  float  * x, const  T * mask, float  * dst, const  soft_max_params p) {
48+     const  int  ncols = ncols_template == 0  ? p.ncols  : ncols_template;
2749
2850    const  int  tid  = threadIdx .x ;
29-     const  int  rowx = blockIdx .x ;
30-     const  int  rowy = rowx % nrows_y; //  broadcast the mask in the row dimension
51+ 
52+     const  int64_t  i03 = blockIdx .z ;
53+     const  int64_t  i02 = blockIdx .y ;
54+     const  int64_t  i01 = blockIdx .x ;
55+ 
56+     // TODO: noncontigous inputs/outputs
57+     const  int  rowx = blockIdx .x  + blockIdx .y  * gridDim .x  + blockIdx .z  * gridDim .x  * gridDim .y ;
58+ 
59+     const  int64_t  i11 = i01;
60+     const  int64_t  i12 = i02 % p.ne12 ;
61+     const  int64_t  i13 = i03 % p.ne13 ;
3162
3263    x    += int64_t (rowx)*ncols;
33-     mask += int64_t (rowy)*ncols  * (mask != nullptr );
64+     mask += (i11*p. nb11  + i12*p. nb12  + i13*p. nb13 ) /  sizeof (T)  * (mask != nullptr );
3465    dst  += int64_t (rowx)*ncols;
3566
3667    const  int  block_size = block_size_template == 0  ? blockDim .x  : block_size_template;
3768
3869    const  int  warp_id = threadIdx .x  / WARP_SIZE;
3970    const  int  lane_id = threadIdx .x  % WARP_SIZE;
4071
41-     const  float  slope = get_alibi_slope (max_bias, rowx/nrows_y,  n_head_log2, m0, m1);
72+     const  float  slope = get_alibi_slope (p. max_bias , i02, p. n_head_log2 , p. m0 , p. m1 );
4273
4374    extern  __shared__  float  data_soft_max_f32[];
4475    float  * buf_iw = data_soft_max_f32; //  shared memory buffer for inter-warp communication
@@ -55,7 +86,7 @@ static __global__ void soft_max_f32(
5586            break ;
5687        }
5788
58-         const  float  val = x[col]*scale + (mask ? slope*t2f32 (mask[col]) : 0 .0f );
89+         const  float  val = x[col]*p. scale  + (mask ? slope*t2f32 (mask[col]) : 0 .0f );
5990
6091        vals[col] = val;
6192        max_val = max (max_val, val);
@@ -151,63 +182,60 @@ static __global__ void soft_max_back_f32(
151182}
152183
153184template <typename  T>
154- static  void  soft_max_f32_cuda (const  float  * x, const  T * mask, float  * dst, const  int  ncols_x,  const   int  nrows_x,  const   int  nrows_y,  const   float  scale,  const   float  max_bias , cudaStream_t stream) {
185+ static  void  soft_max_f32_cuda (const  float  * x, const  T * mask, float  * dst, const  soft_max_params & params , cudaStream_t stream) {
155186    int  nth = WARP_SIZE;
187+     const  int64_t  ncols_x = params.ncols ;
188+ 
156189    while  (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2 ;
157190    const  dim3  block_dims (nth,     1 , 1 );
158-     const  dim3  block_nums (nrows_x,  1 ,  1 );
191+     const  dim3  block_nums (params. ne01 , params. ne02 , params. ne03 );
159192    const  size_t  nbytes_shared = (GGML_PAD (ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof (float );
160193    static_assert (CUDA_SOFT_MAX_BLOCK_SIZE == 1024 , " These values need to be adjusted."  );
161194
162-     const  uint32_t  n_head      = nrows_x/nrows_y;
163-     const  uint32_t  n_head_log2 = 1u  << (uint32_t ) floorf (log2f ((float ) n_head));
164- 
165-     const  float  m0 = powf (2 .0f , -(max_bias       ) / n_head_log2);
166-     const  float  m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
167195
168196    //  FIXME: this limit could be raised by ~2-4x on Ampere or newer
169197    if  (nbytes_shared < ggml_cuda_info ().devices [ggml_cuda_get_device ()].smpb ) {
170198        switch  (ncols_x) {
171199            case  32 :
172200                soft_max_f32<true ,   32 ,   32 ><<<block_nums, block_dims, nbytes_shared, stream>>> 
173-                     (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
201+                     (x, mask, dst, params );
174202                break ;
175203            case  64 :
176204                soft_max_f32<true ,   64 ,   64 ><<<block_nums, block_dims, nbytes_shared, stream>>> 
177-                     (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
205+                     (x, mask, dst, params );
178206                break ;
179207            case  128 :
180208                soft_max_f32<true ,  128 ,  128 ><<<block_nums, block_dims, nbytes_shared, stream>>> 
181-                     (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
209+                     (x, mask, dst, params );
182210                break ;
183211            case  256 :
184212                soft_max_f32<true ,  256 ,  256 ><<<block_nums, block_dims, nbytes_shared, stream>>> 
185-                     (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
213+                     (x, mask, dst, params );
186214                break ;
187215            case  512 :
188216                soft_max_f32<true ,  512 ,  512 ><<<block_nums, block_dims, nbytes_shared, stream>>> 
189-                     (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
217+                     (x, mask, dst, params );
190218                break ;
191219            case  1024 :
192220                soft_max_f32<true , 1024 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>> 
193-                     (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
221+                     (x, mask, dst, params );
194222                break ;
195223            case  2048 :
196224                soft_max_f32<true , 2048 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>> 
197-                     (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
225+                     (x, mask, dst, params );
198226                break ;
199227            case  4096 :
200228                soft_max_f32<true , 4096 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>> 
201-                     (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
229+                     (x, mask, dst, params );
202230                break ;
203231            default :
204232                soft_max_f32<true ,    0 ,    0 ><<<block_nums, block_dims, nbytes_shared, stream>>> 
205-                     (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
233+                     (x, mask, dst, params );
206234                break ;
207235        }
208236    } else  {
209237        const  size_t  nbytes_shared_low = WARP_SIZE*sizeof (float );
210-         soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
238+         soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (x, mask, dst, params );
211239    }
212240}
213241
@@ -235,10 +263,11 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
235263
236264    GGML_ASSERT (!src1 || src1->type  == GGML_TYPE_F16 || src1->type  == GGML_TYPE_F32); //  src1 contains mask and it is optional
237265
238-     const  int64_t  ne00    = src0->ne [0 ];
239266    const  int64_t  nrows_x = ggml_nrows (src0);
240267    const  int64_t  nrows_y = src0->ne [1 ];
241268
269+     const  int64_t  ne00 = src0->ne [0 ];
270+ 
242271    float  scale    = 1 .0f ;
243272    float  max_bias = 0 .0f ;
244273
@@ -247,10 +276,44 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
247276
248277    const  bool  use_f16 = (src1 && src1->type  == GGML_TYPE_F16);
249278
279+     const  int64_t  nb11 = src1 ? src1->nb [1 ] : 1 ;
280+     const  int64_t  nb12 = src1 ? src1->nb [2 ] : 1 ;
281+     const  int64_t  nb13 = src1 ? src1->nb [3 ] : 1 ;
282+ 
283+     const  int64_t  ne12 = src1 ? src1->ne [2 ] : 1 ;
284+     const  int64_t  ne13 = src1 ? src1->ne [3 ] : 1 ;
285+ 
286+     const  uint32_t  n_head      = src0->ne [2 ];
287+     const  uint32_t  n_head_log2 = 1u  << (uint32_t ) floorf (log2f ((float ) n_head));
288+ 
289+     const  float  m0 = powf (2 .0f , -(max_bias       ) / n_head_log2);
290+     const  float  m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
291+ 
292+ 
293+     soft_max_params params = {};
294+     params.nheads  = src0->ne [2 ];
295+     params.n_head_log2  = n_head_log2;
296+     params.ncols  = ne00;
297+     params.nrows_x  = nrows_x;
298+     params.nrows_y  = nrows_y;
299+     params.ne00  = src0->ne [0 ];
300+     params.ne01  = src0->ne [1 ];
301+     params.ne02  = src0->ne [2 ];
302+     params.ne03  = src0->ne [3 ];
303+     params.nb11  = nb11;
304+     params.nb12  = nb12;
305+     params.nb13  = nb13;
306+     params.ne12  = ne12;
307+     params.ne13  = ne13;
308+     params.scale  = scale;
309+     params.max_bias  = max_bias;
310+     params.m0  = m0;
311+     params.m1  = m1;
312+ 
250313    if  (use_f16) {
251-         soft_max_f32_cuda (src0_d, (const  half  *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias , stream);
314+         soft_max_f32_cuda (src0_d, (const  half  *) src1_d, dst_d, params , stream);
252315    } else  {
253-         soft_max_f32_cuda (src0_d, (const  float  *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias , stream);
316+         soft_max_f32_cuda (src0_d, (const  float  *) src1_d, dst_d, params , stream);
254317    }
255318}
256319
0 commit comments