11#include " conv2d.cuh"
22#include " convert.cuh"
33
4- #include < mma.h>
5- using namespace nvcuda ;
6-
4+ #ifdef FP16_MMA_AVAILABLE
5+ # if !defined(GGML_USE_HIP)
6+ # include < mma.h>
7+ # ifdef GGML_USE_MUSA
8+ namespace wmma = mtmusa::wmma;
9+ # else
10+ namespace wmma = nvcuda::wmma;
11+ # endif
12+ # else
13+ # include < rocwmma/rocwmma.hpp>
14+ namespace wmma = rocwmma;
15+ # endif
16+ #endif
717struct conv_params {
818 const int64_t IW, IH;
919 const int64_t OW, OH;
@@ -111,6 +121,8 @@ class float_mma {
111121 __device__ __forceinline__ float * store_result () const { return buf; }
112122};
113123
124+ #if (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(FP16_MMA_AVAILABLE)))
125+
114126class half_mma {
115127 private:
116128 wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float > acc;
@@ -136,6 +148,42 @@ class half_mma {
136148 }
137149};
138150
151+ #else
152+
153+ class half_mma {
154+ public:
155+ float * buf;
156+
157+ __device__ __forceinline__ half_mma (float * scratch) {
158+ buf = scratch;
159+ const int lane_id = threadIdx .x % warpSize ;
160+ # pragma unroll
161+ for (int i = lane_id; i < WMMA_M * WMMA_N; i += warpSize ) {
162+ buf[i] = 0 .0f ;
163+ }
164+ }
165+
166+ __device__ __forceinline__ void mma (const half * A_sh, const half * B_sh, const int strideA, const int strideB) {
167+ const int lane_id = threadIdx .x % warpSize ;
168+ # pragma unroll
169+ for (int e = lane_id; e < (WMMA_M * WMMA_N); e += warpSize ) {
170+ int m = e / WMMA_N;
171+ int n = e % WMMA_N;
172+ float sum = buf[m * WMMA_N + n];
173+ # pragma unroll
174+ for (int k = 0 ; k < WMMA_K; k++) {
175+ float a = A_sh[m * strideA + k];
176+ float b = B_sh[k * strideB + n];
177+ sum = fmaf (__half2float (a), __half2float (b), sum);
178+ }
179+ buf[m * WMMA_N + n] = sum;
180+ }
181+ }
182+
183+ __device__ __forceinline__ float * store_result () const { return buf; }
184+ };
185+ #endif // defined((__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || defined(FP16_MMA_AVAILABLE))
186+
139187template <typename T, typename layout, typename mma>
140188static __global__ void conv2d_kernel (const float * IN, const T * IK, float * OUT, const conv_params P) {
141189 extern __shared__ unsigned char smem_raw[];
0 commit comments