From 1c0009c5fe845774b24c8d3c3abb0b8babfa4d88 Mon Sep 17 00:00:00 2001 From: Hashem Hashemi Date: Fri, 7 Feb 2025 00:35:27 +0000 Subject: [PATCH 1/8] Optimization for quantized gemm skinny sizes --- csrc/rocm/custom.cu | 18 + csrc/rocm/custom_kernels.cu | 454 ++++++++++++++++-- csrc/rocm/ops.h | 4 + csrc/rocm/torch_bindings.cpp | 5 + vllm/_custom_ops.py | 5 + .../model_executor/layers/quantization/fp8.py | 1 + .../layers/quantization/utils/w8a8_utils.py | 27 +- vllm/model_executor/layers/tuned_gemm.py | 10 + 8 files changed, 477 insertions(+), 47 deletions(-) diff --git a/csrc/rocm/custom.cu b/csrc/rocm/custom.cu index fae1b4fbfbe3..5b6d0063676f 100644 --- a/csrc/rocm/custom.cu +++ b/csrc/rocm/custom.cu @@ -48,6 +48,24 @@ void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, at::cuda::getCurrentCUDAStream(), CuCount); } +void wvSpltKQ_(void* in_a, void* in_b, void* out_c, + void* scale_a, void* scale_b, + const int M, const int K, const int Kp, + const int N, cudaStream_t stream, const int CuCount); + +void wvSpltKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + at::Tensor& scale_a, at::Tensor& scale_b, + const int64_t N_in, const int64_t CuCount) { + auto M = in_a.size(0); + auto K = in_a.size(1); + auto Kp = in_a.stride(0); + int N = N_in; + wvSpltKQ_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), + scale_a.data_ptr(), scale_b.data_ptr(), + M, K, Kp, N, + at::cuda::getCurrentCUDAStream(), CuCount); +} + void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, cudaStream_t stream, const int solidx); diff --git a/csrc/rocm/custom_kernels.cu b/csrc/rocm/custom_kernels.cu index ba90b3f75a07..0a56ab211f73 100644 --- a/csrc/rocm/custom_kernels.cu +++ b/csrc/rocm/custom_kernels.cu @@ -327,7 +327,7 @@ void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, #define DTYPE half -__device__ __forceinline__ int mindiv(int N, int div1, int div2) { +/*__device__ __forceinline__ int mindiv(int N, int div1, int div2) { int nPrRnd = div1 * div2; int rnds0 = N / nPrRnd; nPrRnd -= div1 * 3; @@ -354,14 +354,323 @@ __device__ __forceinline__ int mindiv(int N, int div1, int div2) { if (rnds0 == rnds8) rtn = div2 - 8; if (rnds0 == rnds9) rtn = div2 - 9; return rtn; +}*/ + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const float* __restrict__ s_A, + const float* __restrict__ s_B, + const int _WvPrGrp, const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + using intx2 = __attribute__((__vector_size__(2*sizeof(int)))) int; + using intx4 = __attribute__((__vector_size__(4*sizeof(int)))) int; + union bigType { + char f8[A_CHUNK*2]; + char2 c2[A_CHUNK]; + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + int i[A_CHUNK / 2]; + long l[A_CHUNK/ 4]; + intx4 l2[A_CHUNK/ 8]; + half8 h8; + }; + + __shared__ half s[1024 * 32]; + + for (uint32_t k = 0; k < min(K/2 * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + if (k_in >= min(K/2 * M, 32 * 1024)) break; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + } + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + uint32_t n = (blockIdx.x * _WvPrGrp + (threadIdx.y%_WvPrGrp)) * YTILE; + + using floatx16 = __attribute__((__vector_size__(16*sizeof(float)))) float; + floatx16 sum[M][YTILE]; + float sA = *s_A; + float sB = *s_B; + + while (n < N) { + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = {0}; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; + bigType bigB1[UNRL]; + bigType bigB2[UNRL]; + bigType bigB3[UNRL]; + bigType bigB4[UNRL]; + bigType bigB5[UNRL]; + bigType bigB6[UNRL]; + bigType bigB7[UNRL]; + + // Fetch the weight matrix from memory! + for (uint32_t k1 = 0; k1 < K/2; k1 += THRDS * A_CHUNK * UNRL) { + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K/2) break; + + const half* B_ = &B[(n + 0) * (Kp/2) + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * Kp/2]))); + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * Kp/2]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * Kp/2]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * Kp/2]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * Kp/2]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * Kp/2]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * Kp/2]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * Kp/2]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K/2) break; + for (int m = 0; m < M; m++) { + // if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K/2 * m]))); + // else + // bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + } + } + + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K/2) break; + float aV[A_CHUNK*2]; + + for (uint32_t m = 0; m < M; m++) { + for (int i=0; i= 2) + sum[m][1] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(bigA[m][k2].l[i/8], bigB1[k2].l[i/8], sum[m][1], 0, 0, 0); + } + } + } + } + + // Final reduction + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + float accm0 = sum[m][y][0]; + float accm16 = sum[m][y][8]; + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][1]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][9]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][2]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][10]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][3]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][11]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][4]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][12]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][5]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][13]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][6]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][14]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][7]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][15]), "v"(accm16)); + accm0 += __shfl(accm0, 36); + accm16 += __shfl(accm16,52); + sum[m][y][0] = accm0 + __shfl(accm16,16); + } + } + + if (threadIdx.x == 0) { + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + C[n + y + m * N] = __float2half(sum[m][y][0]*sA*sB); + } + } + } + + n += CuCount * _WvPrGrp * YTILE; + } +} +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const float* __restrict__ s_A, + const float* __restrict__ s_B, + const int _WvPrGrp, const int CuCount) { + UNREACHABLE_CODE } +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltKQ_hf_(const int K, const int Kp, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const float* __restrict__ s_A, + const float* __restrict__ s_B, + const int _WvPrGrp, const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + using intx2 = __attribute__((__vector_size__(2*sizeof(int)))) int; + using intx4 = __attribute__((__vector_size__(4*sizeof(int)))) int; + union bigType { + char f8[A_CHUNK*2]; + char2 c2[A_CHUNK]; + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + int i[A_CHUNK / 2]; + long l[A_CHUNK/ 4]; + intx4 l2[A_CHUNK/ 8]; + half8 h8; + }; + + __shared__ half s[1024 * 32]; + + for (uint32_t k = 0; k < min(K/2 * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + if (k_in >= min(K/2 * M, 32 * 1024)) break; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + } + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + uint32_t n = (blockIdx.x * _WvPrGrp + (threadIdx.y%_WvPrGrp)) * YTILE; + + using floatx16 = __attribute__((__vector_size__(16*sizeof(float)))) float; + floatx16 sum[M][YTILE]; + float sA = *s_A; + float sB = *s_B; + + while (n < N) { + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = {0}; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; + bigType bigB1[UNRL]; + bigType bigB2[UNRL]; + bigType bigB3[UNRL]; + bigType bigB4[UNRL]; + bigType bigB5[UNRL]; + bigType bigB6[UNRL]; + bigType bigB7[UNRL]; + + // Fetch the weight matrix from memory! + for (uint32_t k1 = 0; k1 < K/2; k1 += THRDS * A_CHUNK * UNRL) { + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K/2) break; + + const half* B_ = &B[(n + 0) * (Kp/2) + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * Kp/2]))); + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * Kp/2]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * Kp/2]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * Kp/2]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * Kp/2]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * Kp/2]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * Kp/2]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * Kp/2]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K/2) break; + for (int m = 0; m < M; m++) { + if (k_ + K/2 * m < 64 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K/2 * m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_ + K/2 * m]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K/2) break; + float aV[A_CHUNK*2]; + + for (uint32_t m = 0; m < M; m++) { + for (int i=0; i= 2) + sum[m][1] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(bigA[m][k2].l[i/8], bigB1[k2].l[i/8], sum[m][1], 0, 0, 0); + } + } + } + } + + // Final reduction + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + float accm0 = sum[m][y][0]; + float accm16 = sum[m][y][8]; + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][1]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][9]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][2]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][10]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][3]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][11]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][4]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][12]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][5]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][13]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][6]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][14]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][7]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][15]), "v"(accm16)); + accm0 += __shfl(accm0, 36); + accm16 += __shfl(accm16,52); + sum[m][y][0] = accm0 + __shfl(accm16,16); + } + } + + if (threadIdx.x == 0) { + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + C[n + y + m * N] = __float2half(sum[m][y][0]*sA*sB); + } + } + } + + n += CuCount * _WvPrGrp * YTILE; + } +} +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltKQ_hf_(const int K, const int Kp, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const float* __restrict__ s_A, + const float* __restrict__ s_B, + const int _WvPrGrp, const int CuCount) { + UNREACHABLE_CODE + } +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support // This version targets cases where A[] fits LDS capacity template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; union bigType { @@ -381,34 +690,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- __shared__ half s[1024 * 32]; - //---------------------------------------------------- - // Computation of columns that need to be committed to memory! - //---------------------------------------------------- - // uint32_t commitColumn[YTILE]; - // for (uint32_t i = 0; i < YTILE; i++) { - // commitColumn[i] = 1; - //} - - // It's worth trying to load-balance... - int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); - - //---------------------------------------------------- - // Indexing function into the column of weight matrix B - // Algorithm does 64 lane k-splitting / wave and uses - // WG ID and Thread ID to find the index. - //---------------------------------------------------- - uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - // if (n < N && (n + YTILE) >= N) { - // uint32_t startColumn = N - YTILE; - // for (uint32_t i = 0; i < (n - startColumn); i++) { - // commitColumn[i] = 0; - // } - // n = startColumn; - //} - //---------------------------------------------------- // Fetch the activation matrix to LDS // Loop iteration: @@ -434,7 +715,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } __syncthreads(); + //int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); if (threadIdx.y >= _WvPrGrp) return; + + uint32_t n = (blockIdx.x * _WvPrGrp + (threadIdx.y%_WvPrGrp)) * YTILE; float sum[M][YTILE]; @@ -490,6 +774,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // // TODO: Logic below will only work when K is multiple of 8 //---------------------------------------------------- + //for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { // Fetch the weight matrix from memory! #pragma unroll @@ -632,6 +917,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) template __global__ void wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, + const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } @@ -642,7 +928,7 @@ __global__ void wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltK_hf_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; union bigType { @@ -670,14 +956,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) commitColumn[i] = 1; } - // It's worth trying to load-balance... - int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); //---------------------------------------------------- // Indexing function into the column of weight matrix B // Algorithm does 64 lane k-splitting / wave and uses // WG ID and Thread ID to find the index. //---------------------------------------------------- + //int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; // Check whether there will be fragmenation! @@ -713,6 +998,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; } + __syncthreads(); if (threadIdx.y >= _WvPrGrp) return; @@ -915,7 +1201,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) template __global__ void wvSpltK_hf_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount) { + const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support @@ -925,7 +1211,7 @@ __global__ void wvSpltK_hf_(const int K, const int N, const DTYPE* B, template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; @@ -954,8 +1240,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) commitColumn[i] = 1; } - // It's worth trying to load-balance... - int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + //int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); if (threadIdx.y >= _WvPrGrp) return; //---------------------------------------------------- @@ -1252,11 +1537,40 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) template __global__ void wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount) { + const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support +int mindiv(int N, int div1, int div2) { + int nPrRnd = div1 * div2; + int rnds0 = N / nPrRnd; + nPrRnd -= div1 * 3; + int rnds3 = N / nPrRnd; + nPrRnd -= div1; + int rnds4 = N / nPrRnd; + nPrRnd -= div1; + int rnds5 = N / nPrRnd; + nPrRnd -= div1; + int rnds6 = N / nPrRnd; + nPrRnd -= div1; + int rnds7 = N / nPrRnd; + nPrRnd -= div1; + int rnds8 = N / nPrRnd; + nPrRnd -= div1; + int rnds9 = N / nPrRnd; + nPrRnd -= div1; + int rtn = div2; + if (rnds0 == rnds3) rtn = div2 - 3; + if (rnds0 == rnds4) rtn = div2 - 4; + if (rnds0 == rnds5) rtn = div2 - 5; + if (rnds0 == rnds6) rtn = div2 - 6; + if (rnds0 == rnds7) rtn = div2 - 7; + if (rnds0 == rnds8) rtn = div2 - 8; + if (rnds0 == rnds9) rtn = div2 - 9; + return rtn; +} + void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, const int K_in, const int N_in, cudaStream_t stream, const int CuCount = 0) { @@ -1269,17 +1583,18 @@ void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, _N) \ { \ dim3 block(64, _WvPrGrp); \ - /*wvSpltK_hf:*/ \ if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ wvSpltK_hf_sml_<64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ - <<>>(K_in, M_in, af4, bf4, c, CuCount); \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, CuCount); \ } else if (K_in * N_in <= 32 * 1024 * 1.2) { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ wvSpltK_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ - <<>>(K_in, M_in, af4, bf4, c, CuCount); \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, CuCount); \ } else { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ wvSpltK_hf_big_<64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ - <<>>(K_in, M_in, af4, bf4, c, CuCount); \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, CuCount); \ } \ } @@ -1306,4 +1621,55 @@ void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, if (cudaSuccess != err) { throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); } -} \ No newline at end of file +} + +void wvSpltKQ_(void* in_a, void* in_b, void* out_c, + void* scale_a, void* scale_b, + const int M_in, const int K_in, const int Kp_in, const int N_in, cudaStream_t stream, + const int CuCount = 0) { + dim3 grid(CuCount); + half* af4 = reinterpret_cast(in_a); + const half* bf4 = reinterpret_cast(in_b); + auto* c = reinterpret_cast(out_c); + auto* s_a = reinterpret_cast(scale_a); + auto* s_b = reinterpret_cast(scale_b); + +#define WVSPLTKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb,\ + _N) \ + { \ + dim3 block(64, _WvPrGrp); \ + if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ + wvSpltKQ_hf_sml_<64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \ + <<>>(K_in, Kp_in, M_in, af4, bf4, c, s_a, s_b, __wvPrGrp, CuCount); \ + } else { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ + wvSpltKQ_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ + <<>>(K_in, Kp_in, M_in, af4, bf4, c, s_a, s_b, __wvPrGrp, CuCount); \ + } \ + } + + switch (N_in) { + case 1: + WVSPLTKQ(16, 2, 2, 2, 2, 2, 2, 1) // MI308 + break; + case 2: + WVSPLTKQ(16, 2, 2, 2, 2, 2, 2, 2) // MI308 + break; + case 3: + WVSPLTKQ(16, 4, 7, 7, 1, 1, 1, 3) // MI308 + break; + case 4: + WVSPLTKQ(16, 4, 7, 7, 1, 1, 1, 4) // MI308 + break; + default: + throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + + "," + std::to_string(K_in) + "," + + std::to_string(N_in)); + } + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) { + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); + } +} diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 59bd28e3bc12..b20d404e39d1 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -11,6 +11,10 @@ void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, const int64_t N_in, const int64_t CuCount); +void wvSpltKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + at::Tensor& scale_a, at::Tensor& scale_b, + const int64_t N_in, const int64_t CuCount); + void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 50640a96725e..2c04ac1b89c0 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -43,6 +43,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { "wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in," " int CuCount) -> ()"); rocm_ops.impl("wvSpltK", torch::kCUDA, &wvSpltK); + rocm_ops.def( + "wvSpltKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, Tensor scale_b," + " int N_in," + " int CuCount) -> ()"); + rocm_ops.impl("wvSpltKQ", torch::kCUDA, &wvSpltKQ); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 844824a54172..15b7aed1d78f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1142,3 +1142,8 @@ def LLMM_Silu(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, def wvSpltK(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, N: int, cu_count: int) -> None: torch.ops._rocm_C.wvSpltK(a, b, out, N, cu_count) + +def wvSpltKQ(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, + scale_a: torch.Tensor, scale_b: torch.Tensor, N: int, + cu_count: int) -> None: + torch.ops._rocm_C.wvSpltKQ(a, b, out, scale_a, scale_b, N, cu_count) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 3ea954bf9144..da0f34ee62c7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -363,6 +363,7 @@ def process_weights_after_loading(self, layer: Module) -> None: # Update layer with new values. layer.weight = Parameter(weight.t(), requires_grad=False) + #layer.weight = Parameter(weight, requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) if self.quant_config.activation_scheme == "static": layer.input_scale = Parameter(layer.input_scale.max(), diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 99575a3922cc..1f4084e263fd 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -6,6 +6,7 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform +from vllm.model_executor.layers.tuned_gemm import tgemm # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale @@ -123,7 +124,6 @@ def apply_fp8_linear( if out_dtype is None: out_dtype = input.dtype - # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A if cutlass_fp8_supported: qinput, x_scale = ops.scaled_fp8_quant( @@ -161,15 +161,36 @@ def apply_fp8_linear( if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ - output = torch._scaled_mm(qinput, + output2 = None + n = qinput.shape[0] + if n == 1 : + weightT = weight.t() + output = tgemm.scaled_mm(qinput, + weightT, + out_dtype=out_dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) + else : + output = torch._scaled_mm(qinput, weight, out_dtype=out_dtype, scale_a=x_scale, scale_b=weight_scale, bias=bias) + + ''' + atol = 1e-2; rtol = 1e-5 + if output2!=None : + if not torch.allclose(output,output2,atol=atol,rtol=rtol) : + print("MISSMATCH") + print("orig", output) + print("opt", output2) + ''' # A fix for discrepancy in scaled_mm which returns tuple # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: + if (weight.shape[0] == 1) : + if type(output) is tuple and len(output) == 2: output = output[0] return torch.narrow(output, 0, 0, diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 8fb44cdc96c2..d9ea11979503 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -5,6 +5,7 @@ import pandas as pd import torch import torch.nn.functional as F +import traceback from vllm import _custom_ops as ops from vllm.envs import VLLM_USE_ROCM_SKINNY_GEMM @@ -91,6 +92,15 @@ def apply_skinny(self, m, n, k, inp_view, weights): else: return None + def scaled_mm(self, inp, weight, out_dtype=None, scale_a=None, scale_b=None, bias=None): + out = torch.empty(inp.shape[0], + weight.shape[0], + dtype=out_dtype, + device='cuda') + n = inp.shape[0] + ops.wvSpltKQ(weight, inp, out, scale_a, scale_b, n, self.cu_count) + return out + def mm(self, inp, weights, bias=None): if not support_tuned_gemms: return F.linear(inp, weights, bias) From a40b3d15cfbb62e3954af7209ea8380357dda81e Mon Sep 17 00:00:00 2001 From: Hashem Hashemi Date: Fri, 7 Feb 2025 00:41:42 +0000 Subject: [PATCH 2/8] lint fix --- csrc/rocm/custom_kernels.cu | 409 +++++++++++++++++++++--------------- 1 file changed, 236 insertions(+), 173 deletions(-) diff --git a/csrc/rocm/custom_kernels.cu b/csrc/rocm/custom_kernels.cu index 0a56ab211f73..7ca43c4044b5 100644 --- a/csrc/rocm/custom_kernels.cu +++ b/csrc/rocm/custom_kernels.cu @@ -360,40 +360,40 @@ void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const float* __restrict__ s_A, - const float* __restrict__ s_B, - const int _WvPrGrp, const int CuCount) { + const DTYPE* __restrict__ A, DTYPE* C, + const float* __restrict__ s_A, + const float* __restrict__ s_B, const int _WvPrGrp, + const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; - using intx2 = __attribute__((__vector_size__(2*sizeof(int)))) int; - using intx4 = __attribute__((__vector_size__(4*sizeof(int)))) int; + using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int; + using intx4 = __attribute__((__vector_size__(4 * sizeof(int)))) int; union bigType { - char f8[A_CHUNK*2]; + char f8[A_CHUNK * 2]; char2 c2[A_CHUNK]; DTYPE h[A_CHUNK]; float f[A_CHUNK / 2]; int i[A_CHUNK / 2]; - long l[A_CHUNK/ 4]; - intx4 l2[A_CHUNK/ 8]; + long l[A_CHUNK / 4]; + intx4 l2[A_CHUNK / 8]; half8 h8; }; __shared__ half s[1024 * 32]; - for (uint32_t k = 0; k < min(K/2 * M, 32 * 1024); + for (uint32_t k = 0; k < min(K / 2 * M, 32 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - if (k_in >= min(K/2 * M, 32 * 1024)) break; + if (k_in >= min(K / 2 * M, 32 * 1024)) break; *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); } __syncthreads(); if (threadIdx.y >= _WvPrGrp) return; - uint32_t n = (blockIdx.x * _WvPrGrp + (threadIdx.y%_WvPrGrp)) * YTILE; + uint32_t n = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; - using floatx16 = __attribute__((__vector_size__(16*sizeof(float)))) float; + using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; floatx16 sum[M][YTILE]; float sA = *s_A; float sB = *s_B; @@ -413,86 +413,115 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) bigType bigB7[UNRL]; // Fetch the weight matrix from memory! - for (uint32_t k1 = 0; k1 < K/2; k1 += THRDS * A_CHUNK * UNRL) { - #pragma unroll + for (uint32_t k1 = 0; k1 < K / 2; k1 += THRDS * A_CHUNK * UNRL) { + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K/2) break; - - const half* B_ = &B[(n + 0) * (Kp/2) + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * Kp/2]))); - if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * Kp/2]))); - if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * Kp/2]))); - if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * Kp/2]))); - if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * Kp/2]))); - if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * Kp/2]))); - if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * Kp/2]))); - if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * Kp/2]))); + if (k_ >= K / 2) break; + + const half* B_ = &B[(n + 0) * (Kp / 2) + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * Kp / 2]))); + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * Kp / 2]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * Kp / 2]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * Kp / 2]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * Kp / 2]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * Kp / 2]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * Kp / 2]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * Kp / 2]))); } - // Fetch activation matrix from either just LDS or from both LDS / memory - #pragma unroll + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K/2) break; + if (k_ >= K / 2) break; for (int m = 0; m < M; m++) { // if (k_ + K * m < 32 * 1024) - bigA[m][k2] = *((const bigType*)(&(s[k_ + K/2 * m]))); + bigA[m][k2] = *((const bigType*)(&(s[k_ + K / 2 * m]))); // else // bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); } } - - // Do the matrix multiplication in interleaved manner - #pragma unroll + // Do the matrix multiplication in interleaved manner + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K/2) break; - float aV[A_CHUNK*2]; + if (k_ >= K / 2) break; + float aV[A_CHUNK * 2]; for (uint32_t m = 0; m < M; m++) { - for (int i=0; i= 2) - sum[m][1] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(bigA[m][k2].l[i/8], bigB1[k2].l[i/8], sum[m][1], 0, 0, 0); - } - } + for (int i = 0; i < A_CHUNK * 2; i += 8) { + sum[m][0] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bigA[m][k2].l[i / 8], bigB0[k2].l[i / 8], sum[m][0], 0, 0, 0); + if (YTILE >= 2) + sum[m][1] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bigA[m][k2].l[i / 8], bigB1[k2].l[i / 8], sum[m][1], 0, 0, 0); + } + } } } // Final reduction for (int m = 0; m < M; m++) { for (int y = 0; y < YTILE; y++) { - float accm0 = sum[m][y][0]; - float accm16 = sum[m][y][8]; - asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][1]), "v"(accm0)); - asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][9]), "v"(accm16)); - asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][2]), "v"(accm0)); - asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][10]), "v"(accm16)); - asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][3]), "v"(accm0)); - asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][11]), "v"(accm16)); - asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][4]), "v"(accm0)); - asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][12]), "v"(accm16)); - asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][5]), "v"(accm0)); - asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][13]), "v"(accm16)); - asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][6]), "v"(accm0)); - asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][14]), "v"(accm16)); - asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][7]), "v"(accm0)); - asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][15]), "v"(accm16)); - accm0 += __shfl(accm0, 36); - accm16 += __shfl(accm16,52); - sum[m][y][0] = accm0 + __shfl(accm16,16); + float accm0 = sum[m][y][0]; + float accm16 = sum[m][y][8]; + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][1]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][9]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][2]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][10]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][3]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][11]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][4]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][12]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][5]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][13]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][6]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][14]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][7]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][15]), "v"(accm16)); + accm0 += __shfl(accm0, 36); + accm16 += __shfl(accm16, 52); + sum[m][y][0] = accm0 + __shfl(accm16, 16); } } if (threadIdx.x == 0) { for (int m = 0; m < M; m++) { for (int y = 0; y < YTILE; y++) { - C[n + y + m * N] = __float2half(sum[m][y][0]*sA*sB); + C[n + y + m * N] = __float2half(sum[m][y][0] * sA * sB); } } } @@ -502,12 +531,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support template -__global__ void wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const float* __restrict__ s_A, - const float* __restrict__ s_B, - const int _WvPrGrp, const int CuCount) { - UNREACHABLE_CODE +__global__ void wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, + const DTYPE* B, const DTYPE* __restrict__ A, + DTYPE* C, const float* __restrict__ s_A, + const float* __restrict__ s_B, + const int _WvPrGrp, const int CuCount) { + UNREACHABLE_CODE } #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support @@ -515,40 +544,39 @@ __global__ void wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, const D template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltKQ_hf_(const int K, const int Kp, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const float* __restrict__ s_A, - const float* __restrict__ s_B, - const int _WvPrGrp, const int CuCount) { + const DTYPE* __restrict__ A, DTYPE* C, + const float* __restrict__ s_A, const float* __restrict__ s_B, + const int _WvPrGrp, const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; - using intx2 = __attribute__((__vector_size__(2*sizeof(int)))) int; - using intx4 = __attribute__((__vector_size__(4*sizeof(int)))) int; + using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int; + using intx4 = __attribute__((__vector_size__(4 * sizeof(int)))) int; union bigType { - char f8[A_CHUNK*2]; + char f8[A_CHUNK * 2]; char2 c2[A_CHUNK]; DTYPE h[A_CHUNK]; float f[A_CHUNK / 2]; int i[A_CHUNK / 2]; - long l[A_CHUNK/ 4]; - intx4 l2[A_CHUNK/ 8]; + long l[A_CHUNK / 4]; + intx4 l2[A_CHUNK / 8]; half8 h8; }; __shared__ half s[1024 * 32]; - for (uint32_t k = 0; k < min(K/2 * M, 32 * 1024); + for (uint32_t k = 0; k < min(K / 2 * M, 32 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - if (k_in >= min(K/2 * M, 32 * 1024)) break; + if (k_in >= min(K / 2 * M, 32 * 1024)) break; *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); } __syncthreads(); if (threadIdx.y >= _WvPrGrp) return; - uint32_t n = (blockIdx.x * _WvPrGrp + (threadIdx.y%_WvPrGrp)) * YTILE; + uint32_t n = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; - using floatx16 = __attribute__((__vector_size__(16*sizeof(float)))) float; + using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; floatx16 sum[M][YTILE]; float sA = *s_A; float sB = *s_B; @@ -568,85 +596,115 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) bigType bigB7[UNRL]; // Fetch the weight matrix from memory! - for (uint32_t k1 = 0; k1 < K/2; k1 += THRDS * A_CHUNK * UNRL) { - #pragma unroll + for (uint32_t k1 = 0; k1 < K / 2; k1 += THRDS * A_CHUNK * UNRL) { + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K/2) break; - - const half* B_ = &B[(n + 0) * (Kp/2) + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * Kp/2]))); - if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * Kp/2]))); - if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * Kp/2]))); - if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * Kp/2]))); - if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * Kp/2]))); - if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * Kp/2]))); - if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * Kp/2]))); - if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * Kp/2]))); + if (k_ >= K / 2) break; + + const half* B_ = &B[(n + 0) * (Kp / 2) + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * Kp / 2]))); + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * Kp / 2]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * Kp / 2]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * Kp / 2]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * Kp / 2]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * Kp / 2]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * Kp / 2]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * Kp / 2]))); } - // Fetch activation matrix from either just LDS or from both LDS / memory - #pragma unroll + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K/2) break; + if (k_ >= K / 2) break; for (int m = 0; m < M; m++) { - if (k_ + K/2 * m < 64 * 1024) - bigA[m][k2] = *((const bigType*)(&(s[k_ + K/2 * m]))); + if (k_ + K / 2 * m < 64 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K / 2 * m]))); else - bigA[m][k2] = *((const bigType*)(&(A[k_ + K/2 * m]))); + bigA[m][k2] = *((const bigType*)(&(A[k_ + K / 2 * m]))); } } - // Do the matrix multiplication in interleaved manner - #pragma unroll + // Do the matrix multiplication in interleaved manner + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K/2) break; - float aV[A_CHUNK*2]; + if (k_ >= K / 2) break; + float aV[A_CHUNK * 2]; for (uint32_t m = 0; m < M; m++) { - for (int i=0; i= 2) - sum[m][1] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(bigA[m][k2].l[i/8], bigB1[k2].l[i/8], sum[m][1], 0, 0, 0); - } - } + for (int i = 0; i < A_CHUNK * 2; i += 8) { + sum[m][0] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bigA[m][k2].l[i / 8], bigB0[k2].l[i / 8], sum[m][0], 0, 0, 0); + if (YTILE >= 2) + sum[m][1] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bigA[m][k2].l[i / 8], bigB1[k2].l[i / 8], sum[m][1], 0, 0, 0); + } + } } } // Final reduction for (int m = 0; m < M; m++) { for (int y = 0; y < YTILE; y++) { - float accm0 = sum[m][y][0]; - float accm16 = sum[m][y][8]; - asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][1]), "v"(accm0)); - asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][9]), "v"(accm16)); - asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][2]), "v"(accm0)); - asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][10]), "v"(accm16)); - asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][3]), "v"(accm0)); - asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][11]), "v"(accm16)); - asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][4]), "v"(accm0)); - asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][12]), "v"(accm16)); - asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][5]), "v"(accm0)); - asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][13]), "v"(accm16)); - asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][6]), "v"(accm0)); - asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][14]), "v"(accm16)); - asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " : "=v"(accm0) : "0"(accm0), "v"(sum[m][y][7]), "v"(accm0)); - asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " : "=v"(accm16) : "0"(accm16), "v"(sum[m][y][15]), "v"(accm16)); - accm0 += __shfl(accm0, 36); - accm16 += __shfl(accm16,52); - sum[m][y][0] = accm0 + __shfl(accm16,16); + float accm0 = sum[m][y][0]; + float accm16 = sum[m][y][8]; + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][1]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][9]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][2]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][10]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][3]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][11]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][4]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][12]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][5]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][13]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][6]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][14]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][7]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][15]), "v"(accm16)); + accm0 += __shfl(accm0, 36); + accm16 += __shfl(accm16, 52); + sum[m][y][0] = accm0 + __shfl(accm16, 16); } } if (threadIdx.x == 0) { for (int m = 0; m < M; m++) { for (int y = 0; y < YTILE; y++) { - C[n + y + m * N] = __float2half(sum[m][y][0]*sA*sB); + C[n + y + m * N] = __float2half(sum[m][y][0] * sA * sB); } } } @@ -656,13 +714,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support template -__global__ void wvSpltKQ_hf_(const int K, const int Kp, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const float* __restrict__ s_A, - const float* __restrict__ s_B, - const int _WvPrGrp, const int CuCount) { - UNREACHABLE_CODE - } +__global__ void wvSpltKQ_hf_(const int K, const int Kp, const int N, + const DTYPE* B, const DTYPE* __restrict__ A, + DTYPE* C, const float* __restrict__ s_A, + const float* __restrict__ s_B, const int _WvPrGrp, + const int CuCount) { + UNREACHABLE_CODE +} #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support @@ -670,7 +728,8 @@ __global__ void wvSpltKQ_hf_(const int K, const int Kp, const int N, const DTYPE template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, const int CuCount) { + const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, + const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; union bigType { @@ -715,10 +774,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } __syncthreads(); - //int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); if (threadIdx.y >= _WvPrGrp) return; - - uint32_t n = (blockIdx.x * _WvPrGrp + (threadIdx.y%_WvPrGrp)) * YTILE; + + uint32_t n = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; float sum[M][YTILE]; @@ -774,7 +833,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // // TODO: Logic below will only work when K is multiple of 8 //---------------------------------------------------- - //for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { // Fetch the weight matrix from memory! #pragma unroll @@ -820,8 +879,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t m = 0; m < M; m++) { #pragma unroll @@ -917,8 +976,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) template __global__ void wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, - const int _WvPrGrp, - const int CuCount) { + const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support @@ -928,7 +986,8 @@ __global__ void wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltK_hf_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, const int CuCount) { + const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, + const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; union bigType { @@ -956,13 +1015,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) commitColumn[i] = 1; } - //---------------------------------------------------- // Indexing function into the column of weight matrix B // Algorithm does 64 lane k-splitting / wave and uses // WG ID and Thread ID to find the index. //---------------------------------------------------- - //int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; // Check whether there will be fragmenation! @@ -1105,8 +1163,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" @@ -1211,7 +1269,8 @@ __global__ void wvSpltK_hf_(const int K, const int N, const DTYPE* B, template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, const int CuCount) { + const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, + const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; @@ -1240,7 +1299,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) commitColumn[i] = 1; } - //int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); if (threadIdx.y >= _WvPrGrp) return; //---------------------------------------------------- @@ -1383,7 +1442,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (n >= N) continue; #endif - // Fetch the weight matrix from memory! + // Fetch the weight matrix from memory! #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; @@ -1586,15 +1645,18 @@ void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ wvSpltK_hf_sml_<64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ - <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, CuCount); \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ + CuCount); \ } else if (K_in * N_in <= 32 * 1024 * 1.2) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ wvSpltK_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ - <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, CuCount); \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ + CuCount); \ } else { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ wvSpltK_hf_big_<64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ - <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, CuCount); \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ + CuCount); \ } \ } @@ -1623,10 +1685,9 @@ void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, } } -void wvSpltKQ_(void* in_a, void* in_b, void* out_c, - void* scale_a, void* scale_b, - const int M_in, const int K_in, const int Kp_in, const int N_in, cudaStream_t stream, - const int CuCount = 0) { +void wvSpltKQ_(void* in_a, void* in_b, void* out_c, void* scale_a, + void* scale_b, const int M_in, const int K_in, const int Kp_in, + const int N_in, cudaStream_t stream, const int CuCount = 0) { dim3 grid(CuCount); half* af4 = reinterpret_cast(in_a); const half* bf4 = reinterpret_cast(in_b); @@ -1634,19 +1695,21 @@ void wvSpltKQ_(void* in_a, void* in_b, void* out_c, auto* s_a = reinterpret_cast(scale_a); auto* s_b = reinterpret_cast(scale_b); -#define WVSPLTKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb,\ - _N) \ - { \ - dim3 block(64, _WvPrGrp); \ - if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ - int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ - wvSpltKQ_hf_sml_<64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \ - <<>>(K_in, Kp_in, M_in, af4, bf4, c, s_a, s_b, __wvPrGrp, CuCount); \ - } else { \ - int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ - wvSpltKQ_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ - <<>>(K_in, Kp_in, M_in, af4, bf4, c, s_a, s_b, __wvPrGrp, CuCount); \ - } \ +#define WVSPLTKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ + _N) \ + { \ + dim3 block(64, _WvPrGrp); \ + if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ + wvSpltKQ_hf_sml_<64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \ + <<>>(K_in, Kp_in, M_in, af4, bf4, c, s_a, \ + s_b, __wvPrGrp, CuCount); \ + } else { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ + wvSpltKQ_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ + <<>>(K_in, Kp_in, M_in, af4, bf4, c, s_a, \ + s_b, __wvPrGrp, CuCount); \ + } \ } switch (N_in) { From 0f8846c946db0f83cb21046d2c819b7218c45ba5 Mon Sep 17 00:00:00 2001 From: Hashem Hashemi Date: Sat, 8 Feb 2025 04:51:46 +0000 Subject: [PATCH 3/8] Add support for bf16/fp16 --- csrc/rocm/custom.cu | 18 ++++----- csrc/rocm/custom_kernels.cu | 39 ++++++++++++------- csrc/rocm/ops.h | 4 +- csrc/rocm/torch_bindings.cpp | 4 +- vllm/_custom_ops.py | 9 +++-- .../layers/quantization/utils/w8a8_utils.py | 39 +++++++++---------- vllm/model_executor/layers/tuned_gemm.py | 21 +++++++--- 7 files changed, 78 insertions(+), 56 deletions(-) diff --git a/csrc/rocm/custom.cu b/csrc/rocm/custom.cu index 5b6d0063676f..c799dd273dae 100644 --- a/csrc/rocm/custom.cu +++ b/csrc/rocm/custom.cu @@ -48,22 +48,22 @@ void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, at::cuda::getCurrentCUDAStream(), CuCount); } -void wvSpltKQ_(void* in_a, void* in_b, void* out_c, - void* scale_a, void* scale_b, - const int M, const int K, const int Kp, - const int N, cudaStream_t stream, const int CuCount); +void wvSpltKQ_(void* in_a, void* in_b, void* out_c, void* scale_a, + void* scale_b, const int M, const int K, const int Kp, + const int N, const int Otp_in, cudaStream_t stream, + const int CuCount); void wvSpltKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - at::Tensor& scale_a, at::Tensor& scale_b, - const int64_t N_in, const int64_t CuCount) { + at::Tensor& scale_a, at::Tensor& scale_b, const int64_t N_in, + const int64_t Otp_in, const int64_t CuCount) { auto M = in_a.size(0); auto K = in_a.size(1); auto Kp = in_a.stride(0); int N = N_in; + int Otp = Otp_in; wvSpltKQ_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), - scale_a.data_ptr(), scale_b.data_ptr(), - M, K, Kp, N, - at::cuda::getCurrentCUDAStream(), CuCount); + scale_a.data_ptr(), scale_b.data_ptr(), M, K, Kp, N, Otp, + at::cuda::getCurrentCUDAStream(), CuCount); } void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, diff --git a/csrc/rocm/custom_kernels.cu b/csrc/rocm/custom_kernels.cu index 7ca43c4044b5..d130461b27e2 100644 --- a/csrc/rocm/custom_kernels.cu +++ b/csrc/rocm/custom_kernels.cu @@ -1,5 +1,6 @@ #include #include +#include #include #include #include "cuda_compat.h" @@ -363,7 +364,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) const DTYPE* __restrict__ A, DTYPE* C, const float* __restrict__ s_A, const float* __restrict__ s_B, const int _WvPrGrp, - const int CuCount) { + const int Otp, const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int; @@ -521,7 +522,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 0) { for (int m = 0; m < M; m++) { for (int y = 0; y < YTILE; y++) { - C[n + y + m * N] = __float2half(sum[m][y][0] * sA * sB); + if (Otp == 0) // fp16 + C[n + y + m * N] = __float2half(sum[m][y][0] * sA * sB); + else // if (Otp == 1) //bf16 + *reinterpret_cast<__hip_bfloat16*>(&C[n + y + m * N]) = + __float2bfloat16(sum[m][y][0] * sA * sB); } } } @@ -535,7 +540,8 @@ __global__ void wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const float* __restrict__ s_A, const float* __restrict__ s_B, - const int _WvPrGrp, const int CuCount) { + const int _WvPrGrp, const int Otp, + const int CuCount) { UNREACHABLE_CODE } #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support @@ -546,7 +552,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltKQ_hf_(const int K, const int Kp, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const float* __restrict__ s_A, const float* __restrict__ s_B, - const int _WvPrGrp, const int CuCount) { + const int _WvPrGrp, const int Otp, const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int; @@ -704,7 +710,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 0) { for (int m = 0; m < M; m++) { for (int y = 0; y < YTILE; y++) { - C[n + y + m * N] = __float2half(sum[m][y][0] * sA * sB); + if (Otp == 0) // fp16 + C[n + y + m * N] = __float2half(sum[m][y][0] * sA * sB); + else // if (Otp == 12) //bf16 + *reinterpret_cast<__hip_bfloat16*>(&C[n + y + m * N]) = + __float2bfloat16(sum[m][y][0] * sA * sB); } } } @@ -718,7 +728,7 @@ __global__ void wvSpltKQ_hf_(const int K, const int Kp, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const float* __restrict__ s_A, const float* __restrict__ s_B, const int _WvPrGrp, - const int CuCount) { + const int Otp, const int CuCount) { UNREACHABLE_CODE } #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support @@ -879,8 +889,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t m = 0; m < M; m++) { #pragma unroll @@ -1163,8 +1173,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" @@ -1442,7 +1452,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (n >= N) continue; #endif - // Fetch the weight matrix from memory! + // Fetch the weight matrix from memory! #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; @@ -1687,7 +1697,8 @@ void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, void wvSpltKQ_(void* in_a, void* in_b, void* out_c, void* scale_a, void* scale_b, const int M_in, const int K_in, const int Kp_in, - const int N_in, cudaStream_t stream, const int CuCount = 0) { + const int N_in, const int Otp_in, cudaStream_t stream, + const int CuCount = 0) { dim3 grid(CuCount); half* af4 = reinterpret_cast(in_a); const half* bf4 = reinterpret_cast(in_b); @@ -1703,12 +1714,12 @@ void wvSpltKQ_(void* in_a, void* in_b, void* out_c, void* scale_a, int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ wvSpltKQ_hf_sml_<64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \ <<>>(K_in, Kp_in, M_in, af4, bf4, c, s_a, \ - s_b, __wvPrGrp, CuCount); \ + s_b, __wvPrGrp, Otp_in, CuCount); \ } else { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ wvSpltKQ_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ <<>>(K_in, Kp_in, M_in, af4, bf4, c, s_a, \ - s_b, __wvPrGrp, CuCount); \ + s_b, __wvPrGrp, Otp_in, CuCount); \ } \ } diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index b20d404e39d1..0701b5df3f07 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -12,8 +12,8 @@ void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, const int64_t N_in, const int64_t CuCount); void wvSpltKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - at::Tensor& scale_a, at::Tensor& scale_b, - const int64_t N_in, const int64_t CuCount); + at::Tensor& scale_a, at::Tensor& scale_b, const int64_t N_in, + const int64_t Otp_in, const int64_t CuCount); void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 2c04ac1b89c0..4ca24b7c19e7 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -44,8 +44,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " int CuCount) -> ()"); rocm_ops.impl("wvSpltK", torch::kCUDA, &wvSpltK); rocm_ops.def( - "wvSpltKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, Tensor scale_b," + "wvSpltKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, " + "Tensor scale_b," " int N_in," + " int Otp_in," " int CuCount) -> ()"); rocm_ops.impl("wvSpltKQ", torch::kCUDA, &wvSpltKQ); } diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 15b7aed1d78f..182f3d74eb1d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1143,7 +1143,8 @@ def wvSpltK(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, N: int, cu_count: int) -> None: torch.ops._rocm_C.wvSpltK(a, b, out, N, cu_count) -def wvSpltKQ(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, - scale_a: torch.Tensor, scale_b: torch.Tensor, N: int, - cu_count: int) -> None: - torch.ops._rocm_C.wvSpltKQ(a, b, out, scale_a, scale_b, N, cu_count) + +def wvSpltKQ(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, + scale_a: torch.Tensor, scale_b: torch.Tensor, N: int, Otp: int, + cu_count: int) -> None: + torch.ops._rocm_C.wvSpltKQ(a, b, out, scale_a, scale_b, N, Otp, cu_count) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 1f4084e263fd..ca59ff88223b 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -5,8 +5,8 @@ import torch from vllm import _custom_ops as ops -from vllm.platforms import current_platform from vllm.model_executor.layers.tuned_gemm import tgemm +from vllm.platforms import current_platform # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale @@ -163,22 +163,21 @@ def apply_fp8_linear( # Fused GEMM_DQ output2 = None n = qinput.shape[0] - if n == 1 : - weightT = weight.t() - output = tgemm.scaled_mm(qinput, - weightT, - out_dtype=out_dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) - else : - output = torch._scaled_mm(qinput, - weight, - out_dtype=out_dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) - + if n == 1: + weightT = weight.t() + output = tgemm.scaled_mm(qinput, + weightT, + out_dtype=out_dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) + else: + output = torch._scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) ''' atol = 1e-2; rtol = 1e-5 if output2!=None : @@ -189,9 +188,9 @@ def apply_fp8_linear( ''' # A fix for discrepancy in scaled_mm which returns tuple # for torch < 2.5 and a single value in torch >= 2.5 - if (weight.shape[0] == 1) : - if type(output) is tuple and len(output) == 2: - output = output[0] + if (weight.shape[0] == 1): + if type(output) is tuple and len(output) == 2: + output = output[0] return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index d9ea11979503..4f58dada1c70 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -5,7 +5,6 @@ import pandas as pd import torch import torch.nn.functional as F -import traceback from vllm import _custom_ops as ops from vllm.envs import VLLM_USE_ROCM_SKINNY_GEMM @@ -92,13 +91,23 @@ def apply_skinny(self, m, n, k, inp_view, weights): else: return None - def scaled_mm(self, inp, weight, out_dtype=None, scale_a=None, scale_b=None, bias=None): + def scaled_mm(self, + inp, + weight, + out_dtype=None, + scale_a=None, + scale_b=None, + bias=None): out = torch.empty(inp.shape[0], - weight.shape[0], - dtype=out_dtype, - device='cuda') + weight.shape[0], + dtype=out_dtype, + device='cuda') n = inp.shape[0] - ops.wvSpltKQ(weight, inp, out, scale_a, scale_b, n, self.cu_count) + + Otp = 1 #default bfloat16 + if out_dtype == torch.float16: + Otp = 0 + ops.wvSpltKQ(weight, inp, out, scale_a, scale_b, n, Otp, self.cu_count) return out def mm(self, inp, weights, bias=None): From fdb983ee0cf707989b177a60b29e59238091228e Mon Sep 17 00:00:00 2001 From: Hashem Hashemi Date: Sat, 8 Feb 2025 05:01:53 +0000 Subject: [PATCH 4/8] code cleanup --- vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index ca59ff88223b..4708fc0c2a30 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -161,7 +161,7 @@ def apply_fp8_linear( if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ - output2 = None + # output2 = None n = qinput.shape[0] if n == 1: weightT = weight.t() From 282b2366a463f0ed21b1e07681e8c65b3196a873 Mon Sep 17 00:00:00 2001 From: Hashem Hashemi Date: Sat, 8 Feb 2025 05:05:29 +0000 Subject: [PATCH 5/8] code cleanup --- vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 4708fc0c2a30..cb1460d97006 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -188,9 +188,9 @@ def apply_fp8_linear( ''' # A fix for discrepancy in scaled_mm which returns tuple # for torch < 2.5 and a single value in torch >= 2.5 - if (weight.shape[0] == 1): - if type(output) is tuple and len(output) == 2: - output = output[0] + if (weight.shape[0] == 1) and (type(output) + is tuple) and (len(output) == 2): + output = output[0] return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) From a7e7664c7989503393259590402f101007d9e997 Mon Sep 17 00:00:00 2001 From: Hashem Hashemi Date: Wed, 12 Feb 2025 19:18:37 +0000 Subject: [PATCH 6/8] lint fix2 --- vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index cb1460d97006..eb7571d9493d 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -182,7 +182,7 @@ def apply_fp8_linear( atol = 1e-2; rtol = 1e-5 if output2!=None : if not torch.allclose(output,output2,atol=atol,rtol=rtol) : - print("MISSMATCH") + print("MISMATCH") print("orig", output) print("opt", output2) ''' From 13e5f0b5e0dd49bb8340d20eaf667f44ed7468b8 Mon Sep 17 00:00:00 2001 From: Hashem Hashemi Date: Tue, 18 Feb 2025 05:26:54 +0000 Subject: [PATCH 7/8] cleanup --- vllm/model_executor/layers/quantization/fp8.py | 1 - .../layers/quantization/utils/w8a8_utils.py | 10 +--------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index db55304a8afd..b01269597339 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -368,7 +368,6 @@ def process_weights_after_loading(self, layer: Module) -> None: weight = self.add_padding_to_weight(weight) # Update layer with new values. layer.weight = Parameter(weight.t(), requires_grad=False) - #layer.weight = Parameter(weight, requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) if self.quant_config.activation_scheme == "static": layer.input_scale = Parameter(layer.input_scale.max(), diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 9c303f6cdc61..62e3e0da2e81 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -172,7 +172,6 @@ def apply_fp8_linear( if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ - # output2 = None n = qinput.shape[0] if n == 1: weightT = weight.t() @@ -189,14 +188,7 @@ def apply_fp8_linear( scale_a=x_scale, scale_b=weight_scale, bias=bias) - ''' - atol = 1e-2; rtol = 1e-5 - if output2!=None : - if not torch.allclose(output,output2,atol=atol,rtol=rtol) : - print("MISMATCH") - print("orig", output) - print("opt", output2) - ''' + # A fix for discrepancy in scaled_mm which returns tuple # for torch < 2.5 and a single value in torch >= 2.5 if (weight.shape[0] == 1) and (type(output) From 8e768ac43fed743b851e49c5198b276baf37af77 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Tue, 18 Feb 2025 23:25:47 +0000 Subject: [PATCH 8/8] Moved the logic into tuned gemm to preserve API compatibility --- .../layers/quantization/utils/w8a8_utils.py | 27 +++++----------- vllm/model_executor/layers/tuned_gemm.py | 32 +++++++++++++------ 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 62e3e0da2e81..12a5aac9b8d1 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -135,6 +135,7 @@ def apply_fp8_linear( if out_dtype is None: out_dtype = input.dtype + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A if cutlass_fp8_supported: qinput, x_scale = ops.scaled_fp8_quant( @@ -172,27 +173,15 @@ def apply_fp8_linear( if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ - n = qinput.shape[0] - if n == 1: - weightT = weight.t() - output = tgemm.scaled_mm(qinput, - weightT, - out_dtype=out_dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) - else: - output = torch._scaled_mm(qinput, - weight, - out_dtype=out_dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) - + output = tgemm.scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) # A fix for discrepancy in scaled_mm which returns tuple # for torch < 2.5 and a single value in torch >= 2.5 - if (weight.shape[0] == 1) and (type(output) - is tuple) and (len(output) == 2): + if type(output) is tuple and len(output) == 2: output = output[0] return torch.narrow(output, 0, 0, diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 4f58dada1c70..cf3caebf3201 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import os from pathlib import Path +from typing import Optional import pandas as pd import torch @@ -91,23 +92,34 @@ def apply_skinny(self, m, n, k, inp_view, weights): else: return None - def scaled_mm(self, - inp, - weight, - out_dtype=None, - scale_a=None, - scale_b=None, - bias=None): + def scaled_mm( + self, + inp: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: Optional[torch.Tensor], + ) -> torch.Tensor: + n = inp.shape[0] + if n != 1: + return torch._scaled_mm(inp, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + weightT = weight.t() out = torch.empty(inp.shape[0], - weight.shape[0], + weightT.shape[0], dtype=out_dtype, device='cuda') - n = inp.shape[0] Otp = 1 #default bfloat16 if out_dtype == torch.float16: Otp = 0 - ops.wvSpltKQ(weight, inp, out, scale_a, scale_b, n, Otp, self.cu_count) + ops.wvSpltKQ(weightT, inp, out, scale_a, scale_b, n, Otp, + self.cu_count) return out def mm(self, inp, weights, bias=None):