diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 50d7763dcdd..033c667706a 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1554,6 +1554,12 @@ struct ggml_cuda_pdl_config { }; #endif //defined(GGML_CUDA_USE_PDL) +// PDL and __restrict__ need to be mutually exclusive, see https://github.com/ggml-org/llama.cpp/pull/24030 +# if (defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER) +# define GGML_CUDA_RESTRICT +# else +# define GGML_CUDA_RESTRICT __restrict__ +# endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER template static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_kernel_launch_params & launch_params, Args&&... args) { diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index d650b5fbd0f..48636adc64f 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -678,8 +678,8 @@ static __global__ void flash_attn_mask_to_KV_max( template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup_uniform( - float * __restrict__ dst, - const float2 * __restrict__ dst_fixup, + float * dst_ptr, + const float2 * dst_fixup_ptr, const int ne01, const int ne02, const int ne12, const int nblocks_stream_k, const int gqa_ratio, @@ -689,6 +689,8 @@ static __global__ void flash_attn_stream_k_fixup_uniform( const uint3 fd_iter_j) { constexpr int ncols = ncols1*ncols2; ggml_cuda_pdl_lc(); + float * GGML_CUDA_RESTRICT dst = dst_ptr; + const float2 * GGML_CUDA_RESTRICT dst_fixup = dst_fixup_ptr; const int tile_idx = blockIdx.x; // One block per output tile. const int j = blockIdx.y; @@ -760,8 +762,8 @@ static __global__ void flash_attn_stream_k_fixup_uniform( template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup_general( - float * __restrict__ dst, - const float2 * __restrict__ dst_fixup, + float * dst_ptr, + const float2 * dst_fixup_ptr, const int ne01, const int ne02, const int gqa_ratio, const int total_work, @@ -769,6 +771,8 @@ static __global__ void flash_attn_stream_k_fixup_general( const uint3 fd_iter_k_j_z, const uint3 fd_iter_k_j, const uint3 fd_iter_k) { + float * GGML_CUDA_RESTRICT dst = dst_ptr; + const float2 * GGML_CUDA_RESTRICT dst_fixup = dst_fixup_ptr; constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; @@ -867,11 +871,14 @@ static __global__ void flash_attn_stream_k_fixup_general( template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_combine_results( - const float * __restrict__ VKQ_parts, - const float2 * __restrict__ VKQ_meta, - float * __restrict__ dst, + const float * VKQ_parts_ptr, + const float2 * VKQ_meta_ptr, + float * dst_ptr, const int parallel_blocks) { ggml_cuda_pdl_lc(); + const float * GGML_CUDA_RESTRICT VKQ_parts = VKQ_parts_ptr; + const float2 * GGML_CUDA_RESTRICT VKQ_meta = VKQ_meta_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; // Dimension 0: threadIdx.x // Dimension 1: blockIdx.x // Dimension 2: blockIdx.y @@ -1153,8 +1160,8 @@ void launch_fattn( GGML_ASSERT(block_dim.x % warp_size == 0); - // disabled PDL enrollment for now due to a compiler bug. - fattn_kernel<<>>( + ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num, block_dim, nbytes_shared, main_stream); + ggml_cuda_kernel_launch(fattn_kernel, launch_params, (const char *) Q->data, K_data, V_data, diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 3c8b6eaaf24..4ce3fe84552 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1703,14 +1703,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( template __launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, + const char * Q_ptr, + const char * K_ptr, + const char * V_ptr, + const char * mask_ptr, + const char * sinks_ptr, + const int * KV_max_ptr, + float * dst_ptr, + float2 * dst_meta_ptr, const float scale, const float max_bias, const float m0, @@ -1726,6 +1726,14 @@ static __global__ void flash_attn_ext_f16( const int32_t nb31, const int32_t nb32, const int64_t nb33) { ggml_cuda_pdl_sync(); // TODO optimize placement #if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)) + const char * GGML_CUDA_RESTRICT Q = Q_ptr; + const char * GGML_CUDA_RESTRICT K = K_ptr; + const char * GGML_CUDA_RESTRICT V = V_ptr; + const char * GGML_CUDA_RESTRICT mask = mask_ptr; + const char * GGML_CUDA_RESTRICT sinks = sinks_ptr; + const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; + float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr; // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) { @@ -1871,7 +1879,7 @@ static __global__ void flash_attn_ext_f16( (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop); #else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index fac76f13593..0a099810e14 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -788,14 +788,14 @@ static __device__ __forceinline__ void flash_attn_tile_iter( template // D == head size __launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_tile( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, + const char * Q_ptr, + const char * K_ptr, + const char * V_ptr, + const char * mask_ptr, + const char * sinks_ptr, + const int * KV_max_ptr, + float * dst_ptr, + float2 * dst_meta_ptr, const float scale, const float max_bias, const float m0, @@ -810,6 +810,14 @@ static __global__ void flash_attn_tile( const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { #ifdef FLASH_ATTN_AVAILABLE + const char * GGML_CUDA_RESTRICT Q = Q_ptr; + const char * GGML_CUDA_RESTRICT K = K_ptr; + const char * GGML_CUDA_RESTRICT V = V_ptr; + const char * GGML_CUDA_RESTRICT mask = mask_ptr; + const char * GGML_CUDA_RESTRICT sinks = sinks_ptr; + const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; + float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr; // Skip unused kernel variants for faster compilation: @@ -1126,7 +1134,7 @@ static __global__ void flash_attn_tile( } } #else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index b0a6cf67f1a..69dd9368624 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -19,14 +19,14 @@ static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() { template // D == head size __launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 1) static __global__ void flash_attn_ext_vec( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, + const char * Q_ptr, + const char * K_ptr, + const char * V_ptr, + const char * mask_ptr, + const char * sinks_ptr, + const int * KV_max_ptr, + float * dst_ptr, + float2 * dst_meta_ptr, const float scale, const float max_bias, const float m0, @@ -42,6 +42,14 @@ static __global__ void flash_attn_ext_vec( const int32_t nb31, const int32_t nb32, const int64_t nb33) { ggml_cuda_pdl_lc(); #ifdef FLASH_ATTN_AVAILABLE + const char * GGML_CUDA_RESTRICT Q = Q_ptr; + const char * GGML_CUDA_RESTRICT K = K_ptr; + const char * GGML_CUDA_RESTRICT V = V_ptr; + const char * GGML_CUDA_RESTRICT mask = mask_ptr; + const char * GGML_CUDA_RESTRICT sinks = sinks_ptr; + const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; + float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr; // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { @@ -506,7 +514,7 @@ static __global__ void flash_attn_ext_vec( dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]); } #else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 4b6f6501094..6850716fc0d 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -24,14 +24,14 @@ namespace wmma = rocwmma; template __launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1) static __global__ void flash_attn_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, + const char * Q_ptr, + const char * K_ptr, + const char * V_ptr, + const char * mask_ptr, + const char * sinks_ptr, + const int * KV_max_ptr, + float * dst_ptr, + float2 * dst_meta_ptr, const float scale, const float max_bias, const float m0, @@ -46,6 +46,14 @@ static __global__ void flash_attn_ext_f16( const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { #if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)) + const char * GGML_CUDA_RESTRICT Q = Q_ptr; + const char * GGML_CUDA_RESTRICT K = K_ptr; + const char * GGML_CUDA_RESTRICT V = V_ptr; + const char * GGML_CUDA_RESTRICT mask = mask_ptr; + const char * GGML_CUDA_RESTRICT sinks = sinks_ptr; + const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; + float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr; // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -494,7 +502,7 @@ static __global__ void flash_attn_ext_f16( dst_meta[j_dst_unrolled] = dst_meta_val; } #else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 457b695eb2a..eb157b8baf2 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -42,7 +42,7 @@ static __global__ void k_get_rows( template static __global__ void k_get_rows_float( - const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, + const src0_t * src0_ptr, const int32_t * src1_ptr, dst_t * dst_ptr, const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ /*const int64_t ne10,*/ const int64_t ne11, const uint3 ne12_fdv, /*const int64_t ne13,*/ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, @@ -50,6 +50,9 @@ static __global__ void k_get_rows_float( const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { ggml_cuda_pdl_lc(); + const src0_t * GGML_CUDA_RESTRICT src0 = src0_ptr; + const int32_t * GGML_CUDA_RESTRICT src1 = src1_ptr; + dst_t * GGML_CUDA_RESTRICT dst = dst_ptr; ggml_cuda_pdl_sync(); for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) { for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) { diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 09d95f309b4..44b1031663a 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -6,11 +6,15 @@ template static __global__ void mul_mat_vec_f( - const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, + const T * x_ptr, const float * y_ptr, const int32_t * ids_ptr, const ggml_cuda_mm_fusion_args_device fusion, float * dst_ptr, const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const int ids_stride) { + const T * GGML_CUDA_RESTRICT x = x_ptr; + const float * GGML_CUDA_RESTRICT y = y_ptr; + const int32_t * GGML_CUDA_RESTRICT ids = ids_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; const int row = blockIdx.x; // for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens) const int channel_dst = blockIdx.y; diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index ecb6fdedadd..8ce8de85ecf 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -476,12 +476,16 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int template __launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( - const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, + const void * vx_ptr, const void * vy_ptr, const int32_t * ids_ptr, const ggml_cuda_mm_fusion_args_device fusion, float * dst_ptr, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, const uint32_t ids_stride) { + const void * GGML_CUDA_RESTRICT vx = vx_ptr; + const void * GGML_CUDA_RESTRICT vy = vy_ptr; + const int32_t * GGML_CUDA_RESTRICT ids = ids_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; constexpr int qk = ggml_cuda_type_traits::qk; constexpr int qi = ggml_cuda_type_traits::qi; diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 49516965cad..39a500a1704 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -3,10 +3,12 @@ __launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1) static __global__ void quantize_q8_1( - const float * __restrict__ x, void * __restrict__ vy, + const float * x_ptr, void * vy_ptr, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, const int64_t ne0, const uint32_t ne1, const uint3 ne2) { ggml_cuda_pdl_lc(); + const float * GGML_CUDA_RESTRICT x = x_ptr; + void * GGML_CUDA_RESTRICT vy = vy_ptr; const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i0 >= ne0) { diff --git a/ggml/src/ggml-cuda/reduce_rows.cuh b/ggml/src/ggml-cuda/reduce_rows.cuh index 5895d3bf8e5..968c47aa20a 100644 --- a/ggml/src/ggml-cuda/reduce_rows.cuh +++ b/ggml/src/ggml-cuda/reduce_rows.cuh @@ -2,7 +2,9 @@ // Row reduction kernel template - compute sum (norm=false) or mean (norm=true) template -static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) { +static __global__ void reduce_rows_f32(const float * x_ptr, float * dst_ptr, const int ncols) { + const float * GGML_CUDA_RESTRICT x = x_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; const int row = blockIdx.x; const int col = threadIdx.x; diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index e14f96b824c..3b4f004c946 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -111,9 +111,9 @@ static void set_rows_cuda_quant( } template -static __global__ void k_set_rows(const src_t * __restrict__ src0, - const idx_t * __restrict__ src1, - dst_t * __restrict__ dst, +static __global__ void k_set_rows(const src_t * src0_ptr, + const idx_t * src1_ptr, + dst_t * dst_ptr, const int64_t ne_total, const int64_t ne10, const int64_t ne11, @@ -133,6 +133,9 @@ static __global__ void k_set_rows(const src_t * __restrict__ src0, const uint3 ne02, const uint3 ne11_fd, const uint3 ne12_fd) { + const src_t * GGML_CUDA_RESTRICT src0 = src0_ptr; + const idx_t * GGML_CUDA_RESTRICT src1 = src1_ptr; + dst_t * GGML_CUDA_RESTRICT dst = dst_ptr; const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; if (i >= ne_total) { diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 48787b4b890..1463169cf78 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -3,12 +3,16 @@ #include "unary.cuh" template -static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, - const float * __restrict__ bias, +static __global__ void ssm_conv_f32(const float * src0_ptr, const float * src1_ptr, + const float * bias_ptr, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, - float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, + float * dst_ptr, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t n_t) { ggml_cuda_pdl_lc(); + const float * GGML_CUDA_RESTRICT src0 = src0_ptr; + const float * GGML_CUDA_RESTRICT src1 = src1_ptr; + const float * GGML_CUDA_RESTRICT bias = bias_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; GGML_UNUSED(src0_nb0); const int tid = threadIdx.x; const int bidx = blockIdx.x; diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index 412980376ac..2e3f97c7284 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -17,14 +17,22 @@ using namespace cub; #endif // __clang__ template __global__ void __launch_bounds__(splitD, 1) - ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2, - const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5, - const int32_t * __restrict__ src6, float * __restrict__ dst, + ssm_scan_f32(const float * src0_ptr, const float * src1_ptr, const float * src2_ptr, + const float * src3_ptr, const float * src4_ptr, const float * src5_ptr, + const int32_t * src6_ptr, float * dst_ptr, const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, const int64_t s_off, const int64_t d_inner, const int64_t L_param) { + const float * GGML_CUDA_RESTRICT src0 = src0_ptr; + const float * GGML_CUDA_RESTRICT src1 = src1_ptr; + const float * GGML_CUDA_RESTRICT src2 = src2_ptr; + const float * GGML_CUDA_RESTRICT src3 = src3_ptr; + const float * GGML_CUDA_RESTRICT src4 = src4_ptr; + const float * GGML_CUDA_RESTRICT src5 = src5_ptr; + const int32_t * GGML_CUDA_RESTRICT src6 = src6_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; const size_t L = L_template == 0 ? L_param : L_template; ggml_cuda_pdl_sync(); const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2); @@ -118,13 +126,21 @@ __global__ void __launch_bounds__(splitD, 1) template __global__ void __launch_bounds__(d_state, 1) ssm_scan_f32_group( - const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, - const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, - const int32_t * __restrict__ src6, float * __restrict__ dst, + const float * src0_ptr, const float * src1_ptr, const float * src2_ptr, + const float * src3_ptr, const float * src4_ptr, const float * src5_ptr, + const int32_t * src6_ptr, float * dst_ptr, const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) { + const float * GGML_CUDA_RESTRICT src0 = src0_ptr; + const float * GGML_CUDA_RESTRICT src1 = src1_ptr; + const float * GGML_CUDA_RESTRICT src2 = src2_ptr; + const float * GGML_CUDA_RESTRICT src3 = src3_ptr; + const float * GGML_CUDA_RESTRICT src4 = src4_ptr; + const float * GGML_CUDA_RESTRICT src5 = src5_ptr; + const int32_t * GGML_CUDA_RESTRICT src6 = src6_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; const int warp = threadIdx.x / WARP_SIZE; const int lane = threadIdx.x % WARP_SIZE;