11#include " conv2d.cuh"
22#include " convert.cuh"
33
4- #include < mma.h>
5- using namespace nvcuda ;
4+ #ifdef FP16_MMA_AVAILABLE
5+ # if !defined(GGML_USE_HIP)
6+ # include < mma.h>
7+ # ifdef GGML_USE_MUSA
8+ namespace wmma = mtmusa::wmma;
9+ # else // GGML_USE_MUSA
10+ namespace wmma = nvcuda::wmma;
11+ # endif // GGML_USE_MUSA
12+ # elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
13+ # include < rocwmma/rocwmma.hpp>
14+ namespace wmma = rocwmma;
15+ # endif // !defined(GGML_USE_HIP)
16+ #endif // FP16_MMA_AVAILABLE
617
718struct conv_params {
819 const int64_t IW, IH;
@@ -111,6 +122,8 @@ class float_mma {
111122 __device__ __forceinline__ float * store_result () const { return buf; }
112123};
113124
125+ #if (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(FP16_MMA_AVAILABLE)))
126+
114127class half_mma {
115128 private:
116129 wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float > acc;
@@ -136,6 +149,42 @@ class half_mma {
136149 }
137150};
138151
152+ #else
153+
154+ class half_mma {
155+ public:
156+ float * buf;
157+
158+ __device__ __forceinline__ half_mma (float * scratch) {
159+ buf = scratch;
160+ const int lane_id = threadIdx .x % warpSize ;
161+ # pragma unroll
162+ for (int i = lane_id; i < WMMA_M * WMMA_N; i += warpSize ) {
163+ buf[i] = 0 .0f ;
164+ }
165+ }
166+
167+ __device__ __forceinline__ void mma (const half * A_sh, const half * B_sh, const int strideA, const int strideB) {
168+ const int lane_id = threadIdx .x % warpSize ;
169+ # pragma unroll
170+ for (int e = lane_id; e < (WMMA_M * WMMA_N); e += warpSize ) {
171+ int m = e / WMMA_N;
172+ int n = e % WMMA_N;
173+ float sum = buf[m * WMMA_N + n];
174+ # pragma unroll
175+ for (int k = 0 ; k < WMMA_K; k++) {
176+ float a = A_sh[m * strideA + k];
177+ float b = B_sh[k * strideB + n];
178+ sum = fmaf (__half2float (a), __half2float (b), sum);
179+ }
180+ buf[m * WMMA_N + n] = sum;
181+ }
182+ }
183+
184+ __device__ __forceinline__ float * store_result () const { return buf; }
185+ };
186+ #endif // defined((__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || defined(FP16_MMA_AVAILABLE))
187+
139188template <typename T, typename layout, typename mma>
140189static __global__ void conv2d_kernel (const float * IN, const T * IK, float * OUT, const conv_params P) {
141190 extern __shared__ unsigned char smem_raw[];
0 commit comments